Skip to content
Snippets Groups Projects
quantizer.py 2.7 KiB
Newer Older
johannes bilk's avatar
johannes bilk committed
import numpy as np
from module import Module
from layer import Layer


class Quantizer:
    """
    A class that can take a network/module and quantize it post training.
    """
    def __init__(self, module: Module, bits: int = 8, *, perChannel: bool = False, quantizationScheme: str = "symmetric") -> None:
        """
        Initializes the quantizer with a network/module to quantize.

        Parameters:
            module (Module): The network/module to be quantized.
            bits (int): The bit width for quantization.
        """
        self.module = module
        self.bits = bits
        self.perChannel = perChannel
        self.scheme = quantizationScheme

    def callibrate(self) -> None:
        """
        this callibrates and minimizes (pareto) quantization errors
        """
        pass

    @property
    def quantizationError(self) -> tuple[float, float]:
        """
        this returns the two main errors of the quantization
        """
        return self._roundingError(), self._clippingError()

    def _roundingError(self) -> float:
        """
        a private methode for calculating the rounding error
        """
        pass

    def _clippingError(self) -> float:
        """
        a private methode for calculating the clipping error
        """
        pass

    def quantizeModule(self) -> None:
        """
        Applies quantization to all quantizable parameters in the module.
        """
        for layer in self.module:
            self._quantizeLayer(layer)

    def dequantizeModule(self) -> None:
        """
        Applies dequantization to all dequantizable parameters in the module.
        """
        for layer in self.module:
            self._dequantizeLayer(layer)

    def _quantizeLayer(self, layer: Layer) -> None:
        """
        Quantizes the weights (and biases) of a single layer if applicable.
        """
        # Check if layer has weights attribute and quantize
        if hasattr(layer, 'weights') and layer.weights is not None:
            layer.weights.quantize(bits=self.bits, scheme=self.scheme)

        # Check if layer has bias attribute and quantize
        if hasattr(layer, 'bias') and layer.bias is not None:
            layer.bias.quantize(bits=self.bits, scheme=self.scheme)

    def _dequantizeLayer(self, layer: Layer) -> Layer:
        """
        Dequantizes the weights (and biases) of a single layer if applicable.
        """
        # Check if layer has weights attribute and dequantize
        if hasattr(layer, 'weights') and layer.weights is not None:
            layer.weights.dequantize()

        # Check if layer has bias attribute and dequantize
        if hasattr(layer, 'bias') and layer.bias is not None:
            layer.bias.dequantize()