Skip to content
Snippets Groups Projects
weights.py 6.85 KiB
Newer Older
  • Learn to ignore specific revisions
  • johannes bilk's avatar
    johannes bilk committed
    import numpy as np
    from numpy.typing import ArrayLike
    
    
    def initializeWeights(size: tuple | int, scale: float = 1.0, init: str = 'random') -> ArrayLike:
        """
        Initialize filter using a normal distribution with and a
        standard deviation inversely proportional the square root of the number of units
        """
        if init == 'random':
            stddev = scale/np.sqrt(np.prod(size))
            return np.random.normal(loc=0, scale=stddev, size=size)
        elif init == 'ones':
            return np.ones(size)
        elif init == 'zeros':
            return np.zeros(size)
        else:
            raise ValueError('not a valid init argument')
    
    
    class Weights(object):
        """
        the idea behind class is to combine everything an optimizer needs into one object
        this way layers and optimizers don't need to take care of storing and providing
        things like previous updates or cache
        """
    
        __slots__ = ['_values', '_quantizedValues', 'prevValues', 'deltas', 'prevDeltas', 'cache', 'scale', 'maxValue', '_useQuantization', 'zeroPoint', 'qMin', 'qMax']
    
    johannes bilk's avatar
    johannes bilk committed
    
        def __init__(self, size: tuple | int, values: ArrayLike = None, init: str = 'random') -> None:
    
    johannes bilk's avatar
    johannes bilk committed
            self._values = initializeWeights(size, init=init) if values is None else values
    
    johannes bilk's avatar
    johannes bilk committed
            self.prevValues = None
            self.deltas = np.zeros(size)
            self.prevDeltas = None
            self.cache = None
    
    
    johannes bilk's avatar
    johannes bilk committed
            self._quantizedValues = np.zeros(size)
            self.scale = 1
            self.maxValue = 0
    
    johannes bilk's avatar
    johannes bilk committed
            self._useQuantization = False
    
        @property
        def values(self):
            """
            Depending on the _useQuantized flag, return either the original
            or quantized (and dequantized back) weight values for computation.
            """
            if self._useQuantization:
                return self.dequantize()
            else:
                return self._values
    
        @values.setter
        def values(self, newValues):
            """
            Allow updates to the weight values directly.
            """
            self._values = newValues
    
    
    johannes bilk's avatar
    johannes bilk committed
        @property
        def qualifiedName(self) -> tuple:
            return self.__class__.__module__, self.__class__.__name__
    
        def toDict(self) -> dict:
            saveDict = {}
    
    johannes bilk's avatar
    johannes bilk committed
            saveDict['size'] = self._values.shape
            saveDict['values'] = self._values.tolist()
    
    johannes bilk's avatar
    johannes bilk committed
            saveDict['deltas'] = self.deltas.tolist()
            if self.prevValues is not None:
                saveDict['prevValues'] = self.prevValues.tolist()
            saveDict['cache'] = {}
            if type(self.cache) == dict:
                saveDict['cache']['values'] = {}
                for key in self.cache:
                    saveDict['cache']['values'][key] = self.cache[key].tolist()
                saveDict['cache']['type'] = 'dict'
            elif type(self.cache) == np.ndarray:
                saveDict['cache']['values'] = self.cache.tolist()
                saveDict['cache']['type'] = 'np.ndarray'
    
    
            if self._useQuantization:
                saveDict['quantization'] = {}
                saveDict['quantization']['scale'] = int(self.scale)
    
                saveDict['quantization']['qMin'] = int(self.qMin)
                saveDict['quantization']['qMax'] = int(self.qMax)
                saveDict['quantization']['zeroPoint'] = int(self.zeroPoint)
    
                saveDict['quantization']['maxValue'] = float(self.maxValue)
                saveDict['quantization']['quantizedValues'] = self._quantizedValues.tolist()
    
    
    johannes bilk's avatar
    johannes bilk committed
            return saveDict
    
        def fromDict(self, loadDict: dict) -> None:
    
    johannes bilk's avatar
    johannes bilk committed
            self._values = np.array(loadDict['values'])
    
    johannes bilk's avatar
    johannes bilk committed
            self.deltas = np.array(loadDict['deltas'])
            if 'prevValues' in loadDict:
                self.prevValues = np.array(loadDict['prevValues'])
            if loadDict['cache']['type'] == 'np.ndarray':
                self.cache = np.array(loadDict['cache']['values'])
            elif loadDict['cache']['type'] == 'dict':
                self.cache = {}
                for key in loadDict['cache']['values']:
                    self.cache[key] = np.array(loadDict['cache']['values'][key])
    
            if 'quantization' in loadDict:
                self.scale = loadDict['quantization']['scale']
                self.maxValue = loadDict['quantization']['maxValue']
    
                self.zeroPoint = loadDict['quantization']['zeroPoint']
                self.qMin = loadDict['quantization']['qMin']
                self.qMax = loadDict['quantization']['qMax']
    
                self._quantizedValues = np.array(loadDict['quantization']['quantizedValues'])
                self._useQuantization = True
            else:
                self._useQuantization = False
    
    
    johannes bilk's avatar
    johannes bilk committed
        def quantize(self, bits: int = 8, scheme: str = "symmetric"):
    
            """
            Quantizes the weight values to a specified bit width.
            """
            self.zeroPoint = 0  # Zero-point is used in asymmetric quantization
    
            # Determine qMin and qMax based on the quantization scheme and bit width
            if scheme == "symmetric":
                self.qMax = 2 ** (bits - 1) - 1
                self.qMin = - self.qMax
            elif scheme == "asymmetric":
                sefl.qMax = 2 ** bits - 1
                self.qMin = 0
            else:
                raise ValueError(f"{scheme} is not a recognized quantization scheme")
    
            if scheme == "asymmetric":
                # Adjust the scale and zeroPoint for asymmetric quantization
                data_min = np.min(self._values)
                data_max = np.max(self._values)
    
                # Scale calculation based on the actual range of the data
                self.scale = (data_max - data_min) / (self.qMax - self.qMin)
    
                # Zero-point calculation, ensuring it's within the quantized value range
                self.zeroPoint = self.qMin - round(data_min / self.scale)
                self.zeroPoint = max(self.qMin, min(self.qMax, self.zeroPoint))
            else:
                # For symmetric quantization, scale is based on the maximum absolute value
    
    johannes bilk's avatar
    johannes bilk committed
                self.maxValue = np.max(np.abs(self._values))
    
                self.scale = (self.qMax - self.qMin) / (2 * self.maxValue)
    
            # Apply quantization
            if scheme == "asymmetric":
                self._quantizedValues = np.round(self._values / self.scale) + self.zeroPoint
            else:
    
    johannes bilk's avatar
    johannes bilk committed
                self._quantizedValues = np.round(self._values * self.scale).astype(np.int32)
    
    
            self._quantizedValues = np.clip(self._quantizedValues, self.qMin, self.qMax).astype(np.int32)  # Ensure values are within range
    
            self._useQuantization = True
    
    johannes bilk's avatar
    johannes bilk committed
    
        def dequantize(self):
            """
            Dequantizes the weight values back to floating point.
            """
            if self._useQuantization:
    
                # For both symmetric and asymmetric, the formula below applies because
                # for symmetric quantization, zeroPoint is 0.
                # Note: Ensure self.scale and self.zeroPoint are correctly set during quantization.
                return (self._quantizedValues - self.zeroPoint) * self.scale
            else:
                # If not quantized, simply return the original values.
                return self._values
    
        @property
        def quantizationRange(self) -> tuple[int, int]:
            return self.qMin, self.qMax