import numpy as np
from numpy.typing import ArrayLike
from typing import Any, Callable
from abc import ABC, abstractmethod
from .backend import BackendInterface, NumpyBackend, CupyBackend, NumbaBackend


class Tensor(object):
    __slots__ = ['_backend', 'data', 'gradient', 'requireGradient', 'gradientFunc', 'batched']

    __backend__ = NumpyBackend()

    def __init__(self, data: Any,
                 gradient: Any = None,
                 gradientFunc: Callable = None,
                 requireGradient: bool = False,
                 batched: bool = True) -> None:

        self._backend = Tensor.__backend__

        #if isinstance(data, (list | np.ndarray)):
        #    data = self._backend.array(data)
        #elif isinstance(data, (int, float)):
        #    data = self._backend.array([data])
        #elif isinstance(data, self.__class__):
        #    gradient = data.gradient if gradient is None else gradient
        #    gradientFunc = data.gradientFunc if gradientFunc is None else gradientFunc
        #    requireGradient = data.requireGradient if requireGradient is False else requireGradient
        #    data = data.data
        
        #if len(data.shape) == 1:
        #    data = self._backend.reshape(data, (1, *data.shape))

        #if gradient is None and requireGradient:
        #    # If gradient is not provided and it's required, initialize it as None
        #    gradient = self._backend.zeros_like(data)
        #elif isinstance(gradient, (list, int, float)):
        #    gradient = self._backend.array(gradient)

        # Checking if the shapes are the same
        #if gradient is not None:
        #    assert data.shape == gradient.shape, "value and gradient must have the same shape"

        self.data = data
        self.gradient = gradient
        self.requireGradient = requireGradient
        self.gradientFunc = gradientFunc
        self.batched = batched

    def zeroGradient(self) -> None:
        """In-place operation for nulling the gradient"""
        if self.requireGradient:
            self.gradient = self._backend.zeros_like(self.data)
        else:
            raise AttributeError("this tensor is not differentiable")

    def backward(self, gradient=None):
        """
        Compute the gradients recursively by applying the chain rule.
        """
        if gradient is None:
            gradient = self._backend.ones_like(self.data)

        if not self.requireGradient:
            return

        # If grad_fn is not set, this is probably the starting point for backpropagation,
        # so we don't need to compute further backward.
        if self.gradientFunc is None:
            return

        # Accumulate gradients instead of overwriting.
        self.gradient += gradient
        # Compute the local gradients using grad_fn
        self.gradientFunc.backward(self.gradient)

    def __repr__(self) -> str:
        """String representation."""
        dataTitle = 'data:\n'
        gradientTitle = 'gradient:\n'
        dataStr = str(self.data)
        gradientStr = str(self.gradient)
        if self.requireGradient is True:
            return dataTitle + dataStr + '\n' + gradientTitle + gradientStr
        else:
            return dataTitle + dataStr

    def copy(self) -> 'Tensor':
        data = self._backend.copy(self.data)
        gradient = self._backend.copy(self.gradient)
        return self.__class__(data, gradient, gradientFunc=self.gradientFunc, requireGradient=self.requireGradient)

    @property
    def strides(self) -> tuple:
        return self.data.strides

    def __len__(self) -> int:
        """Return the length of the value."""
        return len(self.data)

    @property
    def shape(self) -> tuple:
        """Return the shape of the value."""
        return self.data.shape
    
    @property
    def ndim(self) -> tuple:
        """Return the ndim of the value."""
        return self.data.ndim

    def reshape(self, newShape) -> 'Tensor':
        return Reshape()(self, newShape)
    
    def transpose(self) -> 'Tensor':
        return Transpose()(self)

    def T(self) -> 'Tensor':
        return Transpose()(self)

    def tolist(self) -> tuple[list, list] | list:
        if self.requireGradient is True:
            return self.data.tolist(), self.gradient.tolist()
        else:
            return self.data.tolist()

    @classmethod
    def setBackend(cls, backend: BackendInterface) -> None:
        cls.__backend__ = backend

    def __getitem__(self, index):
        """Get an item by index."""
        if self.requireGradient is True:
            return self.__class__(data=self.data[index], gradient=self.gradient[index], requireGradient=True, gradientFunc=self.gradientFunc)
        else:
            return self.__class__(data=self.data[index], requireGradient=False)

    def __setitem__(self, index, value) -> None:
        """Set the value of an item by index."""
        if isinstance(value, self.__class__):
            self.data[index] = value.data
            if value.requireGradient is True:
                self.gradient[index] = value.gradient
                self.requireGradient = True
        else:
            self.data[index] = value
            self.gradient[index] = 0

    def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
        if method == '__call__':
            operation = ufuncMap.get(ufunc)
            if operation is not None:
                return operation()(*inputs, **kwargs)
        raise NotImplementedError(f'{ufunc} is not implemented yet')

    def __array_function__(self, func, types, args, kwargs):
        operation = funcMap.get(func)
        if operation is not None:
            return operation()(*args, **kwargs)
        raise NotImplementedError(f'{func} is not implemented yet')
        
    def __add__(self, other: ArrayLike) -> 'Tensor':
        return Add()(self, other)

    def __radd__(self, other: ArrayLike) -> 'Tensor':
        return Add()(other, self)

    def __iadd__(self, other: ArrayLike) -> 'Tensor':
        result = Add()(self, other)
        self.data = result.data
        self.gradient = result.gradient
        self.requireGradient = result.requireGradient
        return self

    def __sub__(self, other: ArrayLike) -> 'Tensor':
        return Subtract()(self, other)

    def __rsub__(self, other: ArrayLike) -> 'Tensor':
        return Subtract()(other, self)

    def __isub__(self, other: ArrayLike) -> 'Tensor':
        result = Subtract(self, other)
        self.data = result.data
        self.gradient = result.gradient
        self.requireGradient = result.requireGradient
        return self

    def __mul__(self, other: ArrayLike) -> 'Tensor':
        return Multiply()(self, other)

    def __rmul__(self, other: ArrayLike) -> 'Tensor':
        return Multiply()(other, self)

    def __imul__(self, other: ArrayLike) -> 'Tensor':
        result = Multiply(self, other)
        self.data = result.data
        self.gradient = result.gradient
        self.requireGradient = result.requireGradient
        return self

    def __truediv__(self, other: ArrayLike) -> 'Tensor':
        return Divide()(self, other)

    def __rtruediv__(self, other: ArrayLike) -> 'Tensor':
        return Divide()(other, self)

    def __itruediv__(self, other: ArrayLike) -> 'Tensor':
        result = Divide(self, other)
        self.data = result.data
        self.gradient = result.gradient
        self.requireGradient = result.requireGradient
        return self

    def __matmul__(self, other: ArrayLike) -> 'Tensor':
        return Matmul()(self, other)

    def __rmatmul__(self, other: ArrayLike) -> 'Tensor':
        return Matmul()(other, self)

    def __imatmul__(self, other: ArrayLike) -> 'Tensor':
        result = Matmul(self, other)
        self.data = result.data
        self.gradient = result.gradient
        self.requireGradient = result.requireGradient
        return self

    def __pow__(self, other: ArrayLike) -> 'Tensor':
        return Power()(self, other)

    def __rpow__(self, other: ArrayLike) -> 'Tensor':
        return Power()(other, self)

    def __ipow__(self, other: ArrayLike) -> 'Tensor':
        result = Power(self, other)
        self.data = result.data
        self.gradient = result.gradient
        self.requireGradient = result.requireGradient
        return self

    def __abs__(self) -> 'Tensor':
        return Abs()(self)

    def __pos__(self) -> 'Tensor':
        return Positive()(self)

    def __neg__(self) -> 'Tensor':
        return Negative()(self)

    def __eq__(self, other) -> bool:
        """Equality comparison."""
        return Equal()(self, other)

    def __gt__(self, other) -> bool:
        """Greater than comparison."""
        return Greater()(self, other)

    def __ge__(self, other) -> bool:
        """Greater than or equal to comparison."""
        return GreaterEqual()(self, other)

    def __lt__(self, other) -> bool:
        """Less than comparison."""
        return Less()(self, other)

    def __le__(self, other) -> bool:
        """Less than or equal to comparison."""
        return LessEqual()(self, other)


def checkTensor(tensor: Tensor) -> Tensor:
    if isinstance(tensor, Tensor):
        return tensor
    return Tensor(tensor)


#
# Operations
#


class Operation(ABC):
    __slots__ = ['name', 'operationID', 'backend']
    id = 0
    __backend__ = Tensor.__backend__

    def __init__(self) -> None:
        self.name = self.__class__.__name__
        self.operationID = Operation.id
        self.backend = Operation.__backend__
        Operation.id += 1

    @abstractmethod
    def forward(self, *args, **kwargs) -> Tensor:
        raise NotImplementedError

    @abstractmethod
    def backward(self, gradient: np.ndarray) -> None:
        raise NotImplementedError

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

    def __repr__(self) -> str:
        return f'{self.name}, {self.operationID}'

    def at(self, tensor: Tensor, indices, value) -> None:
         # not ready for use yet
         tensor = self.forward(indices, value)


class TwoTensors(Operation):
    __slots__ = ['tensor1', 'tensor2']

    def __init__(self) -> None:
        super().__init__()
        self.tensor1 = None
        self.tensor2 = None
        self.tensor1BroadcastAxis = None
        self.tensor2BroadcastAxis = None
  
    def getbroadcastAxid(self, data, gradient) -> None:
        # Store old shapes
        tensorShape = np.array(data.shape)
        
        # Get new shape
        gradientShape = np.array(gradient.shape)

        # Prepend ones to the shape of the smaller array
        if len(tensorShape) < len(gradientShape):
            tensorShape = np.pad(tensorShape, (len(gradientShape) - len(tensorShape), 0), mode='constant', constant_values=1)
        elif len(tensorShape) > len(gradientShape):
            gradientShape = np.pad(gradientShape, (len(tensorShape) - len(gradientShape), 0), mode='constant', constant_values=1)
        
        # Find broadcasted axes
        tensorBroadcastAxis = np.where(tensorShape != gradientShape)[0]

        # Change broadcastAxis variables to None if they're empty
        if tensorBroadcastAxis.size == 0:
            tensorBroadcastAxis = None
        
        return tensorBroadcastAxis

    def forward(self, tensor1: Tensor, tensor2: Tensor, *args, **kwargs) -> Tensor:
        if not isinstance(tensor1, Tensor):
            tensor1 = Tensor(tensor1)
        if not isinstance(tensor2, Tensor):
            tensor2 = Tensor(tensor2)
        
        requireGradient = tensor1.requireGradient or tensor2.requireGradient
        if requireGradient:
            self.tensor1 = tensor1
            self.tensor2 = tensor2

        data = self._operation(tensor1.data, tensor2.data, *args, **kwargs)

        return Tensor(data, requireGradient=requireGradient, gradientFunc=self)

    def backward(self, gradient: np.ndarray) -> None:
        if self.tensor1 and self.tensor1.requireGradient:
            gradientForTensor1 = self.backend.copy(gradient)

            tensorBroadcastAxis = self.getbroadcastAxid(self.tensor1, gradientForTensor1)
            if tensorBroadcastAxis is not None:
                gradientForTensor1 = np.sum(gradientForTensor1, axis=tuple(tensorBroadcastAxis), keepdims=True)

            self.tensor1.gradient = self._derivativeD1(gradientForTensor1)
            if self.tensor1.gradientFunc:
                self.tensor1.gradientFunc.backward(self.tensor1.gradient)

        if self.tensor2 and self.tensor2.requireGradient:
            gradientForTensor2 = self.backend.copy(gradient)

            tensorBroadcastAxis = self.getbroadcastAxid(self.tensor2, gradientForTensor2)
            if tensorBroadcastAxis is not None:
                gradientForTensor2 = np.sum(gradientForTensor2, axis=tuple(tensorBroadcastAxis), keepdims=True)

            self.tensor2.gradient = self._derivativeD2(gradientForTensor2)
            if self.tensor2.gradientFunc:
                self.tensor2.gradientFunc.backward(self.tensor2.gradient)

    @abstractmethod
    def _operation(self, data1: np.ndarray, data2: np.ndarray, *args, **kwargs) -> np.ndarray:
        raise NotImplementedError

    @abstractmethod
    def _derivativeD1(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        raise NotImplementedError
    
    @abstractmethod
    def _derivativeD2(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        raise NotImplementedError


class OneTensor(Operation):
    __slots__ = ['tensor']

    def __init__(self) -> None:
        super().__init__()
        self.tensor = None

    def forward(self, tensor: Tensor, *args, **kwargs) -> Tensor:
        if not isinstance(tensor, Tensor):
            tensor = Tensor(tensor)

        requireGradient = tensor.requireGradient
        if requireGradient:
            self.tensor = tensor

        data = self._operation(tensor.data, *args, **kwargs)
       
        return Tensor(data, requireGradient=requireGradient, gradientFunc=self)

    def backward(self, gradient: np.ndarray) -> None:
        if self.tensor and self.tensor.requireGradient:
            self.tensor.gradient = self._derivative(gradient)
            if self.tensor.gradientFunc:
                self.tensor.gradientFunc.backward(self.tensor.gradient)

    @abstractmethod
    def _operation(self, data: np.ndarray, *args, **kwargs) -> np.ndarray:
        raise NotImplementedError
    
    @abstractmethod
    def _derivative(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        raise NotImplementedError


#
# Two Tensors
#


class Add(TwoTensors):
    def __init__(self) -> None:
        super().__init__()

    def _operation(self, data1: np.ndarray, data2: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.add(data1, data2, *args, **kwargs)

    def _derivativeD1(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.add(self.tensor1.gradient, gradient)
    
    def _derivativeD2(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.add(self.tensor2.gradient, gradient)


class Subtract(TwoTensors):
    def __init__(self) -> None:
        super().__init__()

    def _operation(self, data1: np.ndarray, data2: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.subtract(data1, data2, *args, **kwargs)

    def _derivativeD1(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.add(self.tensor1.gradient, gradient)
    
    def _derivativeD2(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.subtract(self.tensor2.gradient, gradient)


class Multiply(TwoTensors):
    def __init__(self) -> None:
        super().__init__()

    def _operation(self, data1: np.ndarray, data2: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.multiply(data1, data2, *args, **kwargs)

    def _derivativeD1(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.add(self.tensor1.gradient, self.backend.multiply(self.tensor2.data, gradient))
    
    def _derivativeD2(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.add(self.tensor2.gradient, self.backend.multiply(self.tensor1.data, gradient))


class Divide(TwoTensors):
    def __init__(self) -> None:
        super().__init__()

    def _operation(self, data1: np.ndarray, data2: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.divide(data1, data2, *args, **kwargs)

    def _derivativeD1(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.add(self.tensor1.gradient, self.backend.divide(gradient, self.tensor2.data))
    
    def _derivativeD2(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.subtract(self.tensor2.gradient, self.backend.divide(self.backend.multiply(self.tensor1.data, gradient), self.backend.power(self.tensor2.data, 2)))


class Matmul(TwoTensors):
    def __init__(self) -> None:
        super().__init__()

    def _operation(self, data1: np.ndarray, data2: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.matmul(data1, data2, *args, **kwargs)
    
    # Update the backward pass to handle batch dimension
    def _derivativeD1(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        if len(self.tensor1.data.shape) > 2 or len(self.tensor2.data.shape) > 2:
            return self.backend.add(self.tensor1.gradient, self.backend.matmul(gradient, self.backend.transpose(self.tensor2.data, axes=(0, 2, 1))))
        else:
            return self.backend.add(self.tensor1.gradient, self.backend.matmul(gradient, self.backend.transpose(self.tensor2.data)))
    
    def _derivativeD2(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        if len(self.tensor1.data.shape) > 2 or len(self.tensor2.data.shape) > 2:
            return self.backend.add(self.tensor2.gradient, self.backend.matmul(self.backend.transpose(self.tensor1.data, axes=(0, 2, 1)), gradient))
        else:
            return self.backend.add(self.tensor2.gradient, self.backend.matmul(self.backend.transpose(self.tensor1.data), gradient))
    
    def backward(self, gradient: np.ndarray) -> None:
        if self.tensor1 and self.tensor1.requireGradient:
            gradientForTensor1 = self.backend.copy(gradient)

            self.tensor1.gradient = self._derivativeD1(gradientForTensor1)
            if self.tensor1.gradientFunc:
                self.tensor1.gradientFunc.backward(self.tensor1.gradient)

        if self.tensor2 and self.tensor2.requireGradient:
            gradientForTensor2 = self.backend.copy(gradient)

            self.tensor2.gradient = self._derivativeD2(gradientForTensor2)
            if self.tensor2.gradientFunc:
                self.tensor2.gradientFunc.backward(self.tensor2.gradient)


class Dot(TwoTensors):
    def __init__(self) -> None:
        super().__init__()

    def _operation(self, data1: np.ndarray, data2: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.dot(data1, data2)

    def _derivativeD1(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.multiply(self.tensor2.data, gradient)
    
    def _derivativeD2(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.multiply(self.tensor1.data, gradient)


class Power(TwoTensors):
    def __init__(self) -> None:
        super().__init__()
        
    def _operation(self, data1: np.ndarray, data2: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.power(data1, data2)
    
    def _derivativeD1(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.add(self.tensor1.gradient, self.backend.multiply(self.backend.multiply(self.tensor2.data, self.backend.power(self.tensor1.data, (self.backend.subtract(self.tensor2.data, 1)))), gradient))
    
    def _derivativeD2(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.add(self.tensor2.gradient, self.backend.multiply(self.backend.multiply(self.backend.log(self.tensor1.data), self.backend.power(self.tensor1.data, self.tensor2.data)), gradient))


#
# Single Tensor
#


class Square(OneTensor):
    def __init__(self) -> None:
        super().__init__()

    def _operation(self, data: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.square(data, *args, **kwargs)
    
    def _derivative(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.multiply(self.backend.multiply(self.tensor.data, 2.0), gradient)


class Sqrt(OneTensor):
    def __init__(self) -> None:
        super().__init__()

    def _operation(self, data: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.sqrt(data, *args, **kwargs)
    
    def _derivative(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.multiply(self.backend.divide(0.5, self.backend.sqrt(self.tensor.data)), gradient)
   

class Log(OneTensor):
    def __init__(self) -> None:
        super().__init__()

    def _operation(self, data: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.log(data, *args, **kwargs)
    
    def _derivative(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.add(self.tensor.gradient, self.backend.multiply((self.backend.divide(1, self.tensor.data)), gradient))


class Exp(OneTensor):
    __slots__ = ['data']

    def __init__(self) -> None:
        super().__init__()
        self.data = None
    
    def _operation(self, data: np.ndarray, *args, **kwargs) -> np.ndarray:
        self.data = self.backend.exp(data, *args, **kwargs)
        return self.data
    
    def _derivative(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.data * gradient


class Sin(OneTensor):
    def __init__(self) -> None:
        super().__init__()

    def _operation(self, data: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.sin(data, *args, **kwargs)
    
    def _derivative(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.multiply(self.backend.cos(self.tensor.data), gradient)


class Cos(OneTensor):
    def __init__(self) -> None:
        super().__init__()
    
    def _operation(self, data: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.cos(data, *args, **kwargs)
    
    def _derivative(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.negative(self.backend.multiply(self.backend.sin(self.tensor.data), gradient))


class Tan(OneTensor):
    def __init__(self) -> None:
        super().__init__()

    def _operation(self, data: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.tan(data, *args, **kwargs)
    
    def _derivative(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.multiply((self.backend.divide(1, self.backend.power(np.cos(self.tensor.data), 2))), gradient)


class Sinh(OneTensor):
    def __init__(self) -> None:
        super().__init__()

    def _operation(self, data: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.sinh(data, *args, **kwargs)
    
    def _derivative(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.multiply(self.backend.cosh(self.tensor.data), gradient)


class Cosh(OneTensor):
    def __init__(self) -> None:
        super().__init__()

    def _operation(self, data: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.cosh(data, *args, **kwargs)
    
    def _derivative(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.multiply(self.backend.sinh(self.tensor.data), gradient)


class Tanh(OneTensor):
    def __init__(self) -> None:
        super().__init__()
    
    def _operation(self, data: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.tanh(data, *args, **kwargs)
    
    def _derivative(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.multiply((self.backend.divide(1, self.backend.power(np.cosh(self.tensor.data), 2))), gradient)


class Abs(OneTensor):
    def __init__(self) -> None:
        super().__init__()

    def _operation(self, data: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.abs(data, *args, **kwargs)
    
    def _derivative(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.multiply(self.backend.sign(self.tensor.data), gradient)


#
# Signs
#


class Sign(OneTensor):
    def __init__(self) -> None:
        super().__init__()

    def _operation(self, data: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.sign(data, *args, **kwargs)
    
    def _derivative(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.multiply(self.backend.sign(self.tensor.data), gradient)


class Positive(OneTensor):
    def __init__(self) -> None:
        super().__init__()
    
    def _operation(self, data: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.positive(data, *args, **kwargs)
    
    def _derivative(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.multiply(self.backend.positive(self.tensor.data), gradient)


class Negative(OneTensor):
    def __init__(self) -> None:
        super().__init__()
    
    def _operation(self, data: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.negative(data, *args, **kwargs)
    
    def _derivative(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.multiply(self.backend.negative(self.tensor.data), gradient)


#
# Compare
#


class Equal(TwoTensors):
    __slots__ = ['bools']

    def __init__(self) -> None:
        super().__init__()
        self.bools = None

    def _operation(self, data1: np.ndarray, data2: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.equal(data1, data2)
    
    def _derivative(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.multiply(self.bools, gradient)


class NotEqual(TwoTensors):
    __slots__ = ['bools']

    def __init__(self) -> None:
        super().__init__()

    def _operation(self, data1: np.ndarray, data2: np.ndarray, *args, **kwargs) -> np.ndarray:
        self.bools = self.backend.not_equal(data1, data2)
        return self.bools
    
    def _derivative(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.multiply(self.bools, gradient)


class Less(TwoTensors):
    __slots__ = ['bools']
    
    def __init__(self) -> None:
        super().__init__()

    def _operation(self, data1: np.ndarray, data2: np.ndarray, *args, **kwargs) -> np.ndarray:
        self.bools = self.backend.less(data1, data2)
        return self.bools
    
    def _derivative(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.multiply(self.bools, gradient)


class LessEqual(TwoTensors):
    __slots__ = ['bools']

    def __init__(self) -> None:
        super().__init__()

    def _operation(self, data1: np.ndarray, data2: np.ndarray, *args, **kwargs) -> np.ndarray:
        self.bools = self.backend.less_equal(data1, data2)
        return self.bools
    
    def _derivative(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.multiply(self.bools, gradient)


class Greater(TwoTensors):
    __slots__ = ['bools']

    def __init__(self) -> None:
        super().__init__()

    def _operation(self, data1: np.ndarray, data2: np.ndarray, *args, **kwargs) -> np.ndarray:
        self.bools = self.backend.greater(data1, data2)
        return self.bools
    
    def _derivative(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.multiply(self.bools, gradient)


class GreaterEqual(TwoTensors):
    __slots__ = ['bools']

    def __init__(self) -> None:
        super().__init__()

    def _operation(self, data1: np.ndarray, data2: np.ndarray, *args, **kwargs) -> np.ndarray:
        self.bools = self.backend.greater_equal(data1, data2)
        return self.bools
    
    def _derivative(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.multiply(self.bools, gradient)


#
# Shaping
#


class Flatten(OneTensor):
    def __init__(self) -> None:
        super().__init__()

    def _operation(self, data: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.reshape(data, newshape=(-1))
    
    def _derivative(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.reshape(gradient, newshape=self.tensor.shape)


class Reshape(OneTensor):
    def __init__(self) -> None:
        super().__init__()

    def _operation(self, data: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.reshape(data, *args, **kwargs)
    
    def _derivative(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.reshape(gradient, newshape=self.tensor.shape)


#
# Broadcasting
#


class Repeat(Operation):
    __slots__ = ['repeats', 'axis', 'tensor']

    def __init__(self) -> None:
        super().__init__()
        self.tensor = None

    def forward(self, tensor: Tensor, repeats: ArrayLike, axis: int = None) -> Tensor:
        tensor = checkTensor(tensor)
        self.repeats = repeats
        self.axis = axis

        requireGradient = tensor.requireGradient
        if requireGradient:
            self.tensor = tensor

        data = self.backend.repeat(tensor.data, repeats=self.repeats, axis=self.axis)

        return Tensor(data, requireGradient=requireGradient, gradientFunc=self)

    def backward(self, gradient) -> None:
        if self.tensor and self.tensor.requireGradient:
            if self.axis is None:
                sum_axis = tuple(range(gradient.ndim)[::-self.repeats])
                counts = np.prod(self.repeats)
            else:
                sum_axis = self.axis
                counts = self.repeats

            grad = self.backend.sum(gradient, axis=sum_axis, keepdims=True)
            grad = self.backend.divide(grad, counts)
            self.tensor.gradient = self.backend.broadcast_to(grad, self.tensor.shape)

            if self.tensor.gradientFunc:
                self.tensor.gradientFunc.backward(self.tensor.gradient)


class Tile(Operation):
    __slots__ = ['tensor', 'reps']

    def __init__(self) -> None:
        super().__init__()
        self.tensor = None

    def forward(self, tensor: Tensor, reps: ArrayLike) -> Tensor:
        tensor = checkTensor(tensor)
        self.reps = reps

        requireGradient = tensor.requireGradient
        if requireGradient:
            self.tensor = tensor

        data = self.backend.tile(tensor.data, reps=self.reps)

        return Tensor(data, requireGradient=requireGradient, gradientFunc=self)

    def backward(self, gradient: np.ndarray) -> None:
        if self.tensor and self.tensor.requireGradient:
            reshaped = self.backend.reshape(gradient, self.tensor.shape + self.reps)
            axis = tuple(range(self.tensor.ndim, gradient.ndim))
            self.tensor.gradient = self.backend.sum(reshaped, axis=axis)

            if self.tensor.gradientFunc:
                self.tensor.gradientFunc.backward(self.tensor.gradient)


class Concatenate(Operation):
    __slots__ = ['tensors', 'axis', 'out', 'dtype', 'casting', 'shapes']

    def __init__(self) -> None:
        super().__init__()
        self.tensors = None
        
    def forward(self, tensors: Tensor, axis=0, out=None, dtype=None, casting='same_kind') -> Tensor:
        self.axis = axis
        self.out = out
        self.dtype = dtype
        self.casting = casting

        tensors = [checkTensor(tensor) for tensor in tensors]

        requireGradient = any(tensor.requireGradient for tensor in tensors)
        if requireGradient:
            self.tensors = tensors
            self.shapes = [tensor.shape for tensor in tensors]

        data = self.backend.concatenate([tensor.data for tensor in tensors], axis=self.axis, out=self.out, dtype=self.dtype, casting=self.casting)
        return Tensor(data, requireGradient=requireGradient, gradientFunc=self)

    def backward(self, gradient) -> None:
        grads = self.backend.split(gradient, self.backend.cumsum([shape[self.axis] for shape in self.shapes[:-1]]), axis=self.axis)
        for tensor, grad in zip(self.tensors, grads):
            if tensor.requireGradient:
                tensor.gradient = grad
                if tensor.gradientFunc:
                    tensor.gradientFunc.backward(tensor.gradient)


class Hstack(Concatenate):
    def __init__(self):
        super().__init__()
    
    def forward(self, tensors: Tensor, dtype=None, casting='same_kind'):
        return super().forward(tensors, axis=1, out=None, dtype=dtype, casting=casting)


class Vstack(Concatenate):
    def __init__(self, dtype=None, casting='same_kind'):
        super().__init__()
    
    def forward(self, tensors: Tensor, dtype=None, casting='same_kind'):
        return super().forward(tensors, axis=0, out=None, dtype=dtype, casting=casting)


class Dstack(Concatenate):
    def __init__(self):
        super().__init__()
    
    def forward(self, tensors: Tensor):
        return super().forward(tensors, axis=2)


class Split(Operation):
    __slots__ = ['tensor', 'indices', 'axis']

    def __init__(self) -> None:
        super().__init__()
        self.tensor = None
    
    def forward(self, tensor, indices_or_sections, axis=0) -> list[Tensor]:
        tensor = checkTensor(tensor)
        self.indices = indices_or_sections
        self.axis = axis

        requireGradient = tensor.requireGradient
        if requireGradient:
            self.tensor = tensor
        
        data = self.backend.split(tensor.data, self.indices, self.axis)

        return [Tensor(datum, requireGradient=requireGradient, gradientFunc=self) for datum in data]

    def backward(self, gradient: np.ndarray) -> None:
        gradient = self.backend.concatenate(gradient, axis=self.axis)
        if self.tensor and self.tensor.requireGradient:
            self.tensor.gradient = gradient
            if self.tensor.gradientFunc:
                self.tensor.gradientFunc.backward(self.tensor.gradient)


class Hsplit(Split):
    def __init__(self) -> None:
        super().__init__()
        self.tensor = None
    
    def forward(self, tensors: Tensor, indices_or_sections):
        return super().forward(tensors, indices_or_sections=indices_or_sections, axis=1)
    
class Vsplit(Split):
    def __init__(self) -> None:
        super().__init__()
        self.tensor = None
    
    def forward(self, tensors: Tensor, indices_or_sections):
        return super().forward(tensors, indices_or_sections=indices_or_sections, axis=0)

class Dsplit(Split):
    def __init__(self) -> None:
        super().__init__()
        self.tensor = None
    
    def forward(self, tensors: Tensor, indices_or_sections):
        return super().forward(tensors, indices_or_sections=indices_or_sections, axis=2)



#
# Reduce
#


class Sum(Operation):
    __slots__ = ['tensor']

    def __init__(self) -> None:
        super().__init__()
        self.tensor = None

    def forward(self, tensor: Tensor, axis=None, dtype=None, keepdims=False, **kwargs) -> Tensor:
        tensor = checkTensor(tensor)

        requireGradient = tensor.requireGradient
        if requireGradient:
            self.tensor = tensor

        data = self.backend.sum(tensor.data, axis=axis, dtype=dtype, keepdims=keepdims)

        return Tensor(data, requireGradient=requireGradient, gradientFunc=self)

    def backward(self, gradient) -> None:
        if self.tensor and self.tensor.requireGradient:
            self.tensor.gradient = self.backend.broadcast_to(gradient.T, self.tensor.shape)
            if self.tensor.gradientFunc:
                self.tensor.gradientFunc.backward(self.tensor.gradient)


class Prod(Operation):
    __slots__ = ['tensor', 'product']

    def __init__(self) -> None:
        super().__init__()
        self.tensor = None

    def forward(self, tensor: Tensor, axis=None, dtype=None, keepdims=False) -> Tensor:
        tensor = checkTensor(tensor)

        requireGradient = tensor.requireGradient
        if requireGradient:
            self.tensor = tensor

        data = tensor.data
        self.product = self.backend.prod(data, axis=axis, dtype=dtype, keepdims=keepdims)

        return Tensor(self.product, requireGradient=requireGradient, gradientFunc=self)

    def backward(self, gradient) -> None:
        if self.tensor and self.tensor.requireGradient:
            tensorNoneZero = self.backend.where(self.tensor.data != 0, self.tensor.data, 1)
            self.tensor.gradient = self.backend.multiply(gradient, self.backend.divide(self.product, tensorNoneZero))
            if self.tensor.gradientFunc:
                self.tensor.gradientFunc.backward(self.tensor.gradient)


#
# Minimum/Maximum etc
#


class Maximum(Operation):
    __slots__ = ['tensor1', 'tensor2', 'data']

    def __init__(self):
        super().__init__()
        self.tensor1 = None
        self.tensor2 = None

    def forward(self, tensor1: Tensor, tensor2: Tensor, out=None, where=True, casting='same_kind', oder='k', dtype=None, subhok=True) -> Tensor:
        tensor1 = checkTensor(tensor1)
        tensor2 = checkTensor(tensor2)

        requireGradient = tensor1.requireGradient or tensor2.requireGradient
        if requireGradient:
            self.tensor1 = tensor1
            self.tensor2 = tensor2

        self.data = self.backend.maximum(tensor1.data, tensor2.data, out=out, where=where, casting=casting, oder=oder, dtype=dtype, subhok=subhok)

        return Tensor(self.data, requireGradient=requireGradient, gradientFunc=self)

    def backward(self, gradient) -> None:
        if self.tensor1 and self.tensor1.requireGradient:
            # A mask that is True where tensor1 had the maximum value, False elsewhere
            mask = (self.tensor1.data == self.data)
            self.tensor1.gradient = self.backend.multiply(gradient, mask)
            if self.tensor1.gradientFunc:
                self.tensor1.gradientFunc.backward(self.tensor1.gradient)

        if self.tensor2 and self.tensor2.requireGradient:
            # A mask that is True where tensor2 had the maximum value, False elsewhere
            mask = (self.tensor2.data == self.data)
            self.tensor2.gradient = self.backend.multiply(gradient, mask)
            if self.tensor2.gradientFunc:
                self.tensor2.gradientFunc.backward(self.tensor2.gradient)


class Minimum(Operation):
    __slots__ = ['tensor1', 'tensor2', 'data']

    def __init__(self):
        super().__init__()
        self.tensor1 = None
        self.tensor2 = None

    def forward(self, tensor1: Tensor, tensor2: Tensor, out=None, where=True, casting='same_kind', oder='k', dtype=None, subhok=True) -> Tensor:
        tensor1 = checkTensor(tensor1)
        tensor2 = checkTensor(tensor2)

        requireGradient = tensor1.requireGradient or tensor2.requireGradient
        if requireGradient:
            self.tensor1 = tensor1
            self.tensor2 = tensor2

        self.data = self.backend.minium(tensor1.data, tensor2.data, out=out, where=where, casting=casting, oder=oder, dtype=dtype, subhok=subhok)

        return Tensor(self.data, requireGradient=requireGradient, gradientFunc=self)

    def backward(self, gradient) -> None:
        if self.tensor1 and self.tensor1.requireGradient:
            # A mask that is True where tensor1 had the minimum value, False elsewhere
            mask = (self.tensor1.data == self.data)
            self.tensor1.gradient = self.backend.multiply(gradient, mask)
            if self.tensor1.gradientFunc:
                self.tensor1.gradientFunc.backward(self.tensor1.gradient)

        if self.tensor2 and self.tensor2.requireGradient:
            # A mask that is True where tensor2 had the minimum value, False elsewhere
            mask = (self.tensor2.data == self.data)
            self.tensor2.gradient = self.backend.multiply(gradient, mask)
            if self.tensor2.gradientFunc:
                self.tensor2.gradientFunc.backward(self.tensor2.gradient)


#
# Min/Max etc
#


class Max(Operation):
    __slots__ = ['tensor', 'mask']

    def __init__(self):
        super().__init__()
        self.tensor = None

    def forward(self, tensor: Tensor, axis=None, keepdims=False) -> Tensor:
        tensor = checkTensor(tensor)
        data = self.backend.max(tensor.data, axis=axis, keepdims=keepdims)

        requireGradient = tensor.requireGradient
        if requireGradient:
            self.tensor = tensor
            self.mask = (tensor.data == self.backend.broadcast_to(data, tensor.shape))

        return Tensor(data, requireGradient=requireGradient, gradientFunc=self)

    def backward(self, gradient) -> None:
        if self.tensor and self.tensor.requireGradient:
            self.tensor.gradient = self.backend.multiply(self.mask, gradient)
            if self.tensor.gradientFunc:
                self.tensor.gradientFunc.backward(self.tensor.gradient)


class Min(Operation):
    __slots__ = ['tensor', 'mask']

    def __init__(self) -> None:
        super().__init__()
        self.tensor = None

    def forward(self, tensor: Tensor, axis=None, keepdims=False) -> Tensor:
        tensor = checkTensor(tensor)
        data = self.backend.min(tensor.data, axis=axis, keepdims=keepdims)

        requireGradient = tensor.requireGradient
        if requireGradient:
            self.tensor = tensor
            self.mask = (tensor.data == self.backend.broadcast_to(data, tensor.shape))

        return Tensor(data, requireGradient=requireGradient, gradientFunc=self)

    def backward(self, gradient: np.ndarray) -> None:
        if self.tensor and self.tensor.requireGradient:
            self.tensor.gradient = self.backend.multiply(self.mask, gradient)
            if self.tensor.gradientFunc:
                self.tensor.gradientFunc.backward(self.tensor.gradient)


class Mean(Operation):
    __slots__ = ['tensor', 'divisor']

    def __init__(self) -> None:
        super().__init__()
        self.tensor = None

    def forward(self, tensor: Tensor, axis=None, keepdims=False) -> Tensor:
        tensor = checkTensor(tensor)
        data = self.backend.mean(tensor.data, axis=axis, keepdims=keepdims)

        requireGradient = tensor.requireGradient
        if requireGradient:
            self.tensor = tensor

            if axis is None:
                self.divisor = self.backend.prod(tensor.shape)
            elif isinstance(axis, int):
                self.divisor = self.backend.prod(tensor.shape[axis])
            else:
                self.divisor = self.backend.prod([tensor.shape[i] for i in axis])

        return Tensor(data, requireGradient=requireGradient, gradientFunc=self)

    def backward(self, gradient: np.ndarray) -> None:
        if self.tensor and self.tensor.requireGradient:
            self.tensor.gradient = self.backend.divide(gradient, self.divisor)
            if self.tensor.gradientFunc:
                self.tensor.gradientFunc.backward(self.tensor.gradient)


class Var(Operation):
    __slots__ = ['tensor', 'divisor', 'diff']

    def __init__(self) -> None:
        super().__init__()
        self.tensor = None

    def forward(self, tensor: Tensor, axis=None, ddof=0, keepdims=False) -> Tensor:
        tensor = checkTensor(tensor)
        data = self.backend.var(tensor.data, axis=axis, ddof=ddof, keepdims=keepdims)

        requireGradient = tensor.requireGradient
        if requireGradient:
            self.tensor = tensor
            self.diff = self.backend.subtract(tensor.data, self.backend.mean(tensor.data, axis=axis, keepdims=True))

            if axis is None:
                self.divisor = self.backend.subtract(self.backend.prod(tensor.shape), ddof)
            elif isinstance(axis, int):
                self.divisor = self.backend.subtract(self.backend.prod(tensor.shape[axis]), ddof)
            else:
                self.divisor = self.backend.subtract(self.backend.prod([tensor.shape[i] for i in axis]), ddof)

        return Tensor(data, requireGradient=requireGradient, gradientFunc=self)

    def backward(self, gradient: np.ndarray) -> None:
        if self.tensor and self.tensor.requireGradient:
            self.tensor.gradient = self.backend.multiply(self.backend.multiply(self.backend.divide(2.0, self.divisor), self.diff), gradient)
            if self.tensor.gradientFunc:
                self.tensor.gradientFunc.backward(self.tensor.gradient)


class Std(Operation):
    __slots__ = ['tensor', 'divisor', 'diff']

    def __init__(self) -> None:
        super().__init__()
        self.tensor = None

    def forward(self, tensor: Tensor, axis=None, keepdims=False) -> Tensor:
        tensor = checkTensor(tensor)
        data = self.backend.std(tensor.data, axis=axis, keepdims=keepdims)

        requireGradient = tensor.requireGradient
        if requireGradient:
            self.tensor = tensor
            self.diff = self.backend.subtract(tensor.data, self.backend.mean(tensor.data, axis=axis, keepdims=True))

            if axis is None:
                self.divisor = self.backend.prod(tensor.shape)
            elif isinstance(axis, int):
                self.divisor = self.backend.prod(tensor.shape[axis])
            else:
                self.divisor = self.backend.prod([tensor.shape[i] for i in axis])

        return Tensor(data, requireGradient=requireGradient, gradientFunc=self)

    def backward(self, gradient: np.ndarray) -> None:
        if self.tensor and self.tensor.requireGradient:
            self.tensor.gradient = self.backend.multiply(gradient, self.backend.divide(self.diff, self.backend.multiply(self.divisor, self.tensor.data)))
            if self.tensor.gradientFunc:
                self.tensor.gradientFunc.backward(self.tensor.gradient)


#
# Others
#


class Pad(Operation):
    __slots__ = ['tensor', 'padding', 'mode', 'constant_values']

    def __init__(self) -> None:
        super().__init__()
        self.tensor = None

    def forward(self, tensor: Tensor, pad_with, mode='constant', constant_values=0) -> Tensor:
        tensor = checkTensor(tensor)

        self.padding = pad_with
        self.mode = mode
        self.constant_values = constant_values

        requireGradient = tensor.requireGradient
        if requireGradient:
            self.tensor = tensor

        data = self.backend.pad(tensor.data, self.padding, mode=self.mode, constant_values=self.constant_values)

        return Tensor(data, requireGradient=requireGradient, gradientFunc=self)

    def backward(self, gradient) -> None:
        if self.tensor and self.tensor.requireGradient:
            slices = tuple(slice(pad[0], -pad[1] if pad[1] != 0 else None) for pad in self.padding)
            self.tensor.gradient = gradient[slices]
            if self.tensor.gradientFunc:
                self.tensor.gradientFunc.backward(self.tensor.gradient)


class Insert(Operation):
    __slots__ = ['tensor', 'values', 'index']

    def __init__(self) -> None:
        super().__init__()
        self.tensor = None

    def forward(self, tensor: Tensor, values: Tensor, index: ArrayLike) -> Tensor:
        self.index = index
        tensor = checkTensor(tensor)
        values = checkTensor(values)

        requireGradient = tensor.requireGradient or values.requireGradient
        if requireGradient:
            self.tensor = tensor
            self.values = values

        data = self.backend.insert(tensor.data, self.index, values.data)
        return Tensor(data, requireGradient=requireGradient, gradientFunc=self)

    def backward(self, gradient) -> None:
        if self.tensor and self.tensor.requireGradient:
            self.tensor.gradient = self.backend.delete(gradient, self.index)
            if self.tensor.gradientFunc:
                self.tensor.gradientFunc.backward(self.tensor.gradient)

        if self.values and self.values.requireGradient:
            self.values.gradient = gradient[self.index]
            if self.values.gradientFunc:
                self.values.gradientFunc.backward(self.values.gradient)


class Transpose(Operation):
    __slots__ = ['tensor']

    def __init__(self) -> None:
        super().__init__()
        self.tensor = None

    def forward(self, tensor: Tensor) -> Tensor:
        tensor = checkTensor(tensor)

        if tensor.requireGradient:
            self.tensor = tensor

        data = self.backend.transpose(tensor.data)
        return Tensor(data, requireGradient=tensor.requireGradient, gradientFunc=self)

    def backward(self, gradient) -> None:
        if self.tensor and self.tensor.requireGradient:
            self.tensor.gradient = self.backend.transpose(gradient)
            if self.tensor.gradientFunc:
                self.tensor.gradientFunc.backward(self.tensor.gradient)


class Where(Operation):
    __slots__ = ['condition', 'tensor1', 'tensor2']

    def __init__(self) -> None:
        super().__init__()
        self.tensor1 = None
        self.tensor2 = None

    def forward(self, condition, tensor1: Tensor, tensor2: Tensor) -> Tensor:
        tensor1 = checkTensor(tensor1)
        tensor2 = checkTensor(tensor2)

        requireGradient = tensor1.requireGradient or tensor2.requireGradient
        if requireGradient:
            self.condition = condition
            self.tensor1 = tensor1
            self.tensor2 = tensor2

        data = self.backend.where(condition, tensor1.data, tensor2.data)
        return Tensor(data, requireGradient=requireGradient, gradientFunc=self)

    def backward(self, gradient) -> None:
        if self.tensor1 and self.tensor1.requireGradient:
            self.tensor1.gradient = self.backend.multiply(gradient, self.condition)
            if self.tensor1.gradientFunc:
                self.tensor1.gradientFunc.backward(self.tensor1.gradient)

        if self.tensor2 and self.tensor2.requireGradient:
            self.tensor2.gradient = self.backend.multiply(gradient, self.backend.logical_not(self.condition))
            if self.tensor2.gradientFunc:
                self.tensor2.gradientFunc.backward(self.tensor2.gradient)


class Cumsum(OneTensor):
    def __init__(self, axis) -> None:
        super().__init__()
        self.axis = axis

    def _operation(self, data: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.cumsum(data, self.axis)
    
    def _derivative(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.cumsum(gradient, -self.axis)[::-1]


class Cumprod(OneTensor):
    def __init__(self, axis) -> None:
        super().__init__()
        self.axis = axis

    def _operation(self, data: np.ndarray, *args, **kwargs) -> np.ndarray:
        return self.backend.cumprod(data, self.axis)
    
    def _derivative(self, gradient: np.ndarray, *args, **kwargs) -> np.ndarray:
        prod = self._operation(self.tensor.data)
        return self.backend.divide(gradient, prod)


#
# Not working correctly
#


class AsStrided(Operation):
    """
    An as_strided operation with backward pass for convolutional gradients
    """

    __slots__ = ['tensor', 'patches', 'shape', 'strides']

    def __init__(self) -> None:
        super().__init__()

    def forward(self, tensor: Tensor, shape=None, strides=None, subok=False) -> Tensor:
        tensor = checkTensor(tensor)

        requireGradient = tensor.requireGradient
        if requireGradient:
            self.tensor = tensor

        self.shape = shape
        self.strides = strides
        self.patches = self.backend.as_strided(tensor.data, shape=shape, strides=strides, subok=False)
        gradientPatches = self.backend.as_strided(tensor.gradient, shape=shape, strides=strides, subok=False)

        return Tensor(data=self.patches, gradient=gradientPatches, requireGradient=requireGradient, gradientFunc=self)

    def backward(self, gradient) -> None:
        if self.tensor and self.tensor.requireGradient:
            # Sum up the gradient patches into the tensor gradient
            self.tensor.gradient = gradient.sum(tuple(self.backend.arange(gradient.ndim - self.tensor.ndim)))
            if self.tensor.gradientFunc:
                self.tensor.gradientFunc.backward(self.tensor.gradient)


class SlidingWindow(Operation):
    """
    An as_strided operation with backward pass for convolutional gradients
    """

    __slots__ = ['tensor', 'patches', 'shape', 'axis']

    def __init__(self) -> None:
        super().__init__()

    def forward(self, tensor: Tensor, window_shape=None, axis=None, *, subok=False, writeable=True) -> Tensor:
        tensor = checkTensor(tensor)

        requireGradient = tensor.requireGradient
        if requireGradient:
            self.tensor = tensor

        self.shape = window_shape
        self.axis = axis
        self.patches = self.backend.sliding_window_view(tensor.data, window_shape=window_shape, axis=axis, subok=subok, writeable=writeable)
        gradientPatches = self.backend.sliding_window_view(tensor.gradient, window_shape=window_shape, axis=axis, subok=subok, writeable=writeable)

        return Tensor(data=self.patches, gradient=gradientPatches, requireGradient=requireGradient, gradientFunc=self)

    def backward(self, gradient) -> None:
        if self.tensor and self.tensor.requireGradient:
            # Sum up the gradient patches into the tensor gradient
            self.tensor.gradient = gradient.sum(tuple(self.backend.range(gradient.ndim - self.tensor.data.ndim)))
            if self.tensor.gradientFunc:
                self.tensor.gradientFunc.backward(self.tensor.gradient)


class Einsum(Operation):
    """
    A placeholder einsum operation with backward pass for convolutional gradients
    """

    __slots__ = ['tensor1', 'tensor2', 'einsums']

    def __init__(self) -> None:
        super().__init__()

    def forward(self, tensor1: Tensor, tensor2: Tensor, optimize=False) -> Tensor:
        tensor1 = checkTensor(tensor1)
        tensor2 = checkTensor(tensor2)

        requireGradient = tensor1.requireGradient or tensor2.requireGradient
        if requireGradient:
            self.tensor1 = tensor1
            self.tensor2 = tensor2

        self.einsums = self.backend.einsum('bihwkl,oikl->bohw', tensor1.data, tensor2.data, optimize=optimize)
        return Tensor(self.einsums, requireGradient=requireGradient, gradientFunc=self)

    def backward(self, gradient) -> None:
        if self.tensor1 and self.tensor1.requireGradient:
            # Create gradient patches for tensor1
            self.tensor1.gradient = self.backend.as_strided(gradient,
                                                            shape=(*self.tensor1.data.shape, *self.tensor2.data.shape[-2:]),
                                                            strides=(*self.tensor1.data.strides, 0, 0))
            if self.tensor1.gradientFunc:
                self.tensor1.gradientFunc.backward(self.tensor1.gradient)

        if self.tensor2 and self.tensor2.requireGradient:
            # Create gradient patches for tensor2
            self.tensor2.gradient = self.backend.as_strided(gradient,
                                                            shape=(*self.tensor2.data.shape[:-2], *self.tensor1.data.shape[-2:]),
                                                            strides=(0, 0, *self.tensor1.data.strides[-2:]))
            if self.tensor2.gradientFunc:
                self.tensor2.gradientFunc.backward(self.tensor2.gradient)


#
# Mapping from Numpy to Tensor
#


ufuncMap = {
    np.add: Add,
    np.subtract: Subtract,
    np.multiply: Multiply,
    np.divide: Divide,
    np.matmul: Matmul,
    np.dot: Dot,
    np.power: Power,
    np.sqrt: Sqrt,
    np.log: Log,
    np.exp: Exp,
    np.sin: Sin,
    np.cos: Cos,
    np.cos: Tan,
    np.sinh: Sinh,
    np.cosh: Cosh,
    np.tanh: Tanh,
    np.abs: Abs,
    np.sign: Sign,
    np.positive: Positive,
    np.negative: Negative,
    np.maximum: Maximum,
    np.minimum: Minimum
}

funcMap = {
    np.sum: Sum,
    np.prod: Prod,
    np.repeat: Repeat,
    np.tile: Tile,
    np.max: Max,
    np.min: Min,
    np.mean: Mean,
    np.var: Var,
    np.std: Std,
    np.reshape: Reshape,
    np.transpose: Transpose,
    np.concatenate: Concatenate,
    np.hstack: Hstack,
    np.vstack: Vstack,
    np.dstack: Dstack,
    np.split: Split,
    np.hsplit: Hsplit,
    np.vsplit: Vsplit,
    np.dsplit: Dsplit,
    np.pad: Pad,
    np.insert: Insert,
    np.where: Where,
    np.einsum: Einsum
}