Skip to content
Snippets Groups Projects
parallel.py 2.13 KiB
Newer Older
  • Learn to ignore specific revisions
  • johannes bilk's avatar
    johannes bilk committed
    import numpy as np
    from .module import Module
    
    
    class Parallel(Module):
        """
        a list for layers to be called in parallel
        incorporates the forward and backward pass
        """
        __slots__ = ['shapes', 'batchSize', 'outputShapes', 'slices', 'splits', 'inputs']
    
        def __init__(self, layers: list = None, splits: list = None) -> None:
            super().__init__(layers)
            self.shapes = None
            self.splits = splits
    
        def forward(self, *inputs: np.ndarray) -> np.ndarray:
            """
            calls all layers in parallel and stacks the outputs
            """
            self.batchSize = inputs[0].shape[0]
    
            # stacking inputs for iterating
            if self.splits is not None:
                inputs = [inputs[0][:,one:two] for one, two in zip(self.splits,self.splits[1:])]
                self.inputs = inputs
            elif len(inputs) == 1:
                inputs = [inputs[0]] * len(self)
            if len(inputs) != len(self):
                raise TypeError('number of input elements must be equal to Layers/Modules or one')
            outputs = []
            slices = [0] # used for slicing up incoming gradient
            self.outputShapes = []
    
            # iterating over layers with inputs
            for layer, input in zip(self, inputs):
                output = layer(input)
                self.outputShapes.append(output.shape[1:])
                outputs.append(output.reshape(self.batchSize,-1))
                slices.append(outputs[-1].shape[1])
    
            # slices for splitting gradients
            self.slices = np.cumsum(slices)
    
            return np.hstack(([output for output in outputs]))
    
        def __call__(self, *inputs: np.ndarray) -> np.ndarray:
            """
            needed to overwrite the call method, in order to have multiple inputs
            """
            return self.forward(*inputs)
    
        def backward(self, gradient: np.ndarray) -> np.ndarray:
            """
            calls all layers in parallel in reverse
            """
            gradients = []
            for layer, start, stop, shape in zip(self, self.slices, self.slices[1:], self.outputShapes):
                grad = layer.backward(gradient[:,start:stop].reshape(self.batchSize, *shape))
                gradients.append(grad)
            return gradients