Skip to content
Snippets Groups Projects
weights.py 3.96 KiB
import numpy as np
from numpy.typing import ArrayLike


def initializeWeights(size: tuple | int, scale: float = 1.0, init: str = 'random') -> ArrayLike:
    """
    Initialize filter using a normal distribution with and a
    standard deviation inversely proportional the square root of the number of units
    """
    if init == 'random':
        stddev = scale/np.sqrt(np.prod(size))
        return np.random.normal(loc=0, scale=stddev, size=size)
    elif init == 'ones':
        return np.ones(size)
    elif init == 'zeros':
        return np.zeros(size)
    else:
        raise ValueError('not a valid init argument')


class Weights(object):
    """
    the idea behind class is to combine everything an optimizer needs into one object
    this way layers and optimizers don't need to take care of storing and providing
    things like previous updates or cache
    """
    __slots__ = ['_values', '_quantizedValues', 'prevValues', 'deltas', 'prevDeltas', 'cache', 'scale', 'maxValue', '_useQuantization']

    def __init__(self, size: tuple | int, values: ArrayLike = None, init: str = 'random') -> None:
        self._values = initializeWeights(size, init=init) if values is None else values
        self.prevValues = None
        self.deltas = np.zeros(size)
        self.prevDeltas = None
        self.cache = None

        self._quantizedValues = np.zeros(size)
        self.scale = 1
        self.maxValue = 0
        self._useQuantization = False

    @property
    def values(self):
        """
        Depending on the _useQuantized flag, return either the original
        or quantized (and dequantized back) weight values for computation.
        """
        if self._useQuantization:
            return self.dequantize()
        else:
            return self._values

    @values.setter
    def values(self, newValues):
        """
        Allow updates to the weight values directly.
        """
        self._values = newValues

    @property
    def qualifiedName(self) -> tuple:
        return self.__class__.__module__, self.__class__.__name__

    def toDict(self) -> dict:
        saveDict = {}
        saveDict['size'] = self._values.shape
        saveDict['values'] = self._values.tolist()
        saveDict['deltas'] = self.deltas.tolist()
        if self.prevValues is not None:
            saveDict['prevValues'] = self.prevValues.tolist()
        saveDict['cache'] = {}
        if type(self.cache) == dict:
            saveDict['cache']['values'] = {}
            for key in self.cache:
                saveDict['cache']['values'][key] = self.cache[key].tolist()
            saveDict['cache']['type'] = 'dict'
        elif type(self.cache) == np.ndarray:
            saveDict['cache']['values'] = self.cache.tolist()
            saveDict['cache']['type'] = 'np.ndarray'

        return saveDict

    def fromDict(self, loadDict: dict) -> None:
        self._values = np.array(loadDict['values'])
        self.deltas = np.array(loadDict['deltas'])
        if 'prevValues' in loadDict:
            self.prevValues = np.array(loadDict['prevValues'])
        if loadDict['cache']['type'] == 'np.ndarray':
            self.cache = np.array(loadDict['cache']['values'])
        elif loadDict['cache']['type'] == 'dict':
            self.cache = {}
            for key in loadDict['cache']['values']:
                self.cache[key] = np.array(loadDict['cache']['values'][key])

    def quantize(self, bits: int = 8, scheme: str = "symmetric"):
            """
            Quantizes the weight values to a specified bit width.
            """
            self.maxValue = np.max(np.abs(self._values))
            self.scale = (2 ** bits - 1) / self.maxValue
            self._quantizedValues = np.round(self._values * self.scale).astype(np.int32)

            self._useQuantization = True

    def dequantize(self):
        """
        Dequantizes the weight values back to floating point.
        """
        if self._useQuantization:
            return self._quantizedValues.astype(np.float32) / self.scale
        return self._values