import numpy as np
from collections import namedtuple
from .module import Module
from .layer import Layer


# Define the named tuple type outside your class
QuantizationError = namedtuple('QuantizationError', ['roundingError', 'clippingError'])


class Quantizer:
    """
    A class that can take a network/module and quantize it post training.
    """
    def __init__(self, bits: int = 8, *, perChannel: bool = False, quantizationScheme: str = "symmetric") -> None:
        """
        Initializes the quantizer with a network/module to quantize.

        Parameters:
            module (Module): The network/module to be quantized.
            bits (int): The bit width for quantization.
        """
        self.bits = bits
        self.perChannel = perChannel
        self.scheme = quantizationScheme

    def callibrate(self) -> None:
        """
        this callibrates and minimizes (pareto) quantization errors
        """
        pass

    @property
    def quantizationError(self) -> QuantizationError:
        """
        this returns the two main errors of the quantization
        """
        return QuantizationError(self._roundingError(), self._clippingError())

    def _roundingError(self) -> float:
        """
        A private method for calculating the mean absolute rounding error.
        """
        totalError = 0.
        totalElements = 0

        for layer in self.module:
            try:
                params = layer.params()
            except AttributeError:
                # 'params' method not found in the layer, skip updating
                continue

            for param in params:
                dequantizedWeights = param.dequantize()
                errorWeights = np.abs(param._values - dequantizedWeights)
                totalError += np.sum(errorWeights)
                totalElements += np.prod(param._values.shape)

        # Calculate the mean absolute error
        meanError = totalError / totalElements if totalElements > 0 else 0
        return meanError

    def _clippingError(self) -> float:
        totalClippingError = 0.
        totalElements = 0

        for layer in self.module:
            try:
                params = layer.params()
            except AttributeError:
                # 'params' method not found in the layer, skip updating
                continue

            for param in layer.params():
                # Assuming you have a method or a way to determine Q_min and Q_max for each parameter
                qMin, qMax = param.quantizationRange

                # Calculate clipping error for values below qMin
                belowMin = np.minimum(param._values - qMin, 0)
                # Calculate clipping error for values above qMax
                aboveMax = np.maximum(param._values - qMax, 0)

                # Sum of absolute errors gives total clipping error for the parameter
                paramClippingError = np.sum(np.abs(belowMin) + np.abs(aboveMax))
                totalClippingError += paramClippingError

                # Update total elements for averaging
                totalElements += np.prod(param._values.shape)

        # Compute mean clipping error if needed, or return total
        meanClippingError = totalClippingError / totalElements if totalElements > 0 else 0
        return meanClippingError

    def __call__(self, module: Module) -> None:
        """
        Applies quantization to all quantizable parameters in the module.
        """
        self.module = module
        for layer in self.module:
            self._quantizeLayer(layer)

    def dequantize(self) -> None:
        """
        Applies dequantization to all dequantizable parameters in the module.
        """
        for layer in self.module:
            self._dequantizeLayer(layer)

    def _quantizeLayer(self, layer: Layer) -> None:
        """
        Quantizes the weights (and biases) of a single layer if applicable.
        """
        try:
            params = layer.params()
        except AttributeError:
            # 'params' method not found in the layer, skip updating
            return

        for param in params:
            param.quantize(bits=self.bits, scheme=self.scheme)

    def _dequantizeLayer(self, layer: Layer) -> None:
        """
        Dequantizes the weights (and biases) of a single layer if applicable.
        """
        try:
            params = layer.params()
        except AttributeError:
            # 'params' method not found in the layer, skip updating
            return

        for param in params:
            param.dequantize()