import numpy as np
from numpy.typing import ArrayLike
from .layer import Layer


class L1Norm(Layer):
    # Parameters
    # ‾‾‾‾‾‾‾‾‾‾
    # Attributes
    # ‾‾‾‾‾‾‾‾‾‾
    # Methods
    # ‾‾‾‾‾‾‾
    def __init__(self, axis=None, epsilon: float = 1e-8) -> None:
        super().__init__()
        self.axis = axis
        self.scales = None
        self.epsilon = epsilon

    def forward(self, input: ArrayLike) -> ArrayLike:
        # Parameters
        # ‾‾‾‾‾‾‾‾‾‾
        # Returns
        # ‾‾‾‾‾‾‾
        norm = np.abs(input).sum(axis=self.axis, keepdims=True)
        norm = 1. / (norm + self.epsilon)
        output = input * norm
        self.scales = -np.sign(output)
        self.gradient = np.zeros_like(output, dtype=float)
        return output

    def backward(self, gradient: ArrayLike) -> ArrayLike:
        # Parameters
        # ‾‾‾‾‾‾‾‾‾‾
        # Returns
        # ‾‾‾‾‾‾‾
        self.gradient += self.scales
        gradient[:] += self.gradient
        return gradient


class L2Norm(Layer):
    # Parameters
    # ‾‾‾‾‾‾‾‾‾‾
    # Attributes
    # ‾‾‾‾‾‾‾‾‾‾
    # Methods
    # ‾‾‾‾‾‾‾
    def __init__(self, axis=None, epsilon: float = 1e-8) -> None:
        super().__init__()
        self.axis = axis
        self.scales = None
        self.epsilon = epsilon

    def forward(self, input: ArrayLike) -> ArrayLike:
        # Parameters
        # ‾‾‾‾‾‾‾‾‾‾
        # Returns
        # ‾‾‾‾‾‾‾
        norm = (input * input).sum(axis=self.axis, keepdims=True)
        norm = 1. / np.sqrt(norm + self.epsilon)
        output = input * norm
        self.scales = (1. - output) * norm
        self.gradient = np.zeros_like(output, dtype=float)
        return output

    def backward(self, gradient: ArrayLike) -> ArrayLike:
        # Parameters
        # ‾‾‾‾‾‾‾‾‾‾
        # Returns
        # ‾‾‾‾‾‾‾
        self.gradient += self.scales
        gradient[:] += self.delta
        return gradient