Newer
Older
from .module import Module
from .layer import Layer
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class Quantizer:
"""
A class that can take a network/module and quantize it post training.
"""
def __init__(self, module: Module, bits: int = 8, *, perChannel: bool = False, quantizationScheme: str = "symmetric") -> None:
"""
Initializes the quantizer with a network/module to quantize.
Parameters:
module (Module): The network/module to be quantized.
bits (int): The bit width for quantization.
"""
self.module = module
self.bits = bits
self.perChannel = perChannel
self.scheme = quantizationScheme
def callibrate(self) -> None:
"""
this callibrates and minimizes (pareto) quantization errors
"""
pass
@property
def quantizationError(self) -> tuple[float, float]:
"""
this returns the two main errors of the quantization
"""
return self._roundingError(), self._clippingError()
def _roundingError(self) -> float:
"""
a private methode for calculating the rounding error
"""
pass
def _clippingError(self) -> float:
"""
a private methode for calculating the clipping error
"""
pass
"""
Applies quantization to all quantizable parameters in the module.
"""
for layer in self.module:
self._quantizeLayer(layer)
"""
Applies dequantization to all dequantizable parameters in the module.
"""
for layer in self.module:
self._dequantizeLayer(layer)
def _quantizeLayer(self, layer: Layer) -> None:
"""
Quantizes the weights (and biases) of a single layer if applicable.
"""
# Check if layer has weights attribute and quantize
if hasattr(layer, 'weights') and layer.weights is not None:
layer.weights.quantize(bits=self.bits, scheme=self.scheme)
# Check if layer has bias attribute and quantize
if hasattr(layer, 'bias') and layer.bias is not None:
layer.bias.quantize(bits=self.bits, scheme=self.scheme)
def _dequantizeLayer(self, layer: Layer) -> Layer:
"""
Dequantizes the weights (and biases) of a single layer if applicable.
"""
# Check if layer has weights attribute and dequantize
if hasattr(layer, 'weights') and layer.weights is not None:
layer.weights.dequantize()
# Check if layer has bias attribute and dequantize
if hasattr(layer, 'bias') and layer.bias is not None:
layer.bias.dequantize()