Skip to content
Snippets Groups Projects
weights.py 2.65 KiB
Newer Older
  • Learn to ignore specific revisions
  • johannes bilk's avatar
    johannes bilk committed
    import numpy as np
    from numpy.typing import ArrayLike
    
    
    def initializeWeights(size: tuple | int, scale: float = 1.0, init: str = 'random') -> ArrayLike:
        """
        Initialize filter using a normal distribution with and a
        standard deviation inversely proportional the square root of the number of units
        """
        if init == 'random':
            stddev = scale/np.sqrt(np.prod(size))
            return np.random.normal(loc=0, scale=stddev, size=size)
        elif init == 'ones':
            return np.ones(size)
        elif init == 'zeros':
            return np.zeros(size)
        else:
            raise ValueError('not a valid init argument')
    
    
    class Weights(object):
        """
        the idea behind class is to combine everything an optimizer needs into one object
        this way layers and optimizers don't need to take care of storing and providing
        things like previous updates or cache
        """
        __slots__ = ['values', 'prevValues', 'deltas', 'prevDeltas', 'cache']
    
        def __init__(self, size: tuple | int, values: ArrayLike = None, init: str = 'random') -> None:
            self.values = initializeWeights(size, init=init) if values is None else values
            self.prevValues = None
            self.deltas = np.zeros(size)
            self.prevDeltas = None
            self.cache = None
    
        @property
        def qualifiedName(self) -> tuple:
            return self.__class__.__module__, self.__class__.__name__
    
        def toDict(self) -> dict:
            saveDict = {}
            saveDict['size'] = self.values.shape
            saveDict['values'] = self.values.tolist()
            saveDict['deltas'] = self.deltas.tolist()
            if self.prevValues is not None:
                saveDict['prevValues'] = self.prevValues.tolist()
            saveDict['cache'] = {}
            if type(self.cache) == dict:
                saveDict['cache']['values'] = {}
                for key in self.cache:
                    saveDict['cache']['values'][key] = self.cache[key].tolist()
                saveDict['cache']['type'] = 'dict'
            elif type(self.cache) == np.ndarray:
                saveDict['cache']['values'] = self.cache.tolist()
                saveDict['cache']['type'] = 'np.ndarray'
    
            return saveDict
    
        def fromDict(self, loadDict: dict) -> None:
            self.values = np.array(loadDict['values'])
            self.deltas = np.array(loadDict['deltas'])
            if 'prevValues' in loadDict:
                self.prevValues = np.array(loadDict['prevValues'])
            if loadDict['cache']['type'] == 'np.ndarray':
                self.cache = np.array(loadDict['cache']['values'])
            elif loadDict['cache']['type'] == 'dict':
                self.cache = {}
                for key in loadDict['cache']['values']:
                    self.cache[key] = np.array(loadDict['cache']['values'][key])