import numpy as np from collections import namedtuple from .module import Module from .layer import Layer # Define the named tuple type outside your class QuantizationError = namedtuple('QuantizationError', ['roundingError', 'clippingError']) class Quantizer: """ A class that can take a network/module and quantize it post training. """ def __init__(self, bits: int = 8, *, perChannel: bool = False, quantizationScheme: str = "symmetric") -> None: """ Initializes the quantizer with a network/module to quantize. Parameters: module (Module): The network/module to be quantized. bits (int): The bit width for quantization. """ self.bits = bits self.perChannel = perChannel self.scheme = quantizationScheme def callibrate(self) -> None: """ this callibrates and minimizes (pareto) quantization errors """ pass @property def quantizationError(self) -> QuantizationError: """ this returns the two main errors of the quantization """ return QuantizationError(self._roundingError(), self._clippingError()) def _roundingError(self) -> float: """ A private method for calculating the mean absolute rounding error. """ totalError = 0. totalElements = 0 for layer in self.module: try: params = layer.params() except AttributeError: # 'params' method not found in the layer, skip updating continue for param in params: dequantizedWeights = param.dequantize() errorWeights = np.abs(param._values - dequantizedWeights) totalError += np.sum(errorWeights) totalElements += np.prod(param._values.shape) # Calculate the mean absolute error meanError = totalError / totalElements if totalElements > 0 else 0 return meanError def _clippingError(self) -> float: totalClippingError = 0. totalElements = 0 for layer in self.module: try: params = layer.params() except AttributeError: # 'params' method not found in the layer, skip updating continue for param in layer.params(): # Assuming you have a method or a way to determine Q_min and Q_max for each parameter qMin, qMax = param.quantizationRange # Calculate clipping error for values below qMin belowMin = np.minimum(param._values - qMin, 0) # Calculate clipping error for values above qMax aboveMax = np.maximum(param._values - qMax, 0) # Sum of absolute errors gives total clipping error for the parameter paramClippingError = np.sum(np.abs(belowMin) + np.abs(aboveMax)) totalClippingError += paramClippingError # Update total elements for averaging totalElements += np.prod(param._values.shape) # Compute mean clipping error if needed, or return total meanClippingError = totalClippingError / totalElements if totalElements > 0 else 0 return meanClippingError def __call__(self, module: Module) -> None: """ Applies quantization to all quantizable parameters in the module. """ self.module = module for layer in self.module: self._quantizeLayer(layer) def dequantize(self) -> None: """ Applies dequantization to all dequantizable parameters in the module. """ for layer in self.module: self._dequantizeLayer(layer) def _quantizeLayer(self, layer: Layer) -> None: """ Quantizes the weights (and biases) of a single layer if applicable. """ try: params = layer.params() except AttributeError: # 'params' method not found in the layer, skip updating return for param in params: param.quantize(bits=self.bits, scheme=self.scheme) def _dequantizeLayer(self, layer: Layer) -> None: """ Dequantizes the weights (and biases) of a single layer if applicable. """ try: params = layer.params() except AttributeError: # 'params' method not found in the layer, skip updating return for param in params: param.dequantize()