from abc import ABC, abstractmethod
import numpy as np


class BackendInterface(ABC):
    # init

    @abstractmethod
    def array(self, x):
        raise NotImplementedError()
    
    @abstractmethod
    def copy(self, x, *args, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def zeros(self, x, *args, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def ones(self, x, *args, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def zeros_like(self, x, *args, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def ones_like(self, x, *args, **kwargs):
        raise NotImplementedError()
    
    @abstractmethod
    def arange(self, *args, **kwargs):
        raise NotImplementedError()

    # double tensor

    @abstractmethod
    def add(self, x, y, *args, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def subtract(self, x, y, *args, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def multiply(self, x, y, *args, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def divide(self, x, y, *args, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def matmul(self, x, y, *args, **kwargs):
        raise NotImplementedError()
    
    @abstractmethod
    def dot(self, x, y, *args, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def power(self, x, y, *args, **kwargs):
        raise NotImplementedError()

    # single tensor

    @abstractmethod
    def square(self, x, *args, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def sqrt(self, x, *args, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def log(self, x, *args, **kwargs):
        raise NotImplementedError()
    
    @abstractmethod
    def exp(self, x, *args, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def sin(self, x, *args, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def cos(self, x, *args, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def tan(self, x, *args, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def sinh(self, x, *args, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def cosh(self, x, *args, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def tanh(self, x, *args, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def abs(self, x, *args, **kwargs):
        raise NotImplementedError()

    # signs

    @abstractmethod
    def sign(self, x, *args, **kwargs):
        raise NotImplementedError()
    
    @abstractmethod
    def positive(self, x, *args, **kwargs):
        raise NotImplementedError()
    
    @abstractmethod
    def negative(self, x, *args, **kwargs):
        raise NotImplementedError()
    
    @abstractmethod
    def negative(self, x, *args, **kwargs):
        raise NotImplementedError()

    # Compare

    @abstractmethod
    def equal(self, x, y, *args, **kwargs):
        raise NotImplementedError()
    
    @abstractmethod
    def not_equal(self, x, y, *args, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def less(self, x, y, *args, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def less_equal(self, x, y, *args, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def greater(self, x, y, *args, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def greater_equal(self, x, y, *args, **kwargs):
        raise NotImplementedError()

    # logic
    
    @abstractmethod
    def logical_and(self, x, *args, **kwargs):
        raise NotImplementedError()
    
    @abstractmethod
    def logical_or(self, x, *args, **kwargs):
        raise NotImplementedError()
    
    @abstractmethod
    def logical_xor(self, x, *args, **kwargs):
        raise NotImplementedError()
    
    @abstractmethod
    def logical_not(self, x, *args, **kwargs):
        raise NotImplementedError()

    # shaping

    @abstractmethod
    def flatten(self, x, *args, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def reshape(self, x, *args, **kwargs):
        raise NotImplementedError()

    # broadcasting

    @abstractmethod
    def broadcast_to(self, x, *args, **kwargs):
        raise NotImplementedError()
    
    @abstractmethod
    def repeat(self, x, *args, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def tile(self, x, *args, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def concatenate(self, x, *args, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def split(self, x, *args, **kwargs):
        raise NotImplementedError()

    # reduce

    @abstractmethod
    def sum(self, x, *args, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def prod(self, x, *args, **kwargs):
        raise NotImplementedError()

    # min/max etc

    @abstractmethod
    def max(self, x, *args, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def min(self, x, *args, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def mean(self, x, *args, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def var(self, x, *args, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def std(self, x, *args, **kwargs):
        raise NotImplementedError()

    # others

    @abstractmethod
    def pad(self, x, *args, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def insert(self, x, *args, **kwargs):
        raise NotImplementedError()
    
    @abstractmethod
    def transpose(self, x, *args, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def where(self, x, *args, **kwargs):
        raise NotImplementedError()
    
    @abstractmethod
    def cumsum(self, x, *args, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def cumprod(self, x, *args, **kwargs):
        raise NotImplementedError()

    # not working yet

    @abstractmethod
    def as_strided(self, x, *args, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def sliding_window_view(self, x, *args, **kwargs):
        raise NotImplementedError()

    @abstractmethod
    def einsum(self, subscript, x, y, *args, **kwargs):
        raise NotImplementedError()


class NumpyBackend(BackendInterface):
    import numpy as np
    # init

    def array(self, x):
        return self.np.array(x)

    def copy(self, x, *args, **kwargs):
        return self.np.copy(x, *args, **kwargs)

    def zeros(self, x, *args, **kwargs):
        return self.np.zeros(x, *args, **kwargs)

    def ones(self, x, *args, **kwargs):
        return self.np.ones(x, *args, **kwargs)

    def zeros_like(self, x, *args, **kwargs):
        return self.np.zeros_like(x, *args, **kwargs)

    def ones_like(self, x, *args, **kwargs):
        return self.np.ones_like(x, *args, **kwargs)
    
    def arange(self, *args, **kwargs):
        return self.np.arange(*args, **kwargs)

    # double tensor

    def add(self, x, y, *args, **kwargs):
        return self.np.add(x, y, *args, **kwargs)

    def subtract(self, x, y, *args, **kwargs):
        return self.np.subtract(x, y, *args, **kwargs)

    def multiply(self, x, y, *args, **kwargs):
        return self.np.multiply(x, y, *args, **kwargs)

    def divide(self, x, y, *args, **kwargs):
        return self.np.divide(x, y, *args, **kwargs)

    def matmul(self, x, y, *args, **kwargs):
        return self.np.matmul(x, y, *args, **kwargs)
    
    def dot(self, x, y, *args, **kwargs):
        return self.np.dot(x, y, *args, **kwargs)

    def power(self, x, y, *args, **kwargs):
        return self.np.power(x, y, *args, **kwargs)

    # single tensor

    def square(self, x, *args, **kwargs):
        return self.np.square(x, *args, **kwargs)

    def sqrt(self, x, *args, **kwargs):
        return self.np.sqrt(x, *args, **kwargs)

    def log(self, x, *args, **kwargs):
        return self.np.log(x, *args, **kwargs)
    
    def exp(self, x, *args, **kwargs):
        return self.np.exp(x, *args, **kwargs)

    def sin(self, x, *args, **kwargs):
        return self.np.sin(x, *args, **kwargs)

    def cos(self, x, *args, **kwargs):
        return self.np.cos(x, *args, **kwargs)

    def tan(self, x, *args, **kwargs):
        return self.np.tan(x, *args, **kwargs)

    def sinh(self, x, *args, **kwargs):
        return self.np.sinh(x, *args, **kwargs)

    def cosh(self, x, *args, **kwargs):
        return self.np.cosh(x, *args, **kwargs)

    def tanh(self, x, *args, **kwargs):
        return self.np.tanh(x, *args, **kwargs)

    def abs(self, x, *args, **kwargs):
        return self.np.abs(x, *args, **kwargs)
    
    # signs

    def sign(self, x, *args, **kwargs):
        return self.np.sign(x, *args, **kwargs)
    
    def positive(self, x, *args, **kwargs):
        return self.np.positive(x, *args, **kwargs)
    
    def negative(self, x, *args, **kwargs):
        return self.np.negative(x, *args, **kwargs)
    
    # compare

    def equal(self, x, y, *args, **kwargs):
        return self.np.equal(x, y, *args, **kwargs)

    def not_equal(self, x, y, *args, **kwargs):
        return self.np.not_equal(x, y, *args, **kwargs)

    def less(self, x, y, *args, **kwargs):
        return self.np.less(x, y, *args, **kwargs)

    def less_equal(self, x, y, *args, **kwargs):
        return self.np.less_equal(x, y, *args, **kwargs)

    def greater(self, x, y, *args, **kwargs):
        return self.np.greater(x, y, *args, **kwargs)

    def greater_equal(self, x, y, *args, **kwargs):
        return self.np.greater_equal(x, y, *args, **kwargs)

    # logic
    
    def logical_and(self, x, *args, **kwargs):
        return self.np.logical_and(x, *args, **kwargs)
    
    def logical_or(self, x, *args, **kwargs):
        return self.np.logical_or(x, *args, **kwargs)

    def logical_xor(self, x, *args, **kwargs):
        return self.np.logical_xor(x, *args, **kwargs)
    
    def logical_not(self, x, *args, **kwargs):
        return self.np.logical_not(x, *args, **kwargs)
    
    # shaping

    def flatten(self, x, **kwargs):
        return self.np.reshape(x, -1, **kwargs)

    def reshape(self, x, *args, **kwargs):
        return self.np.reshape(x, *args, **kwargs)

    # broadcasting

    def broadcast_to(self, x, *args, **kwargs):
        return self.np.broadcast_to(x, *args, **kwargs)

    def repeat(self, x, *args, **kwargs):
        return self.np.repeat(x, *args, **kwargs)

    def tile(self, x, *args, **kwargs):
        return self.np.tile(x, *args, **kwargs)

    def concatenate(self, x, *args, **kwargs):
        return self.np.concatenate(x, *args, **kwargs)
    
    def split(self, x, *args, **kwargs):
        return self.np.split(x, *args, **kwargs)

    # reduce

    def sum(self, x, *args, **kwargs):
        return self.np.sum(x, *args, **kwargs)

    def prod(self, x, *args, **kwargs):
        return self.np.prod(x, *args, **kwargs)

    # min/max etc

    def max(self, x, *args, **kwargs):
        return self.np.max(x, *args, **kwargs)

    def min(self, x, *args, **kwargs):
        return self.np.min(x, *args, **kwargs)

    def mean(self, x, *args, **kwargs):
        return self.np.mean(x, *args, **kwargs)

    def var(self, x, *args, **kwargs):
        return self.np.var(x, *args, **kwargs)

    def std(self, x, *args, **kwargs):
        return self.np.std(x, *args, **kwargs)

    # others

    def pad(self, x, *args, **kwargs):
        return self.np.pad(x, *args, **kwargs)

    def insert(self, x, *args, **kwargs):
        return self.np.insert(x, *args, **kwargs)

    def transpose(self, x, *args, **kwargs):
        return self.np.transpose(x, *args, **kwargs)

    def where(self, x, *args, **kwargs):
        return self.np.where(x, *args, **kwargs)

    def cumsum(self, x, *args, **kwargs):
        return self.np.cumsum(x, *args, **kwargs)

    def cumprod(self, x, *args, **kwargs):
        return self.np.cumprod(x, *args, **kwargs)

    # not working yet

    def as_strided(self, x, *args, **kwargs):
        return self.np.lib.stride_tricks.as_strided(x, *args, **kwargs)

    def sliding_window_view(self, x, *args, **kwargs):
        return self.np.lib.stride_tricks.sliding_window_view(x, *args, **kwargs)

    def einsum(self, subscript, x, y, *args, **kwargs):
        return self.np.einsum(subscript, x, y, *args, **kwargs)


class CupyBackend(BackendInterface):
    try:
        import cupy as cp
    except ImportError:
        pass

    # init

    def array(self, x):
        return self.cp.array(x)

    def copy(self, x, *args, **kwargs):
        return self.cp.copy(x, *args, **kwargs)

    def zeros(self, x, *args, **kwargs):
        return self.cp.zeros(x, *args, **kwargs)

    def ones(self, x, *args, **kwargs):
        return self.cp.ones(x, *args, **kwargs)

    def zeros_like(self, x, *args, **kwargs):
        return self.cp.zeros_like(x, *args, **kwargs)

    def ones_like(self, x, *args, **kwargs):
        return self.cp.ones_like(x, *args, **kwargs)
    
    def arange(self, *args, **kwargs):
        return self.cp.arange(*args, **kwargs)

    # double tensor

    def add(self, x, y, *args, **kwargs):
        return self.cp.add(x, y, *args, **kwargs)

    def subtract(self, x, y, *args, **kwargs):
        return self.cp.subtract(x, y, *args, **kwargs)

    def multiply(self, x, y, *args, **kwargs):
        return self.cp.multiply(x, y, *args, **kwargs)

    def divide(self, x, y, *args, **kwargs):
        return self.cp.divide(x, y, *args, **kwargs)

    def matmul(self, x, y, *args, **kwargs):
        return self.cp.matmul(x, y, *args, **kwargs)
    
    def dot(self, x, y, *args, **kwargs):
        return self.cp.dot(x, y, *args, **kwargs)

    def power(self, x, y, *args, **kwargs):
        return self.cp.power(x, y, *args, **kwargs)

    # single tensor

    def square(self, x, *args, **kwargs):
        return self.cp.square(x, *args, **kwargs)

    def sqrt(self, x, *args, **kwargs):
        return self.cp.sqrt(x, *args, **kwargs)

    def log(self, x, *args, **kwargs):
        return self.cp.log(x, *args, **kwargs)

    def exp(self, x, *args, **kwargs):
        return self.cp.exp(x, *args, **kwargs)

    def sin(self, x, *args, **kwargs):
        return self.cp.sin(x, *args, **kwargs)

    def cos(self, x, *args, **kwargs):
        return self.cp.cos(x, *args, **kwargs)

    def tan(self, x, *args, **kwargs):
        return self.cp.tan(x, *args, **kwargs)

    def sinh(self, x, *args, **kwargs):
        return self.cp.sinh(x, *args, **kwargs)

    def cosh(self, x, *args, **kwargs):
        return self.cp.cosh(x, *args, **kwargs)

    def tanh(self, x, *args, **kwargs):
        return self.cp.tanh(x, *args, **kwargs)

    def abs(self, x, *args, **kwargs):
        return self.cp.abs(x, *args, **kwargs)
    
    # signs

    def sign(self, x, *args, **kwargs):
        return self.cp.sign(x, *args, **kwargs)
    
    def positive(self, x, *args, **kwargs):
        return self.cp.positive(x, *args, **kwargs)
    
    def negative(self, x, *args, **kwargs):
        return self.cp.negative(x, *args, **kwargs)
    
    # compare

    def equal(self, x, y, *args, **kwargs):
        return self.cp.equal(x, y, *args, **kwargs)

    def not_equal(self, x, y, *args, **kwargs):
        return self.cp.not_equal(x, y, *args, **kwargs)

    def less(self, x, y, *args, **kwargs):
        return self.cp.less(x, y, *args, **kwargs)

    def less_equal(self, x, y, *args, **kwargs):
        return self.cp.less_equal(x, y, *args, **kwargs)

    def greater(self, x, y, *args, **kwargs):
        return self.cp.greater(x, y, *args, **kwargs)

    def greater_equal(self, x, y, *args, **kwargs):
        return self.cp.greater_equal(x, y, *args, **kwargs)
    # logic
    
    def logical_and(self, x, *args, **kwargs):
        return self.cp.logical_and(x, *args, **kwargs)
    
    def logical_or(self, x, *args, **kwargs):
        return self.cp.logical_or(x, *args, **kwargs)

    def logical_xor(self, x, *args, **kwargs):
        return self.cp.logical_xor(x, *args, **kwargs)
    
    def logical_not(self, x, *args, **kwargs):
        return self.cp.logical_not(x, *args, **kwargs)
    
    # shaping

    def flatten(self, x, **kwargs):
        return self.cp.reshape(x, -1, **kwargs)

    def reshape(self, x, *args, **kwargs):
        return self.cp.reshape(x, *args, **kwargs)

    # broadcasting

    def broadcast_to(self, x, *args, **kwargs):
        return self.cp.broadcast_to(x, *args, **kwargs)

    def repeat(self, x, *args, **kwargs):
        return self.cp.repeat(x, *args, **kwargs)

    def tile(self, x, *args, **kwargs):
        return self.cp.tile(x, *args, **kwargs)

    def concatenate(self, x, *args, **kwargs):
        return self.cp.concatenate(x, *args, **kwargs)
    
    def split(self, x, *args, **kwargs):
        return self.cp.split(x, *args, **kwargs)

    # reduce

    def sum(self, x, *args, **kwargs):
        return self.cp.sum(x, *args, **kwargs)

    def prod(self, x, *args, **kwargs):
        return self.cp.prod(x, *args, **kwargs)

    # min/max etc

    def max(self, x, *args, **kwargs):
        return self.cp.max(x, *args, **kwargs)

    def min(self, x, *args, **kwargs):
        return self.cp.min(x, *args, **kwargs)

    def mean(self, x, *args, **kwargs):
        return self.cp.mean(x, *args, **kwargs)

    def var(self, x, *args, **kwargs):
        return self.cp.var(x, *args, **kwargs)

    def std(self, x, *args, **kwargs):
        return self.cp.std(x, *args, **kwargs)

    # others

    def pad(self, x, *args, **kwargs):
        return self.cp.pad(x, *args, **kwargs)

    def insert(self, x, *args, **kwargs):
        return self.cp.insert(x, *args, **kwargs)
    
    def transpose(self, x, *args, **kwargs):
        return self.cp.negative(x, *args, **kwargs)

    def where(self, x, *args, **kwargs):
        return self.cp.where(x, *args, **kwargs)
    
    def cumsum(self, x, *args, **kwargs):
        return self.cp.cumsum(x, *args, **kwargs)

    def cumprod(self, x, *args, **kwargs):
        return self.cp.cumprod(x, *args, **kwargs)


    # not working yet

    def as_strided(self, x, *args, **kwargs):
        return self.cp.lib.stride_tricks.as_strided(x, *args, **kwargs)
    
    def sliding_window_view(self, x, *args, **kwargs):
        return self.cp.lib.stride_tricks.sliding_window_view(x, *args, **kwargs)

    def einsum(self, subscript, x, y, *args, **kwargs):
        return self.cp.einsum(subscript, x, y, *args, **kwargs)


class NumbaBackend(BackendInterface):
    from numba import jit
    # init

    @staticmethod
    @jit(nopython=True)
    def array(x):
        return np.array(x)

    @staticmethod
    @jit(nopython=True)
    def copy(x, *args, **kwargs):
        return np.copy(x, *args, **kwargs)

    @staticmethod
    @jit(nopython=True)
    def zeros(x, *args, **kwargs):
        return np.zeros(x, *args, **kwargs)

    @staticmethod
    @jit(nopython=True)
    def ones(x, *args, **kwargs):
        return np.ones(x, *args, **kwargs)

    @staticmethod
    @jit(nopython=True)
    def zeros_like(x, *args, **kwargs):
        return np.zeros_like(x, *args, **kwargs)

    @staticmethod
    @jit(nopython=True)
    def ones_like(x, *args, **kwargs):
        return np.ones_like(x, *args, **kwargs)
    
    @staticmethod
    @jit(nopython=True)
    def arange(*args, **kwargs):
        return np.arange(*args, **kwargs)

    # double tensor

    @staticmethod
    @jit(nopython=True)
    def add(x, y, *args, **kwargs):
        return np.add(x, y, *args, **kwargs)

    @staticmethod
    @jit(nopython=True)
    def subtract(x, y, *args, **kwargs):
        return np.subtract(x, y, *args, **kwargs)

    @staticmethod
    @jit(nopython=True)
    def multiply(x, y, *args, **kwargs):
        return np.multiply(x, y, *args, **kwargs)

    @staticmethod
    @jit(nopython=True)
    def divide(x, y, *args, **kwargs):
        return np.divide(x, y, *args, **kwargs)

    @staticmethod
    @jit(nopython=True)
    def matmul(x, y, *args, **kwargs):
        return np.matmul(x, y, *args, **kwargs)
    
    @staticmethod
    @jit(nopython=True)
    def dot(x, y, *args, **kwargs):
        return np.dot(x, y, *args, **kwargs)

    @staticmethod
    @jit(nopython=True)
    def power(x, y, *args, **kwargs):
        return np.power(x, y, *args, **kwargs)

    # single tensor

    @staticmethod
    @jit(nopython=True)
    def square(x, *args, **kwargs):
        return np.square(x, *args, **kwargs)
    
    @staticmethod
    @jit(nopython=True)
    def sqrt(x, *args, **kwargs):
        return np.sqrt(x, *args, **kwargs)

    @staticmethod
    @jit(nopython=True)
    def log(x, *args, **kwargs):
        return np.log(x, *args, **kwargs)
    
    @staticmethod
    @jit(nopython=True)
    def exp(x, *args, **kwargs):
        return np.exp(x, *args, **kwargs)

    @staticmethod
    @jit(nopython=True)
    def sin(x, *args, **kwargs):
        return np.sin(x, *args, **kwargs)

    @staticmethod
    @jit(nopython=True)
    def cos(x, *args, **kwargs):
        return np.cos(x, *args, **kwargs)

    @staticmethod
    @jit(nopython=True)
    def tan(x, *args, **kwargs):
        return np.tan(x, *args, **kwargs)

    @staticmethod
    @jit(nopython=True)
    def sinh(x, *args, **kwargs):
        return np.sinh(x, *args, **kwargs)

    @staticmethod
    @jit(nopython=True)
    def cosh(x, *args, **kwargs):
        return np.cosh(x, *args, **kwargs)

    @staticmethod
    @jit(nopython=True)
    def tanh(x, *args, **kwargs):
        return np.tanh(x, *args, **kwargs)

    @staticmethod
    @jit(nopython=True)
    def abs(x, *args, **kwargs):
        return np.abs(x, *args, **kwargs)

    # signs

    @staticmethod
    @jit(nopython=True)
    def sign(x, *args, **kwargs):
        return np.sign(x, *args, **kwargs)
    
    @staticmethod
    @jit(nopython=True)
    def positive(x, *args, **kwargs):
        return np.positive(x, *args, **kwargs)
    
    @staticmethod
    @jit(nopython=True)
    def negative(x, *args, **kwargs):
        return np.negative(x, *args, **kwargs)

    # compare

    @staticmethod
    @jit(nopython=True)
    def equal(x, y, *args, **kwargs):
        return np.equal(x, y, *args, **kwargs)

    @staticmethod
    @jit(nopython=True)
    def not_equal(x, y, *args, **kwargs):
        return np.not_equal(x, y, *args, **kwargs)

    @staticmethod
    @jit(nopython=True)
    def less(x, y, *args, **kwargs):
        return np.less(x, y, *args, **kwargs)

    @staticmethod
    @jit(nopython=True)
    def less_equal(x, y, *args, **kwargs):
        return np.less_equal(x, y, *args, **kwargs)

    @staticmethod
    @jit(nopython=True)
    def greater(x, y, *args, **kwargs):
        return np.greater(x, y, *args, **kwargs)

    @staticmethod
    @jit(nopython=True)
    def greater_equal(x, y, *args, **kwargs):
        return np.greater_equal(x, y, *args, **kwargs)
    
    # logic
    
    @staticmethod
    @jit(nopython=True)
    def logical_and(x, *args, **kwargs):
        return np.logical_and(x, *args, **kwargs)
    
    @staticmethod
    @jit(nopython=True)
    def logical_or(x, *args, **kwargs):
        return np.logical_or(x, *args, **kwargs)

    @staticmethod
    @jit(nopython=True)
    def logical_xor(x, *args, **kwargs):
        return np.logical_xor(x, *args, **kwargs)
    
    @staticmethod
    @jit(nopython=True)
    def logical_not(x, *args, **kwargs):
        return np.logical_not(x, *args, **kwargs)
    
    # shaping

    @staticmethod
    @jit(nopython=True)
    def flatten(x, **kwargs):
        return np.reshape(x, -1, **kwargs)

    @staticmethod
    @jit(nopython=True)
    def reshape(x, *args, **kwargs):
        return np.reshape(x, *args, **kwargs)

    # broadcasting

    @staticmethod
    @jit(nopython=True)
    def broadcast_to(x, *args, **kwargs):
        return np.broadcast_to(x, *args, **kwargs)

    @staticmethod
    @jit(nopython=True)
    def repeat(x, *args, **kwargs):
        return np.repeat(x, *args, **kwargs)

    @staticmethod
    @jit(nopython=True)
    def tile(x, *args, **kwargs):
        return np.tile(x, *args, **kwargs)

    @staticmethod
    @jit(nopython=True)
    def concatenate(x, *args, **kwargs):
        return np.concatenate(x, *args, **kwargs)
    
    @staticmethod
    @jit(nopython=True)
    def split(x, *args, **kwargs):
        return np.split(x, *args, **kwargs)

    # reduce

    @staticmethod
    @jit(nopython=True)
    def sum(x, *args, **kwargs):
        return np.sum(x, *args, **kwargs)

    @staticmethod
    @jit(nopython=True)
    def prod(x, *args, **kwargs):
        return np.prod(x, *args, **kwargs)

    # min/max etc

    @staticmethod
    @jit(nopython=True)
    def max(x, *args, **kwargs):
        return np.max(x, *args, **kwargs)

    @staticmethod
    @jit(nopython=True)
    def min(x, *args, **kwargs):
        return np.min(x, *args, **kwargs)

    @staticmethod
    @jit(nopython=True)
    def mean(x, *args, **kwargs):
        return np.mean(x, *args, **kwargs)

    @staticmethod
    @jit(nopython=True)
    def var(x, *args, **kwargs):
        return np.var(x, *args, **kwargs)

    @staticmethod
    @jit(nopython=True)
    def std(x, *args, **kwargs):
        return np.std(x, *args, **kwargs)

    # others

    @staticmethod
    @jit(nopython=True)
    def pad(x, *args, **kwargs):
        return np.pad(x, *args, **kwargs)

    @staticmethod
    @jit(nopython=True)
    def insert(x, *args, **kwargs):
        return np.insert(x, *args, **kwargs)

    @staticmethod
    @jit(nopython=True)
    def transpose(x, *args, **kwargs):
        return np.transpose(x, *args, **kwargs)

    @staticmethod
    @jit(nopython=True)
    def where(x, *args, **kwargs):
        return np.where(x, *args, **kwargs)

    @staticmethod
    @jit(nopython=True)
    def cumsum(x, *args, **kwargs):
        return np.cumsum(x, *args, **kwargs)

    @staticmethod
    @jit(nopython=True)
    def cumprod(x, *args, **kwargs):
        return np.cumprod(x, *args, **kwargs)

    # not working yet

    @staticmethod
    @jit(nopython=True)
    def as_strided(x, *args, **kwargs):
        return np.lib.stride_tricks.as_strided(x, *args, **kwargs)
    
    @staticmethod
    @jit(nopython=True)
    def sliding_window_view( x, *args, **kwargs):
        return np.lib.stride_tricks.sliding_window_view(x, *args, **kwargs)

    @staticmethod
    @jit(nopython=True)
    def einsum(x, y, *args, **kwargs):
        return np.einsum('bihwkl,oikl->bohw', x, y, *args, **kwargs)


class PytorchBackend(BackendInterface):
    try:
        import torch
    except ImportError:
        pass
    # init

    def array(self, x, *args, **kwargs):
        return self.torch.tensor(x, *args, **kwargs)
    
    def copy(self, x, *args, **kwargs):
        return self.torch.clone(x)
    
    def zeros(self, x, *args, **kwargs):
        return self.torch.zeros(x, *args, **kwargs)

    def ones(self, x, *args, **kwargs):
        return self.torch.ones(x, *args, **kwargs)

    def zeros_like(self, x, *args, **kwargs):
        return self.torch.zeros_like(x, *args, **kwargs)

    def ones_like(self, x, *args, **kwargs):
        return self.torch.ones_like(x, *args, **kwargs)
    
    def arange(self, *args, **kwargs):
        return self.torch.arange(*args, **kwargs)

    # double tensor

    def add(self, x, y, *args, **kwargs):
        return self.torch.add(x, y, *args, **kwargs)

    def subtract(self, x, y, *args, **kwargs):
        return self.torch.subtract(x, y, *args, **kwargs)

    def multiply(self, x, y, *args, **kwargs):
        return self.torch.multiply(x, y, *args, **kwargs)

    def divide(self, x, y, *args, **kwargs):
        return self.torch.divide(x, y, *args, **kwargs)

    def matmul(self, x, y, *args, **kwargs):
        return self.torch.matmul(x, y, *args, **kwargs)
    
    def dot(self, x, y, *args, **kwargs):
        return self.torch.dot(x, y, *args, **kwargs)

    def power(self, x, y, *args, **kwargs):
        return self.torch.pow(x, y, *args, **kwargs)

    # single tensor

    def square(self, x, *args, **kwargs):
        return self.torch.square(x, *args, **kwargs)
    
    def sqrt(self, x, *args, **kwargs):
        return self.torch.sqrt(x, *args, **kwargs)

    def log(self, x, *args, **kwargs):
        return self.torch.log(x, *args, **kwargs)

    def exp(self, x, *args, **kwargs):
        return self.torch.exp(x, *args, **kwargs)

    def sin(self, x, *args, **kwargs):
        return self.torch.sin(x, *args, **kwargs)

    def cos(self, x, *args, **kwargs):
        return self.torch.cos(x, *args, **kwargs)

    def tan(self, x, *args, **kwargs):
        return self.torch.tan(x, *args, **kwargs)

    def sinh(self, x, *args, **kwargs):
        return self.torch.sinh(x, *args, **kwargs)

    def cosh(self, x, *args, **kwargs):
        return self.torch.cosh(x, *args, **kwargs)

    def tanh(self, x, *args, **kwargs):
        return self.torch.tanh(x, *args, **kwargs)

    def abs(self, x, *args, **kwargs):
        return self.torch.abs(x, *args, **kwargs)

    # signs

    def sign(self, x, *args, **kwargs):
        return self.torch.sign(x, *args, **kwargs)
    
    def positive(self, x, *args, **kwargs):
        return self.torch.positive(x, *args, **kwargs)
    
    def negative(self, x, *args, **kwargs):
        return self.torch.negative(x, *args, **kwargs)

    # compare

    def equal(self, x, y, *args, **kwargs):
        return self.torch.equal(x, y, *args, **kwargs)

    def not_equal(self, x, y, *args, **kwargs):
        return self.torch.not_equal(x, y, *args, **kwargs)

    def less(self, x, y, *args, **kwargs):
        return self.torch.less(x, y, *args, **kwargs)

    def less_equal(self, x, y, *args, **kwargs):
        return self.torch.less_equal(x, y, *args, **kwargs)

    def greater(self, x, y, *args, **kwargs):
        return self.torch.greater(x, y, *args, **kwargs)

    def greater_equal(self, x, y, *args, **kwargs):
        return self.torch.greater_equal(x, y, *args, **kwargs)

    # logic
    
    def logical_and(self, x, *args, **kwargs):
        return self.torch.logical_and(x, *args, **kwargs)
    
    def logical_or(self, x, *args, **kwargs):
        return self.torch.logical_or(x, *args, **kwargs)

    def logical_xor(self, x, *args, **kwargs):
        return self.torch.logical_xor(x, *args, **kwargs)
    
    def logical_not(self, x, *args, **kwargs):
        return self.torch.logical_not(x, *args, **kwargs)
    
    # shaping

    def flatten(self, x, **kwargs):
        return self.torch.reshape(x, -1, **kwargs)

    def reshape(self, x, *args, **kwargs):
        return self.torch.reshape(x, *args, **kwargs)

    # broadcasting

    def broadcast_to(self, x, *args, **kwargs):
        return self.torch.broadcast_to(x, *args, **kwargs)

    def repeat(self, x, *args, **kwargs):
        return self.torch.repeat(x, *args, **kwargs)

    def tile(self, x, *args, **kwargs):
        return self.torch.tile(x, *args, **kwargs)

    def concatenate(self, x, *args, **kwargs):
        return self.torch.cat(x, *args, **kwargs)
    
    def split(self, x, *args, **kwargs):
        return self.torch.split(x, *args, **kwargs)

    # reduce

    def sum(self, x, *args, **kwargs):
        return self.torch.sum(x)

    def prod(self, x, *args, **kwargs):
        return self.torch.prod(x, *args, **kwargs)

    # min/max etc

    def max(self, x, *args, **kwargs):
        return self.torch.max(x, *args, **kwargs)

    def min(self, x, *args, **kwargs):
        return self.torch.min(x, *args, **kwargs)

    def mean(self, x, *args, **kwargs):
        return self.torch.mean(x, *args, **kwargs)

    def var(self, x, *args, **kwargs):
        return self.torch.var(x, *args, **kwargs)

    def std(self, x, *args, **kwargs):
        return self.torch.std(x, *args, **kwargs)

    # others

    def pad(self, x, *args, **kwargs):
        return self.torch.pad(x, *args, **kwargs)

    def insert(self, x, *args, **kwargs):
        return self.torch.insert(x, *args, **kwargs)
    
    def transpose(self, x, *args, **kwargs):
        return self.torch.transpose(x, 0, 1)

    def where(self, x, *args, **kwargs):
        return self.torch.where(x, *args, **kwargs)
    
    def cumsum(self, x, *args, **kwargs):
        return self.torch.cumsum(x, *args, **kwargs)

    def cumprod(self, x, *args, **kwargs):
        return self.torch.cumprod(x, *args, **kwargs)

    # not working yet

    def as_strided(self, x, *args, **kwargs):
        return self.torch.as_strided(x, *args, **kwargs)
    
    def sliding_window_view(self, x, *args, **kwargs):
        raise NotImplementedError

    def einsum(self, subscript, x, y, *args, **kwargs):
        return self.torch.einsum(subscript, x, y, *args, **kwargs)


class TensorflowBackend(BackendInterface):
    try:
        import tensorflow as tf
    except ImportError:
        pass
    # Implement the necessary methods here using TensorFlow's functions
    pass