import numpy as np from .weights import Weights from .rnn import RNN, checkDims class LSTM(RNN): """ An implementation of the LSTM layer. """ __slots__ = ['inputSize', 'hiddenSize', 'input', 'hidden', 'cell'] def __init__(self, inputSize: int, hiddenSize: int, weights: np.ndarray = None, bias: np.ndarray = None) -> None: super().__init__() self.inputSize = inputSize self.hiddenSize = hiddenSize # Initialize weights and bias self.weights = Weights((inputSize + hiddenSize, 4 * hiddenSize), values=weights) self.bias = Weights((4 * hiddenSize,), values=bias) # Initialize hidden and cell states self.hidden = np.zeros((hiddenSize,)) self.cell = np.zeros((hiddenSize,)) def forward(self, input: np.ndarray, hiddenState: np.ndarray = None, cellState: np.ndarray = None) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """ forward pass of the LSTM layer """ checkDims(input) self.input = input self.batchSize, self.seqLength, _ = input.shape # Initialize hidden and cell states if not provided if hiddenState is None: hiddenState = np.zeros((self.batchSize, self.seqLength, self.hiddenSize)) if cellState is None: cellState = np.zeros((self.batchSize, self.seqLength, self.hiddenSize)) # Initialize output array output = np.zeros((self.batchSize, self.seqLength, self.hiddenSize)) for t in range(self.seqLength): combined = np.hstack((hiddenState[:, t, :], input[:, t, :])) gates = np.matmul(combined, self.weights.values) + self.bias.values # Compute the input, forget, and output gates inputGate, forgetGate, outputGate, hiddenGate = np.split(gates, 4) # Apply sigmoid activation function for input, forget, and output gates inputGate = 1 / (1 + np.exp(-inputGate)) forgetGate = 1 / (1 + np.exp(-forgetGate)) outputGate = 1 / (1 + np.exp(-outputGate)) # Apply tanh activation function for the cell gate hiddenGate = np.tanh(hiddenGate) # Update the cell and hidden state cellState[:, t, :] = forgetGate * cellState[:, t, :] + inputGate * hiddenGate hiddenState[:, t, :] = outputGate * np.tanh(cellState[:, t, :]) return output, hiddenState, cellState def backward(self, gradient: np.ndarray, hiddenGradient: np.ndarray = None, cellGradient: np.ndarray = None) -> np.ndarray: """ backward pass of the LSTM layer """ gradInputState = np.zeros_like(self.input) dhiddenNext = np.zeros((self.batchSize, self.hiddenSize)) dcellNext = np.zeros((self.batchSize, self.hiddenSize)) dW = np.zeros_like(self.weights.values) db = np.zeros_like(self.bias.values) if hiddenGradient is not None: dhiddenNext += hiddenGradient if cellGradient is not None: dcellNext += cellGradient for t in reversed(range(self.seqLength)): # Compute the input, forget, and output gates inputGate, forgetGate, outputGate, hiddenGate = np.split(gradient[:, t, :], 4) # Partial derivative of loss w.r.t. output gate do = dhiddenNext * np.tanh(self.cell[:, t, :]) do_input = do * outputGate * (1 - outputGate) # Partial derivative of loss w.r.t. cell state dc = dcellNext + dhiddenNext * outputGate * (1 - np.tanh(self.cell[:, t, :]) ** 2) dc_bar = dc * inputGate dc_bar_input = dc_bar * (1 - hiddenGate ** 2) # Partial derivative of loss w.r.t. input gate di = dc * hiddenGate di_input = di * inputGate * (1 - inputGate) # Partial derivative of loss w.r.t. forget gate df = dc * self.cell[:, t - 1, :] df_input = df * forgetGate * (1 - forgetGate) # Stacking the gradients dstacked = np.hstack((di_input, df_input, do_input, dc_bar_input)) # Gradients with respect to weights and biases dW += np.matmul(np.hstack((self.input[:, t, :], self.hidden[:, t - 1, :])).T, dstacked) db += np.sum(dstacked, axis=0) # Gradients with respect to inputs gradInputState[:, t, :] = np.matmul(dstacked, self.weights.values[:self.inputSize].T) # Update for next timestep dhiddenNext = np.matmul(dstacked, self.weights.values[self.inputSize:].T) dcellNext = forgetGate * dc # Store the gradients self.weights.deltas = dW self.bias.deltas = db return gradInputState, dhiddenNext, dcellNext def __str__(self) -> str: """ used for print the layer in a human readable manner """ printString = self.name printString += ' input size: ' + str(self.inputSize) printString += ' hidden size: ' + str(self.hiddenSize) return printString