import numpy as np
from .layer import Layer
from .module import Module
from .weights import Weights


class Regularization(Layer):
    def __init__(self, layers: list | Module, Lambda: float) -> None:
        self.name = self.__class__.__name__
        self.lambda = Lambda
        self.params = []

        for layer in layers:
            try:
                params = layer.params()
                for param in params:
                    self.params.append(param)
            except AttributeError:
                # 'params' method not found in the layer, skip updating
                continue

    def forward(self, input: np.ndarray) -> np.ndarray:
        return input

    @abstractmethod
    def backward(self, gradient: np.ndarray) -> np.ndarray:
        pass


class L1Regularization(Regularization):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def backward(self, gradient: np.ndarray) -> np.ndarray:
    # Compute regularization gradients and add to existing gradients
    for param in self.params:
        gradient += self.Lambda * np.sign(param.values)
    return gradient


class L2Regularization(Regularization):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def backward(self, gradient: np.ndarray) -> np.ndarray:
        # Compute regularization gradients and add to existing gradients
        for param in self.params:
            gradient += self.Lambda * 2 * param.values
        return gradient