Skip to content
Snippets Groups Projects
quantizer.py 2.69 KiB
Newer Older
  • Learn to ignore specific revisions
  • johannes bilk's avatar
    johannes bilk committed
    import numpy as np
    
    from .module import Module
    from .layer import Layer
    
    johannes bilk's avatar
    johannes bilk committed
    
    
    class Quantizer:
        """
        A class that can take a network/module and quantize it post training.
        """
        def __init__(self, module: Module, bits: int = 8, *, perChannel: bool = False, quantizationScheme: str = "symmetric") -> None:
            """
            Initializes the quantizer with a network/module to quantize.
    
            Parameters:
                module (Module): The network/module to be quantized.
                bits (int): The bit width for quantization.
            """
            self.module = module
            self.bits = bits
            self.perChannel = perChannel
            self.scheme = quantizationScheme
    
        def callibrate(self) -> None:
            """
            this callibrates and minimizes (pareto) quantization errors
            """
            pass
    
        @property
        def quantizationError(self) -> tuple[float, float]:
            """
            this returns the two main errors of the quantization
            """
            return self._roundingError(), self._clippingError()
    
        def _roundingError(self) -> float:
            """
            a private methode for calculating the rounding error
            """
            pass
    
        def _clippingError(self) -> float:
            """
            a private methode for calculating the clipping error
            """
            pass
    
    
        def quantize(self) -> None:
    
    johannes bilk's avatar
    johannes bilk committed
            """
            Applies quantization to all quantizable parameters in the module.
            """
            for layer in self.module:
                self._quantizeLayer(layer)
    
    
        def dequantize(self) -> None:
    
    johannes bilk's avatar
    johannes bilk committed
            """
            Applies dequantization to all dequantizable parameters in the module.
            """
            for layer in self.module:
                self._dequantizeLayer(layer)
    
        def _quantizeLayer(self, layer: Layer) -> None:
            """
            Quantizes the weights (and biases) of a single layer if applicable.
            """
            # Check if layer has weights attribute and quantize
            if hasattr(layer, 'weights') and layer.weights is not None:
                layer.weights.quantize(bits=self.bits, scheme=self.scheme)
    
            # Check if layer has bias attribute and quantize
            if hasattr(layer, 'bias') and layer.bias is not None:
                layer.bias.quantize(bits=self.bits, scheme=self.scheme)
    
        def _dequantizeLayer(self, layer: Layer) -> Layer:
            """
            Dequantizes the weights (and biases) of a single layer if applicable.
            """
            # Check if layer has weights attribute and dequantize
            if hasattr(layer, 'weights') and layer.weights is not None:
                layer.weights.dequantize()
    
            # Check if layer has bias attribute and dequantize
            if hasattr(layer, 'bias') and layer.bias is not None:
                layer.bias.dequantize()