Newer
Older
from collections import namedtuple
from .module import Module
from .layer import Layer
# Define the named tuple type outside your class
QuantizationError = namedtuple('QuantizationError', ['roundingError', 'clippingError'])
class Quantizer:
"""
A class that can take a network/module and quantize it post training.
"""
def __init__(self, bits: int = 8, *, perChannel: bool = False, quantizationScheme: str = "symmetric") -> None:
"""
Initializes the quantizer with a network/module to quantize.
Parameters:
module (Module): The network/module to be quantized.
bits (int): The bit width for quantization.
"""
self.bits = bits
self.perChannel = perChannel
self.scheme = quantizationScheme
def callibrate(self) -> None:
"""
this callibrates and minimizes (pareto) quantization errors
"""
pass
@property
def quantizationError(self) -> QuantizationError:
"""
this returns the two main errors of the quantization
"""
return QuantizationError(self._roundingError(), self._clippingError())
A private method for calculating the mean absolute rounding error.
totalError = 0.
totalElements = 0
for layer in self.module:
try:
params = layer.params()
except AttributeError:
# 'params' method not found in the layer, skip updating
continue
for param in params:
dequantizedWeights = param.dequantize()
errorWeights = np.abs(param._values - dequantizedWeights)
totalError += np.sum(errorWeights)
totalElements += np.prod(param._values.shape)
# Calculate the mean absolute error
meanError = totalError / totalElements if totalElements > 0 else 0
return meanError
totalClippingError = 0.
totalElements = 0
for layer in self.module:
try:
params = layer.params()
except AttributeError:
# 'params' method not found in the layer, skip updating
continue
for param in layer.params():
# Assuming you have a method or a way to determine Q_min and Q_max for each parameter
qMin, qMax = param.quantizationRange
# Calculate clipping error for values below qMin
belowMin = np.minimum(param._values - qMin, 0)
# Calculate clipping error for values above qMax
aboveMax = np.maximum(param._values - qMax, 0)
# Sum of absolute errors gives total clipping error for the parameter
paramClippingError = np.sum(np.abs(belowMin) + np.abs(aboveMax))
totalClippingError += paramClippingError
# Update total elements for averaging
totalElements += np.prod(param._values.shape)
# Compute mean clipping error if needed, or return total
meanClippingError = totalClippingError / totalElements if totalElements > 0 else 0
return meanClippingError
"""
Applies quantization to all quantizable parameters in the module.
"""
for layer in self.module:
self._quantizeLayer(layer)
"""
Applies dequantization to all dequantizable parameters in the module.
"""
for layer in self.module:
self._dequantizeLayer(layer)
def _quantizeLayer(self, layer: Layer) -> None:
"""
Quantizes the weights (and biases) of a single layer if applicable.
"""
try:
params = layer.params()
except AttributeError:
# 'params' method not found in the layer, skip updating
return
for param in params:
param.quantize(bits=self.bits, scheme=self.scheme)
def _dequantizeLayer(self, layer: Layer) -> None:
"""
Dequantizes the weights (and biases) of a single layer if applicable.
"""
try:
params = layer.params()
except AttributeError:
# 'params' method not found in the layer, skip updating
return
for param in params:
param.dequantize()