Skip to content
Snippets Groups Projects
quantizer.py 4.47 KiB
Newer Older
  • Learn to ignore specific revisions
  • johannes bilk's avatar
    johannes bilk committed
    import numpy as np
    
    from collections import namedtuple
    
    from .module import Module
    from .layer import Layer
    
    # Define the named tuple type outside your class
    QuantizationError = namedtuple('QuantizationError', ['roundingError', 'clippingError'])
    
    
    
    johannes bilk's avatar
    johannes bilk committed
    class Quantizer:
        """
        A class that can take a network/module and quantize it post training.
        """
    
        def __init__(self, bits: int = 8, *, perChannel: bool = False, quantizationScheme: str = "symmetric") -> None:
    
    johannes bilk's avatar
    johannes bilk committed
            """
            Initializes the quantizer with a network/module to quantize.
    
            Parameters:
                module (Module): The network/module to be quantized.
                bits (int): The bit width for quantization.
            """
            self.bits = bits
            self.perChannel = perChannel
            self.scheme = quantizationScheme
    
        def callibrate(self) -> None:
            """
            this callibrates and minimizes (pareto) quantization errors
            """
            pass
    
        @property
    
        def quantizationError(self) -> QuantizationError:
    
    johannes bilk's avatar
    johannes bilk committed
            """
            this returns the two main errors of the quantization
            """
    
            return QuantizationError(self._roundingError(), self._clippingError())
    
    johannes bilk's avatar
    johannes bilk committed
    
        def _roundingError(self) -> float:
            """
    
            A private method for calculating the mean absolute rounding error.
    
    johannes bilk's avatar
    johannes bilk committed
            """
    
            totalError = 0.
            totalElements = 0
    
            for layer in self.module:
                try:
                    params = layer.params()
                except AttributeError:
                    # 'params' method not found in the layer, skip updating
                    continue
    
                for param in params:
                    dequantizedWeights = param.dequantize()
                    errorWeights = np.abs(param._values - dequantizedWeights)
                    totalError += np.sum(errorWeights)
                    totalElements += np.prod(param._values.shape)
    
            # Calculate the mean absolute error
            meanError = totalError / totalElements if totalElements > 0 else 0
            return meanError
    
    johannes bilk's avatar
    johannes bilk committed
    
        def _clippingError(self) -> float:
    
            totalClippingError = 0.
            totalElements = 0
    
            for layer in self.module:
                try:
                    params = layer.params()
                except AttributeError:
                    # 'params' method not found in the layer, skip updating
                    continue
    
                for param in layer.params():
                    # Assuming you have a method or a way to determine Q_min and Q_max for each parameter
                    qMin, qMax = param.quantizationRange
    
                    # Calculate clipping error for values below qMin
                    belowMin = np.minimum(param._values - qMin, 0)
                    # Calculate clipping error for values above qMax
                    aboveMax = np.maximum(param._values - qMax, 0)
    
                    # Sum of absolute errors gives total clipping error for the parameter
                    paramClippingError = np.sum(np.abs(belowMin) + np.abs(aboveMax))
                    totalClippingError += paramClippingError
    
                    # Update total elements for averaging
                    totalElements += np.prod(param._values.shape)
    
            # Compute mean clipping error if needed, or return total
            meanClippingError = totalClippingError / totalElements if totalElements > 0 else 0
            return meanClippingError
    
        def __call__(self, module: Module) -> None:
    
    johannes bilk's avatar
    johannes bilk committed
            """
            Applies quantization to all quantizable parameters in the module.
            """
    
            self.module = module
    
    johannes bilk's avatar
    johannes bilk committed
            for layer in self.module:
                self._quantizeLayer(layer)
    
    
        def dequantize(self) -> None:
    
    johannes bilk's avatar
    johannes bilk committed
            """
            Applies dequantization to all dequantizable parameters in the module.
            """
            for layer in self.module:
                self._dequantizeLayer(layer)
    
        def _quantizeLayer(self, layer: Layer) -> None:
            """
            Quantizes the weights (and biases) of a single layer if applicable.
            """
    
            try:
                params = layer.params()
            except AttributeError:
                # 'params' method not found in the layer, skip updating
                return
    
            for param in params:
                param.quantize(bits=self.bits, scheme=self.scheme)
    
        def _dequantizeLayer(self, layer: Layer) -> None:
    
    johannes bilk's avatar
    johannes bilk committed
            """
            Dequantizes the weights (and biases) of a single layer if applicable.
            """
    
            try:
                params = layer.params()
            except AttributeError:
                # 'params' method not found in the layer, skip updating
                return
    
            for param in params:
                param.dequantize()