Skip to content
Snippets Groups Projects
lstm.py 4.97 KiB
Newer Older
johannes bilk's avatar
johannes bilk committed
import numpy as np
from .weights import Weights
from .rnn import RNN, checkDims


class LSTM(RNN):
    """
    An implementation of the LSTM layer.
    """
    __slots__ = ['inputSize', 'hiddenSize', 'input', 'hidden', 'cell']

    def __init__(self, inputSize: int, hiddenSize: int, weights: np.ndarray = None, bias: np.ndarray = None) -> None:
        super().__init__()
        self.inputSize = inputSize
        self.hiddenSize = hiddenSize

        # Initialize weights and bias
        self.weights = Weights((inputSize + hiddenSize, 4 * hiddenSize), values=weights)
        self.bias = Weights((4 * hiddenSize,), values=bias)

        # Initialize hidden and cell states
        self.hidden = np.zeros((hiddenSize,))
        self.cell = np.zeros((hiddenSize,))

    def forward(self, input: np.ndarray, hiddenState: np.ndarray = None, cellState: np.ndarray = None) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
        forward pass of the LSTM layer
        """
        checkDims(input)

        self.input = input
        self.batchSize, self.seqLength, _ = input.shape

        # Initialize hidden and cell states if not provided
        if hiddenState is None:
            hiddenState = np.zeros((self.batchSize, self.seqLength, self.hiddenSize))
        if cellState is None:
            cellState = np.zeros((self.batchSize, self.seqLength, self.hiddenSize))

        # Initialize output array
        output = np.zeros((self.batchSize, self.seqLength, self.hiddenSize))

        for t in range(self.seqLength):
            combined = np.hstack((hiddenState[:, t, :], input[:, t, :]))
            gates = np.matmul(combined, self.weights.values) + self.bias.values

            # Compute the input, forget, and output gates
            inputGate, forgetGate, outputGate, hiddenGate = np.split(gates, 4)

            # Apply sigmoid activation function for input, forget, and output gates
            inputGate = 1 / (1 + np.exp(-inputGate))
            forgetGate = 1 / (1 + np.exp(-forgetGate))
            outputGate = 1 / (1 + np.exp(-outputGate))

            # Apply tanh activation function for the cell gate
            hiddenGate = np.tanh(hiddenGate)

            # Update the cell and hidden state
            cellState[:, t, :] = forgetGate * cellState[:, t, :] + inputGate * hiddenGate
            hiddenState[:, t, :] = outputGate * np.tanh(cellState[:, t, :])

        return output, hiddenState, cellState

    def backward(self, gradient: np.ndarray, hiddenGradient: np.ndarray = None, cellGradient: np.ndarray = None) -> np.ndarray:
        """
        backward pass of the LSTM layer
        """
        gradInputState = np.zeros_like(self.input)
        dhiddenNext = np.zeros((self.batchSize, self.hiddenSize))
        dcellNext = np.zeros((self.batchSize, self.hiddenSize))
        dW = np.zeros_like(self.weights.values)
        db = np.zeros_like(self.bias.values)

        if hiddenGradient is not None:
            dhiddenNext += hiddenGradient
        if cellGradient is not None:
            dcellNext += cellGradient

        for t in reversed(range(self.seqLength)):
            # Compute the input, forget, and output gates
            inputGate, forgetGate, outputGate, hiddenGate = np.split(gradient[:, t, :], 4)

            # Partial derivative of loss w.r.t. output gate
            do = dhiddenNext * np.tanh(self.cell[:, t, :])
            do_input = do * outputGate * (1 - outputGate)

            # Partial derivative of loss w.r.t. cell state
            dc = dcellNext + dhiddenNext * outputGate * (1 - np.tanh(self.cell[:, t, :]) ** 2)
            dc_bar = dc * inputGate
            dc_bar_input = dc_bar * (1 - hiddenGate ** 2)

            # Partial derivative of loss w.r.t. input gate
            di = dc * hiddenGate
            di_input = di * inputGate * (1 - inputGate)

            # Partial derivative of loss w.r.t. forget gate
            df = dc * self.cell[:, t - 1, :]
            df_input = df * forgetGate * (1 - forgetGate)

            # Stacking the gradients
            dstacked = np.hstack((di_input, df_input, do_input, dc_bar_input))

            # Gradients with respect to weights and biases
            dW += np.matmul(np.hstack((self.input[:, t, :], self.hidden[:, t - 1, :])).T, dstacked)
            db += np.sum(dstacked, axis=0)

            # Gradients with respect to inputs
            gradInputState[:, t, :] = np.matmul(dstacked, self.weights.values[:self.inputSize].T)

            # Update for next timestep
            dhiddenNext = np.matmul(dstacked, self.weights.values[self.inputSize:].T)
            dcellNext = forgetGate * dc

        # Store the gradients
        self.weights.deltas = dW
        self.bias.deltas = db

        return gradInputState, dhiddenNext, dcellNext

    def __str__(self) -> str:
        """
        used for print the layer in a human readable manner
        """
        printString = self.name
        printString += '    input size: ' + str(self.inputSize)
        printString += '    hidden size: ' + str(self.hiddenSize)
        return printString