Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import numpy as np
from .module import Module
class Parallel(Module):
"""
a list for layers to be called in parallel
incorporates the forward and backward pass
"""
__slots__ = ['shapes', 'batchSize', 'outputShapes', 'slices', 'splits', 'inputs']
def __init__(self, layers: list = None, splits: list = None) -> None:
super().__init__(layers)
self.shapes = None
self.splits = splits
def forward(self, *inputs: np.ndarray) -> np.ndarray:
"""
calls all layers in parallel and stacks the outputs
"""
self.batchSize = inputs[0].shape[0]
# stacking inputs for iterating
if self.splits is not None:
inputs = [inputs[0][:,one:two] for one, two in zip(self.splits,self.splits[1:])]
self.inputs = inputs
elif len(inputs) == 1:
inputs = [inputs[0]] * len(self)
if len(inputs) != len(self):
raise TypeError('number of input elements must be equal to Layers/Modules or one')
outputs = []
slices = [0] # used for slicing up incoming gradient
self.outputShapes = []
# iterating over layers with inputs
for layer, input in zip(self, inputs):
output = layer(input)
self.outputShapes.append(output.shape[1:])
outputs.append(output.reshape(self.batchSize,-1))
slices.append(outputs[-1].shape[1])
# slices for splitting gradients
self.slices = np.cumsum(slices)
return np.hstack(([output for output in outputs]))
def __call__(self, *inputs: np.ndarray) -> np.ndarray:
"""
needed to overwrite the call method, in order to have multiple inputs
"""
return self.forward(*inputs)
def backward(self, gradient: np.ndarray) -> np.ndarray:
"""
calls all layers in parallel in reverse
"""
gradients = []
for layer, start, stop, shape in zip(self, self.slices, self.slices[1:], self.outputShapes):
grad = layer.backward(gradient[:,start:stop].reshape(self.batchSize, *shape))
gradients.append(grad)
return gradients