import numpy as np
from .layer import Layer
from .weights import Weights


def checkDims(input: np.ndarray) -> None:
    assert input.ndim == 2, f"Input input should have 2 dimensions, got {input.ndim}"
    batchsize, numFeatures = input.shape
    assert batchsize > 0 and numFeatures > 0, "All dimensions should be greater than 0"


from abc import ABC, abstractmethod
from .weights import Weights
import numpy as np
from numpy.typing import ArrayLike


class Layer(ABC):
    """
    this is an abstract class and can only be used indirectly through inherited classes
    """
    __slots__ = ['name', 'mode', 'layerID']
    id = 0

    def __init__(self) -> None:
        self.name = self.__class__.__name__
        self.mode = ''
        self.layerID = Layer.id
        Layer.id += 1

    @property
    def qualifiedName(self) -> tuple:
        return self.__class__.__module__, self.__class__.__name__

    @abstractmethod
    def forward(self, input: ArrayLike) -> np.ndarray:
        """
        it's an abstract method, thus forcing the coder to implement it in daughter classes
        """
        pass

    def __call__(self, *args: ArrayLike) -> np.ndarray:
        """
        this is used to make layers behave more like functions
        """
        return self.forward(*args)

    @abstractmethod
    def backward(self, gradient: ArrayLike) -> np.ndarray:
        """
        it's an abstract method, thus forcing the coder to implement it in daughter classes
        """
        pass

    def train(self) -> None:
        """
        used to put layer in to training mode
        meaning unfreezes parameters
        """
        self.mode = 'train'

    def eval(self) -> None:
        """
        used to put layer in to evaluation mode
        meaning freezes parameters
        """
        self.mode = 'eval'

    def __str__(self) -> str:
        """
        used for print the layer in a human readable manner
        """
        return self.name
		


class Flatten(Layer):
    """
    This layer flattens any given input, the purpose is to use it after a
    convolution block, in order squeeze all channels into one and prepare
    the input for use in a linear layer
    """
    __slots__ = ['inputShape', 'flatShape']
    def __init__(self) -> None:
        super().__init__()
        self.inputShape = None
        self.flatShape = None

    def forward(self, input: np.ndarray) -> np.ndarray:
        """
        flattens input into a 1D array, according to batchsize
        """
        if self.inputShape is None:
            self.inputShape = input.shape[1:]
            self.flatShape = np.prod(self.inputShape)
        return input.reshape(-1, self.flatShape)

    def backward(self, gradient: np.ndarray) -> np.ndarray:
        """
        unflattens upstream gradient into original input
        """
        return gradient.reshape(-1, *self.inputShape)


class Dropout(Layer):
    """
    dropout layer randomly zeros neurons during forward pass
    and masks the gradient accordingly on the backward pass
    this is used to prevent overfitting
    """
    __slots__ = ['size', 'probability', 'mask']

    def __init__(self, size: int, probability: float) -> None:
        super().__init__()
        self.size = size
        if probability < 0 or probability > 1:
            raise ValueError('probability has to be between 0 and 1')
        self.probability = probability

    def forward(self, input: np.ndarray) -> np.ndarray:
        """
        masking input from a linear layer
        """
        checkDims(input)
        if self.mode == 'train':
            self.mask = np.random.random(input.shape) < (1 - self.probability)
            return np.multiply(input, self.mask) / (1 - self.probability)
        else:
            return input

    def backward(self, gradient: np.ndarray) -> np.ndarray:
        """
        # masking gradient from a linear layer
        """
        return np.multiply(gradient, self.mask) / (1 - self.probability)

    def __str__(self) -> str:
        """
        used for print the layer in a human readable manner
        """
        printString = self.name
        printString += '    size: ' + str(self.size)
        printString += '    probability: ' + str(self.probability)
        return printString