import numpy as np
from module import Module
from layer import Layer


class Quantizer:
    """
    A class that can take a network/module and quantize it post training.
    """
    def __init__(self, module: Module, bits: int = 8, *, perChannel: bool = False, quantizationScheme: str = "symmetric") -> None:
        """
        Initializes the quantizer with a network/module to quantize.

        Parameters:
            module (Module): The network/module to be quantized.
            bits (int): The bit width for quantization.
        """
        self.module = module
        self.bits = bits
        self.perChannel = perChannel
        self.scheme = quantizationScheme

    def callibrate(self) -> None:
        """
        this callibrates and minimizes (pareto) quantization errors
        """
        pass

    @property
    def quantizationError(self) -> tuple[float, float]:
        """
        this returns the two main errors of the quantization
        """
        return self._roundingError(), self._clippingError()

    def _roundingError(self) -> float:
        """
        a private methode for calculating the rounding error
        """
        pass

    def _clippingError(self) -> float:
        """
        a private methode for calculating the clipping error
        """
        pass

    def quantizeModule(self) -> None:
        """
        Applies quantization to all quantizable parameters in the module.
        """
        for layer in self.module:
            self._quantizeLayer(layer)

    def dequantizeModule(self) -> None:
        """
        Applies dequantization to all dequantizable parameters in the module.
        """
        for layer in self.module:
            self._dequantizeLayer(layer)

    def _quantizeLayer(self, layer: Layer) -> None:
        """
        Quantizes the weights (and biases) of a single layer if applicable.
        """
        # Check if layer has weights attribute and quantize
        if hasattr(layer, 'weights') and layer.weights is not None:
            layer.weights.quantize(bits=self.bits, scheme=self.scheme)

        # Check if layer has bias attribute and quantize
        if hasattr(layer, 'bias') and layer.bias is not None:
            layer.bias.quantize(bits=self.bits, scheme=self.scheme)

    def _dequantizeLayer(self, layer: Layer) -> Layer:
        """
        Dequantizes the weights (and biases) of a single layer if applicable.
        """
        # Check if layer has weights attribute and dequantize
        if hasattr(layer, 'weights') and layer.weights is not None:
            layer.weights.dequantize()

        # Check if layer has bias attribute and dequantize
        if hasattr(layer, 'bias') and layer.bias is not None:
            layer.bias.dequantize()