Skip to content
Snippets Groups Projects
backend.py 32.2 KiB
Newer Older
  • Learn to ignore specific revisions
  • johannes bilk's avatar
    johannes bilk committed
    from abc import ABC, abstractmethod
    import numpy as np
    
    
    class BackendInterface(ABC):
        # init
    
        @abstractmethod
        def array(self, x):
            raise NotImplementedError()
    
    johannes bilk's avatar
    johannes bilk committed
        @abstractmethod
        def copy(self, x, *args, **kwargs):
            raise NotImplementedError()
    
        @abstractmethod
        def zeros(self, x, *args, **kwargs):
            raise NotImplementedError()
    
        @abstractmethod
        def ones(self, x, *args, **kwargs):
            raise NotImplementedError()
    
        @abstractmethod
        def zeros_like(self, x, *args, **kwargs):
            raise NotImplementedError()
    
        @abstractmethod
        def ones_like(self, x, *args, **kwargs):
            raise NotImplementedError()
    
    johannes bilk's avatar
    johannes bilk committed
        @abstractmethod
        def arange(self, *args, **kwargs):
            raise NotImplementedError()
    
        # double tensor
    
        @abstractmethod
        def add(self, x, y, *args, **kwargs):
            raise NotImplementedError()
    
        @abstractmethod
        def subtract(self, x, y, *args, **kwargs):
            raise NotImplementedError()
    
        @abstractmethod
        def multiply(self, x, y, *args, **kwargs):
            raise NotImplementedError()
    
        @abstractmethod
        def divide(self, x, y, *args, **kwargs):
            raise NotImplementedError()
    
        @abstractmethod
        def matmul(self, x, y, *args, **kwargs):
            raise NotImplementedError()
    
    johannes bilk's avatar
    johannes bilk committed
        @abstractmethod
        def dot(self, x, y, *args, **kwargs):
            raise NotImplementedError()
    
        @abstractmethod
        def power(self, x, y, *args, **kwargs):
            raise NotImplementedError()
    
        # single tensor
    
        @abstractmethod
        def square(self, x, *args, **kwargs):
            raise NotImplementedError()
    
        @abstractmethod
        def sqrt(self, x, *args, **kwargs):
            raise NotImplementedError()
    
        @abstractmethod
        def log(self, x, *args, **kwargs):
            raise NotImplementedError()
    
    johannes bilk's avatar
    johannes bilk committed
        @abstractmethod
        def exp(self, x, *args, **kwargs):
            raise NotImplementedError()
    
        @abstractmethod
        def sin(self, x, *args, **kwargs):
            raise NotImplementedError()
    
        @abstractmethod
        def cos(self, x, *args, **kwargs):
            raise NotImplementedError()
    
        @abstractmethod
        def tan(self, x, *args, **kwargs):
            raise NotImplementedError()
    
        @abstractmethod
        def sinh(self, x, *args, **kwargs):
            raise NotImplementedError()
    
        @abstractmethod
        def cosh(self, x, *args, **kwargs):
            raise NotImplementedError()
    
        @abstractmethod
        def tanh(self, x, *args, **kwargs):
            raise NotImplementedError()
    
        @abstractmethod
        def abs(self, x, *args, **kwargs):
            raise NotImplementedError()
    
        # signs
    
        @abstractmethod
        def sign(self, x, *args, **kwargs):
            raise NotImplementedError()
    
    johannes bilk's avatar
    johannes bilk committed
        @abstractmethod
        def positive(self, x, *args, **kwargs):
            raise NotImplementedError()
    
    johannes bilk's avatar
    johannes bilk committed
        @abstractmethod
        def negative(self, x, *args, **kwargs):
            raise NotImplementedError()
    
    johannes bilk's avatar
    johannes bilk committed
        @abstractmethod
        def negative(self, x, *args, **kwargs):
            raise NotImplementedError()
    
        # Compare
    
        @abstractmethod
        def equal(self, x, y, *args, **kwargs):
            raise NotImplementedError()
    
    johannes bilk's avatar
    johannes bilk committed
        @abstractmethod
        def not_equal(self, x, y, *args, **kwargs):
            raise NotImplementedError()
    
        @abstractmethod
        def less(self, x, y, *args, **kwargs):
            raise NotImplementedError()
    
        @abstractmethod
        def less_equal(self, x, y, *args, **kwargs):
            raise NotImplementedError()
    
        @abstractmethod
        def greater(self, x, y, *args, **kwargs):
            raise NotImplementedError()
    
        @abstractmethod
        def greater_equal(self, x, y, *args, **kwargs):
            raise NotImplementedError()
    
        # logic
    
    johannes bilk's avatar
    johannes bilk committed
        @abstractmethod
        def logical_and(self, x, *args, **kwargs):
            raise NotImplementedError()
    
    johannes bilk's avatar
    johannes bilk committed
        @abstractmethod
        def logical_or(self, x, *args, **kwargs):
            raise NotImplementedError()
    
    johannes bilk's avatar
    johannes bilk committed
        @abstractmethod
        def logical_xor(self, x, *args, **kwargs):
            raise NotImplementedError()
    
    johannes bilk's avatar
    johannes bilk committed
        @abstractmethod
        def logical_not(self, x, *args, **kwargs):
            raise NotImplementedError()
    
        # shaping
    
        @abstractmethod
        def flatten(self, x, *args, **kwargs):
            raise NotImplementedError()
    
        @abstractmethod
        def reshape(self, x, *args, **kwargs):
            raise NotImplementedError()
    
        # broadcasting
    
        @abstractmethod
        def broadcast_to(self, x, *args, **kwargs):
            raise NotImplementedError()
    
    johannes bilk's avatar
    johannes bilk committed
        @abstractmethod
        def repeat(self, x, *args, **kwargs):
            raise NotImplementedError()
    
        @abstractmethod
        def tile(self, x, *args, **kwargs):
            raise NotImplementedError()
    
        @abstractmethod
        def concatenate(self, x, *args, **kwargs):
            raise NotImplementedError()
    
        @abstractmethod
        def split(self, x, *args, **kwargs):
            raise NotImplementedError()
    
        # reduce
    
        @abstractmethod
        def sum(self, x, *args, **kwargs):
            raise NotImplementedError()
    
        @abstractmethod
        def prod(self, x, *args, **kwargs):
            raise NotImplementedError()
    
        # min/max etc
    
        @abstractmethod
        def max(self, x, *args, **kwargs):
            raise NotImplementedError()
    
        @abstractmethod
        def min(self, x, *args, **kwargs):
            raise NotImplementedError()
    
        @abstractmethod
        def mean(self, x, *args, **kwargs):
            raise NotImplementedError()
    
        @abstractmethod
        def var(self, x, *args, **kwargs):
            raise NotImplementedError()
    
        @abstractmethod
        def std(self, x, *args, **kwargs):
            raise NotImplementedError()
    
        # others
    
        @abstractmethod
        def pad(self, x, *args, **kwargs):
            raise NotImplementedError()
    
        @abstractmethod
        def insert(self, x, *args, **kwargs):
            raise NotImplementedError()
    
    johannes bilk's avatar
    johannes bilk committed
        @abstractmethod
        def transpose(self, x, *args, **kwargs):
            raise NotImplementedError()
    
        @abstractmethod
        def where(self, x, *args, **kwargs):
            raise NotImplementedError()
    
    johannes bilk's avatar
    johannes bilk committed
        @abstractmethod
        def cumsum(self, x, *args, **kwargs):
            raise NotImplementedError()
    
        @abstractmethod
        def cumprod(self, x, *args, **kwargs):
            raise NotImplementedError()
    
        # not working yet
    
        @abstractmethod
        def as_strided(self, x, *args, **kwargs):
            raise NotImplementedError()
    
        @abstractmethod
        def sliding_window_view(self, x, *args, **kwargs):
            raise NotImplementedError()
    
        @abstractmethod
        def einsum(self, subscript, x, y, *args, **kwargs):
            raise NotImplementedError()
    
    
    class NumpyBackend(BackendInterface):
        import numpy as np
        # init
    
        def array(self, x):
            return self.np.array(x)
    
        def copy(self, x, *args, **kwargs):
            return self.np.copy(x, *args, **kwargs)
    
        def zeros(self, x, *args, **kwargs):
            return self.np.zeros(x, *args, **kwargs)
    
        def ones(self, x, *args, **kwargs):
            return self.np.ones(x, *args, **kwargs)
    
        def zeros_like(self, x, *args, **kwargs):
            return self.np.zeros_like(x, *args, **kwargs)
    
        def ones_like(self, x, *args, **kwargs):
            return self.np.ones_like(x, *args, **kwargs)
    
    johannes bilk's avatar
    johannes bilk committed
        def arange(self, *args, **kwargs):
            return self.np.arange(*args, **kwargs)
    
        # double tensor
    
        def add(self, x, y, *args, **kwargs):
            return self.np.add(x, y, *args, **kwargs)
    
        def subtract(self, x, y, *args, **kwargs):
            return self.np.subtract(x, y, *args, **kwargs)
    
        def multiply(self, x, y, *args, **kwargs):
            return self.np.multiply(x, y, *args, **kwargs)
    
        def divide(self, x, y, *args, **kwargs):
            return self.np.divide(x, y, *args, **kwargs)
    
        def matmul(self, x, y, *args, **kwargs):
            return self.np.matmul(x, y, *args, **kwargs)
    
    johannes bilk's avatar
    johannes bilk committed
        def dot(self, x, y, *args, **kwargs):
            return self.np.dot(x, y, *args, **kwargs)
    
        def power(self, x, y, *args, **kwargs):
            return self.np.power(x, y, *args, **kwargs)
    
        # single tensor
    
        def square(self, x, *args, **kwargs):
            return self.np.square(x, *args, **kwargs)
    
        def sqrt(self, x, *args, **kwargs):
            return self.np.sqrt(x, *args, **kwargs)
    
        def log(self, x, *args, **kwargs):
            return self.np.log(x, *args, **kwargs)
    
    johannes bilk's avatar
    johannes bilk committed
        def exp(self, x, *args, **kwargs):
            return self.np.exp(x, *args, **kwargs)
    
        def sin(self, x, *args, **kwargs):
            return self.np.sin(x, *args, **kwargs)
    
        def cos(self, x, *args, **kwargs):
            return self.np.cos(x, *args, **kwargs)
    
        def tan(self, x, *args, **kwargs):
            return self.np.tan(x, *args, **kwargs)
    
        def sinh(self, x, *args, **kwargs):
            return self.np.sinh(x, *args, **kwargs)
    
        def cosh(self, x, *args, **kwargs):
            return self.np.cosh(x, *args, **kwargs)
    
        def tanh(self, x, *args, **kwargs):
            return self.np.tanh(x, *args, **kwargs)
    
        def abs(self, x, *args, **kwargs):
            return self.np.abs(x, *args, **kwargs)
    
    johannes bilk's avatar
    johannes bilk committed
        # signs
    
        def sign(self, x, *args, **kwargs):
            return self.np.sign(x, *args, **kwargs)
    
    johannes bilk's avatar
    johannes bilk committed
        def positive(self, x, *args, **kwargs):
            return self.np.positive(x, *args, **kwargs)
    
    johannes bilk's avatar
    johannes bilk committed
        def negative(self, x, *args, **kwargs):
            return self.np.negative(x, *args, **kwargs)
    
    johannes bilk's avatar
    johannes bilk committed
        # compare
    
        def equal(self, x, y, *args, **kwargs):
            return self.np.equal(x, y, *args, **kwargs)
    
        def not_equal(self, x, y, *args, **kwargs):
            return self.np.not_equal(x, y, *args, **kwargs)
    
        def less(self, x, y, *args, **kwargs):
            return self.np.less(x, y, *args, **kwargs)
    
        def less_equal(self, x, y, *args, **kwargs):
            return self.np.less_equal(x, y, *args, **kwargs)
    
        def greater(self, x, y, *args, **kwargs):
            return self.np.greater(x, y, *args, **kwargs)
    
        def greater_equal(self, x, y, *args, **kwargs):
            return self.np.greater_equal(x, y, *args, **kwargs)
    
        # logic
    
    johannes bilk's avatar
    johannes bilk committed
        def logical_and(self, x, *args, **kwargs):
            return self.np.logical_and(x, *args, **kwargs)
    
    johannes bilk's avatar
    johannes bilk committed
        def logical_or(self, x, *args, **kwargs):
            return self.np.logical_or(x, *args, **kwargs)
    
        def logical_xor(self, x, *args, **kwargs):
            return self.np.logical_xor(x, *args, **kwargs)
    
    johannes bilk's avatar
    johannes bilk committed
        def logical_not(self, x, *args, **kwargs):
            return self.np.logical_not(x, *args, **kwargs)
    
    johannes bilk's avatar
    johannes bilk committed
        # shaping
    
        def flatten(self, x, **kwargs):
            return self.np.reshape(x, -1, **kwargs)
    
        def reshape(self, x, *args, **kwargs):
            return self.np.reshape(x, *args, **kwargs)
    
        # broadcasting
    
        def broadcast_to(self, x, *args, **kwargs):
            return self.np.broadcast_to(x, *args, **kwargs)
    
        def repeat(self, x, *args, **kwargs):
            return self.np.repeat(x, *args, **kwargs)
    
        def tile(self, x, *args, **kwargs):
            return self.np.tile(x, *args, **kwargs)
    
        def concatenate(self, x, *args, **kwargs):
            return self.np.concatenate(x, *args, **kwargs)
    
    johannes bilk's avatar
    johannes bilk committed
        def split(self, x, *args, **kwargs):
            return self.np.split(x, *args, **kwargs)
    
        # reduce
    
        def sum(self, x, *args, **kwargs):
            return self.np.sum(x, *args, **kwargs)
    
        def prod(self, x, *args, **kwargs):
            return self.np.prod(x, *args, **kwargs)
    
        # min/max etc
    
        def max(self, x, *args, **kwargs):
            return self.np.max(x, *args, **kwargs)
    
        def min(self, x, *args, **kwargs):
            return self.np.min(x, *args, **kwargs)
    
        def mean(self, x, *args, **kwargs):
            return self.np.mean(x, *args, **kwargs)
    
        def var(self, x, *args, **kwargs):
            return self.np.var(x, *args, **kwargs)
    
        def std(self, x, *args, **kwargs):
            return self.np.std(x, *args, **kwargs)
    
        # others
    
        def pad(self, x, *args, **kwargs):
            return self.np.pad(x, *args, **kwargs)
    
        def insert(self, x, *args, **kwargs):
            return self.np.insert(x, *args, **kwargs)
    
        def transpose(self, x, *args, **kwargs):
            return self.np.transpose(x, *args, **kwargs)
    
        def where(self, x, *args, **kwargs):
            return self.np.where(x, *args, **kwargs)
    
        def cumsum(self, x, *args, **kwargs):
            return self.np.cumsum(x, *args, **kwargs)
    
        def cumprod(self, x, *args, **kwargs):
            return self.np.cumprod(x, *args, **kwargs)
    
        # not working yet
    
        def as_strided(self, x, *args, **kwargs):
            return self.np.lib.stride_tricks.as_strided(x, *args, **kwargs)
    
        def sliding_window_view(self, x, *args, **kwargs):
            return self.np.lib.stride_tricks.sliding_window_view(x, *args, **kwargs)
    
        def einsum(self, subscript, x, y, *args, **kwargs):
            return self.np.einsum(subscript, x, y, *args, **kwargs)
    
    
    class CupyBackend(BackendInterface):
        try:
            import cupy as cp
        except ImportError:
            pass
    
        # init
    
        def array(self, x):
            return self.cp.array(x)
    
        def copy(self, x, *args, **kwargs):
            return self.cp.copy(x, *args, **kwargs)
    
        def zeros(self, x, *args, **kwargs):
            return self.cp.zeros(x, *args, **kwargs)
    
        def ones(self, x, *args, **kwargs):
            return self.cp.ones(x, *args, **kwargs)
    
        def zeros_like(self, x, *args, **kwargs):
            return self.cp.zeros_like(x, *args, **kwargs)
    
        def ones_like(self, x, *args, **kwargs):
            return self.cp.ones_like(x, *args, **kwargs)
    
    johannes bilk's avatar
    johannes bilk committed
        def arange(self, *args, **kwargs):
            return self.cp.arange(*args, **kwargs)
    
        # double tensor
    
        def add(self, x, y, *args, **kwargs):
            return self.cp.add(x, y, *args, **kwargs)
    
        def subtract(self, x, y, *args, **kwargs):
            return self.cp.subtract(x, y, *args, **kwargs)
    
        def multiply(self, x, y, *args, **kwargs):
            return self.cp.multiply(x, y, *args, **kwargs)
    
        def divide(self, x, y, *args, **kwargs):
            return self.cp.divide(x, y, *args, **kwargs)
    
        def matmul(self, x, y, *args, **kwargs):
            return self.cp.matmul(x, y, *args, **kwargs)
    
    johannes bilk's avatar
    johannes bilk committed
        def dot(self, x, y, *args, **kwargs):
            return self.cp.dot(x, y, *args, **kwargs)
    
        def power(self, x, y, *args, **kwargs):
            return self.cp.power(x, y, *args, **kwargs)
    
        # single tensor
    
        def square(self, x, *args, **kwargs):
            return self.cp.square(x, *args, **kwargs)
    
        def sqrt(self, x, *args, **kwargs):
            return self.cp.sqrt(x, *args, **kwargs)
    
        def log(self, x, *args, **kwargs):
            return self.cp.log(x, *args, **kwargs)
    
        def exp(self, x, *args, **kwargs):
            return self.cp.exp(x, *args, **kwargs)
    
        def sin(self, x, *args, **kwargs):
            return self.cp.sin(x, *args, **kwargs)
    
        def cos(self, x, *args, **kwargs):
            return self.cp.cos(x, *args, **kwargs)
    
        def tan(self, x, *args, **kwargs):
            return self.cp.tan(x, *args, **kwargs)
    
        def sinh(self, x, *args, **kwargs):
            return self.cp.sinh(x, *args, **kwargs)
    
        def cosh(self, x, *args, **kwargs):
            return self.cp.cosh(x, *args, **kwargs)
    
        def tanh(self, x, *args, **kwargs):
            return self.cp.tanh(x, *args, **kwargs)
    
        def abs(self, x, *args, **kwargs):
            return self.cp.abs(x, *args, **kwargs)
    
    johannes bilk's avatar
    johannes bilk committed
        # signs
    
        def sign(self, x, *args, **kwargs):
            return self.cp.sign(x, *args, **kwargs)
    
    johannes bilk's avatar
    johannes bilk committed
        def positive(self, x, *args, **kwargs):
            return self.cp.positive(x, *args, **kwargs)
    
    johannes bilk's avatar
    johannes bilk committed
        def negative(self, x, *args, **kwargs):
            return self.cp.negative(x, *args, **kwargs)
    
    johannes bilk's avatar
    johannes bilk committed
        # compare
    
        def equal(self, x, y, *args, **kwargs):
            return self.cp.equal(x, y, *args, **kwargs)
    
        def not_equal(self, x, y, *args, **kwargs):
            return self.cp.not_equal(x, y, *args, **kwargs)
    
        def less(self, x, y, *args, **kwargs):
            return self.cp.less(x, y, *args, **kwargs)
    
        def less_equal(self, x, y, *args, **kwargs):
            return self.cp.less_equal(x, y, *args, **kwargs)
    
        def greater(self, x, y, *args, **kwargs):
            return self.cp.greater(x, y, *args, **kwargs)
    
        def greater_equal(self, x, y, *args, **kwargs):
            return self.cp.greater_equal(x, y, *args, **kwargs)
        # logic
    
    johannes bilk's avatar
    johannes bilk committed
        def logical_and(self, x, *args, **kwargs):
            return self.cp.logical_and(x, *args, **kwargs)
    
    johannes bilk's avatar
    johannes bilk committed
        def logical_or(self, x, *args, **kwargs):
            return self.cp.logical_or(x, *args, **kwargs)
    
        def logical_xor(self, x, *args, **kwargs):
            return self.cp.logical_xor(x, *args, **kwargs)
    
    johannes bilk's avatar
    johannes bilk committed
        def logical_not(self, x, *args, **kwargs):
            return self.cp.logical_not(x, *args, **kwargs)
    
    johannes bilk's avatar
    johannes bilk committed
        # shaping
    
        def flatten(self, x, **kwargs):
            return self.cp.reshape(x, -1, **kwargs)
    
        def reshape(self, x, *args, **kwargs):
            return self.cp.reshape(x, *args, **kwargs)
    
        # broadcasting
    
        def broadcast_to(self, x, *args, **kwargs):
            return self.cp.broadcast_to(x, *args, **kwargs)
    
        def repeat(self, x, *args, **kwargs):
            return self.cp.repeat(x, *args, **kwargs)
    
        def tile(self, x, *args, **kwargs):
            return self.cp.tile(x, *args, **kwargs)
    
        def concatenate(self, x, *args, **kwargs):
            return self.cp.concatenate(x, *args, **kwargs)
    
    johannes bilk's avatar
    johannes bilk committed
        def split(self, x, *args, **kwargs):
            return self.cp.split(x, *args, **kwargs)
    
        # reduce
    
        def sum(self, x, *args, **kwargs):
            return self.cp.sum(x, *args, **kwargs)
    
        def prod(self, x, *args, **kwargs):
            return self.cp.prod(x, *args, **kwargs)
    
        # min/max etc
    
        def max(self, x, *args, **kwargs):
            return self.cp.max(x, *args, **kwargs)
    
        def min(self, x, *args, **kwargs):
            return self.cp.min(x, *args, **kwargs)
    
        def mean(self, x, *args, **kwargs):
            return self.cp.mean(x, *args, **kwargs)
    
        def var(self, x, *args, **kwargs):
            return self.cp.var(x, *args, **kwargs)
    
        def std(self, x, *args, **kwargs):
            return self.cp.std(x, *args, **kwargs)
    
        # others
    
        def pad(self, x, *args, **kwargs):
            return self.cp.pad(x, *args, **kwargs)
    
        def insert(self, x, *args, **kwargs):
            return self.cp.insert(x, *args, **kwargs)
    
    johannes bilk's avatar
    johannes bilk committed
        def transpose(self, x, *args, **kwargs):
            return self.cp.negative(x, *args, **kwargs)
    
        def where(self, x, *args, **kwargs):
            return self.cp.where(x, *args, **kwargs)
    
    johannes bilk's avatar
    johannes bilk committed
        def cumsum(self, x, *args, **kwargs):
            return self.cp.cumsum(x, *args, **kwargs)
    
        def cumprod(self, x, *args, **kwargs):
            return self.cp.cumprod(x, *args, **kwargs)
    
    
        # not working yet
    
        def as_strided(self, x, *args, **kwargs):
            return self.cp.lib.stride_tricks.as_strided(x, *args, **kwargs)
    
    johannes bilk's avatar
    johannes bilk committed
        def sliding_window_view(self, x, *args, **kwargs):
            return self.cp.lib.stride_tricks.sliding_window_view(x, *args, **kwargs)
    
        def einsum(self, subscript, x, y, *args, **kwargs):
            return self.cp.einsum(subscript, x, y, *args, **kwargs)
    
    
    class NumbaBackend(BackendInterface):
        from numba import jit
        # init
    
        @staticmethod
        @jit(nopython=True)
        def array(x):
            return np.array(x)
    
        @staticmethod
        @jit(nopython=True)
        def copy(x, *args, **kwargs):
            return np.copy(x, *args, **kwargs)
    
        @staticmethod
        @jit(nopython=True)
        def zeros(x, *args, **kwargs):
            return np.zeros(x, *args, **kwargs)
    
        @staticmethod
        @jit(nopython=True)
        def ones(x, *args, **kwargs):
            return np.ones(x, *args, **kwargs)
    
        @staticmethod
        @jit(nopython=True)
        def zeros_like(x, *args, **kwargs):
            return np.zeros_like(x, *args, **kwargs)
    
        @staticmethod
        @jit(nopython=True)
        def ones_like(x, *args, **kwargs):
            return np.ones_like(x, *args, **kwargs)
    
    johannes bilk's avatar
    johannes bilk committed
        @staticmethod
        @jit(nopython=True)
        def arange(*args, **kwargs):
            return np.arange(*args, **kwargs)
    
        # double tensor
    
        @staticmethod
        @jit(nopython=True)
        def add(x, y, *args, **kwargs):
            return np.add(x, y, *args, **kwargs)
    
        @staticmethod
        @jit(nopython=True)
        def subtract(x, y, *args, **kwargs):
            return np.subtract(x, y, *args, **kwargs)
    
        @staticmethod
        @jit(nopython=True)
        def multiply(x, y, *args, **kwargs):
            return np.multiply(x, y, *args, **kwargs)
    
        @staticmethod
        @jit(nopython=True)
        def divide(x, y, *args, **kwargs):
            return np.divide(x, y, *args, **kwargs)
    
        @staticmethod
        @jit(nopython=True)
        def matmul(x, y, *args, **kwargs):
            return np.matmul(x, y, *args, **kwargs)
    
    johannes bilk's avatar
    johannes bilk committed
        @staticmethod
        @jit(nopython=True)
        def dot(x, y, *args, **kwargs):
            return np.dot(x, y, *args, **kwargs)
    
        @staticmethod
        @jit(nopython=True)
        def power(x, y, *args, **kwargs):
            return np.power(x, y, *args, **kwargs)
    
        # single tensor
    
        @staticmethod
        @jit(nopython=True)
        def square(x, *args, **kwargs):
            return np.square(x, *args, **kwargs)
    
    johannes bilk's avatar
    johannes bilk committed
        @staticmethod
        @jit(nopython=True)
        def sqrt(x, *args, **kwargs):
            return np.sqrt(x, *args, **kwargs)
    
        @staticmethod
        @jit(nopython=True)
        def log(x, *args, **kwargs):
            return np.log(x, *args, **kwargs)
    
    johannes bilk's avatar
    johannes bilk committed
        @staticmethod
        @jit(nopython=True)
        def exp(x, *args, **kwargs):
            return np.exp(x, *args, **kwargs)
    
        @staticmethod
        @jit(nopython=True)
        def sin(x, *args, **kwargs):
            return np.sin(x, *args, **kwargs)
    
        @staticmethod
        @jit(nopython=True)
        def cos(x, *args, **kwargs):
            return np.cos(x, *args, **kwargs)
    
        @staticmethod
        @jit(nopython=True)
        def tan(x, *args, **kwargs):
            return np.tan(x, *args, **kwargs)
    
        @staticmethod
        @jit(nopython=True)
        def sinh(x, *args, **kwargs):
            return np.sinh(x, *args, **kwargs)
    
        @staticmethod
        @jit(nopython=True)
        def cosh(x, *args, **kwargs):
            return np.cosh(x, *args, **kwargs)
    
        @staticmethod
        @jit(nopython=True)
        def tanh(x, *args, **kwargs):
            return np.tanh(x, *args, **kwargs)
    
        @staticmethod
        @jit(nopython=True)
        def abs(x, *args, **kwargs):
            return np.abs(x, *args, **kwargs)
    
        # signs
    
        @staticmethod
        @jit(nopython=True)
        def sign(x, *args, **kwargs):
            return np.sign(x, *args, **kwargs)
    
    johannes bilk's avatar
    johannes bilk committed
        @staticmethod
        @jit(nopython=True)
        def positive(x, *args, **kwargs):
            return np.positive(x, *args, **kwargs)
    
    johannes bilk's avatar
    johannes bilk committed
        @staticmethod
        @jit(nopython=True)
        def negative(x, *args, **kwargs):
            return np.negative(x, *args, **kwargs)
    
        # compare
    
        @staticmethod
        @jit(nopython=True)
        def equal(x, y, *args, **kwargs):
            return np.equal(x, y, *args, **kwargs)
    
        @staticmethod
        @jit(nopython=True)
        def not_equal(x, y, *args, **kwargs):
            return np.not_equal(x, y, *args, **kwargs)
    
        @staticmethod
        @jit(nopython=True)
        def less(x, y, *args, **kwargs):
            return np.less(x, y, *args, **kwargs)
    
        @staticmethod
        @jit(nopython=True)
        def less_equal(x, y, *args, **kwargs):
            return np.less_equal(x, y, *args, **kwargs)
    
        @staticmethod
        @jit(nopython=True)
        def greater(x, y, *args, **kwargs):
            return np.greater(x, y, *args, **kwargs)
    
        @staticmethod
        @jit(nopython=True)
        def greater_equal(x, y, *args, **kwargs):
            return np.greater_equal(x, y, *args, **kwargs)
    
    johannes bilk's avatar
    johannes bilk committed
        # logic
    
    johannes bilk's avatar
    johannes bilk committed
        @staticmethod
        @jit(nopython=True)
        def logical_and(x, *args, **kwargs):
            return np.logical_and(x, *args, **kwargs)
    
    johannes bilk's avatar
    johannes bilk committed
        @staticmethod
        @jit(nopython=True)
        def logical_or(x, *args, **kwargs):
            return np.logical_or(x, *args, **kwargs)
    
        @staticmethod
        @jit(nopython=True)
        def logical_xor(x, *args, **kwargs):
            return np.logical_xor(x, *args, **kwargs)
    
    johannes bilk's avatar
    johannes bilk committed
        @staticmethod
        @jit(nopython=True)
        def logical_not(x, *args, **kwargs):
            return np.logical_not(x, *args, **kwargs)
    
    johannes bilk's avatar
    johannes bilk committed
        # shaping
    
        @staticmethod
        @jit(nopython=True)
        def flatten(x, **kwargs):
            return np.reshape(x, -1, **kwargs)
    
        @staticmethod
        @jit(nopython=True)
        def reshape(x, *args, **kwargs):
            return np.reshape(x, *args, **kwargs)
    
        # broadcasting
    
        @staticmethod
        @jit(nopython=True)
        def broadcast_to(x, *args, **kwargs):
            return np.broadcast_to(x, *args, **kwargs)
    
        @staticmethod
        @jit(nopython=True)
        def repeat(x, *args, **kwargs):
            return np.repeat(x, *args, **kwargs)
    
        @staticmethod
        @jit(nopython=True)
        def tile(x, *args, **kwargs):
            return np.tile(x, *args, **kwargs)
    
        @staticmethod
        @jit(nopython=True)
        def concatenate(x, *args, **kwargs):
            return np.concatenate(x, *args, **kwargs)
    
    johannes bilk's avatar
    johannes bilk committed
        @staticmethod
        @jit(nopython=True)
        def split(x, *args, **kwargs):
            return np.split(x, *args, **kwargs)
    
        # reduce
    
        @staticmethod
        @jit(nopython=True)
        def sum(x, *args, **kwargs):
            return np.sum(x, *args, **kwargs)
    
        @staticmethod
        @jit(nopython=True)
        def prod(x, *args, **kwargs):
            return np.prod(x, *args, **kwargs)
    
        # min/max etc
    
        @staticmethod
        @jit(nopython=True)
        def max(x, *args, **kwargs):
            return np.max(x, *args, **kwargs)
    
        @staticmethod
        @jit(nopython=True)
        def min(x, *args, **kwargs):
            return np.min(x, *args, **kwargs)
    
        @staticmethod
        @jit(nopython=True)
        def mean(x, *args, **kwargs):
            return np.mean(x, *args, **kwargs)
    
        @staticmethod
        @jit(nopython=True)
        def var(x, *args, **kwargs):
            return np.var(x, *args, **kwargs)
    
        @staticmethod
        @jit(nopython=True)
        def std(x, *args, **kwargs):
            return np.std(x, *args, **kwargs)
    
        # others
    
        @staticmethod
        @jit(nopython=True)
        def pad(x, *args, **kwargs):
            return np.pad(x, *args, **kwargs)
    
        @staticmethod
        @jit(nopython=True)
        def insert(x, *args, **kwargs):
            return np.insert(x, *args, **kwargs)
    
        @staticmethod
        @jit(nopython=True)
        def transpose(x, *args, **kwargs):
            return np.transpose(x, *args, **kwargs)