Skip to content
Snippets Groups Projects
backend.py 32.5 KiB
Newer Older
  • Learn to ignore specific revisions
  • johannes bilk's avatar
    johannes bilk committed
    
        @staticmethod
        @jit(nopython=True)
        def where(x, *args, **kwargs):
            return np.where(x, *args, **kwargs)
    
        @staticmethod
        @jit(nopython=True)
        def cumsum(x, *args, **kwargs):
            return np.cumsum(x, *args, **kwargs)
    
        @staticmethod
        @jit(nopython=True)
        def cumprod(x, *args, **kwargs):
            return np.cumprod(x, *args, **kwargs)
    
        # not working yet
    
        @staticmethod
        @jit(nopython=True)
        def as_strided(x, *args, **kwargs):
            return np.lib.stride_tricks.as_strided(x, *args, **kwargs)
        
        @staticmethod
        @jit(nopython=True)
        def sliding_window_view( x, *args, **kwargs):
            return np.lib.stride_tricks.sliding_window_view(x, *args, **kwargs)
    
        @staticmethod
        @jit(nopython=True)
        def einsum(x, y, *args, **kwargs):
            return np.einsum('bihwkl,oikl->bohw', x, y, *args, **kwargs)
    
    
    class PytorchBackend(BackendInterface):
        try:
            import torch
        except ImportError:
            pass
        # init
    
        def array(self, x, *args, **kwargs):
            return self.torch.tensor(x, *args, **kwargs)
        
        def copy(self, x, *args, **kwargs):
            return self.torch.clone(x)
        
        def zeros(self, x, *args, **kwargs):
            return self.torch.zeros(x, *args, **kwargs)
    
        def ones(self, x, *args, **kwargs):
            return self.torch.ones(x, *args, **kwargs)
    
        def zeros_like(self, x, *args, **kwargs):
            return self.torch.zeros_like(x, *args, **kwargs)
    
        def ones_like(self, x, *args, **kwargs):
            return self.torch.ones_like(x, *args, **kwargs)
        
        def arange(self, *args, **kwargs):
            return self.torch.arange(*args, **kwargs)
    
        # double tensor
    
        def add(self, x, y, *args, **kwargs):
            return self.torch.add(x, y, *args, **kwargs)
    
        def subtract(self, x, y, *args, **kwargs):
            return self.torch.subtract(x, y, *args, **kwargs)
    
        def multiply(self, x, y, *args, **kwargs):
            return self.torch.multiply(x, y, *args, **kwargs)
    
        def divide(self, x, y, *args, **kwargs):
            return self.torch.divide(x, y, *args, **kwargs)
    
        def matmul(self, x, y, *args, **kwargs):
            return self.torch.matmul(x, y, *args, **kwargs)
        
        def dot(self, x, y, *args, **kwargs):
            return self.torch.dot(x, y, *args, **kwargs)
    
        def power(self, x, y, *args, **kwargs):
            return self.torch.pow(x, y, *args, **kwargs)
    
        # single tensor
    
        def square(self, x, *args, **kwargs):
            return self.torch.square(x, *args, **kwargs)
        
        def sqrt(self, x, *args, **kwargs):
            return self.torch.sqrt(x, *args, **kwargs)
    
        def log(self, x, *args, **kwargs):
            return self.torch.log(x, *args, **kwargs)
    
        def exp(self, x, *args, **kwargs):
            return self.torch.exp(x, *args, **kwargs)
    
        def sin(self, x, *args, **kwargs):
            return self.torch.sin(x, *args, **kwargs)
    
        def cos(self, x, *args, **kwargs):
            return self.torch.cos(x, *args, **kwargs)
    
        def tan(self, x, *args, **kwargs):
            return self.torch.tan(x, *args, **kwargs)
    
        def sinh(self, x, *args, **kwargs):
            return self.torch.sinh(x, *args, **kwargs)
    
        def cosh(self, x, *args, **kwargs):
            return self.torch.cosh(x, *args, **kwargs)
    
        def tanh(self, x, *args, **kwargs):
            return self.torch.tanh(x, *args, **kwargs)
    
        def abs(self, x, *args, **kwargs):
            return self.torch.abs(x, *args, **kwargs)
    
        # signs
    
        def sign(self, x, *args, **kwargs):
            return self.torch.sign(x, *args, **kwargs)
        
        def positive(self, x, *args, **kwargs):
            return self.torch.positive(x, *args, **kwargs)
        
        def negative(self, x, *args, **kwargs):
            return self.torch.negative(x, *args, **kwargs)
    
        # compare
    
        def equal(self, x, y, *args, **kwargs):
            return self.torch.equal(x, y, *args, **kwargs)
    
        def not_equal(self, x, y, *args, **kwargs):
            return self.torch.not_equal(x, y, *args, **kwargs)
    
        def less(self, x, y, *args, **kwargs):
            return self.torch.less(x, y, *args, **kwargs)
    
        def less_equal(self, x, y, *args, **kwargs):
            return self.torch.less_equal(x, y, *args, **kwargs)
    
        def greater(self, x, y, *args, **kwargs):
            return self.torch.greater(x, y, *args, **kwargs)
    
        def greater_equal(self, x, y, *args, **kwargs):
            return self.torch.greater_equal(x, y, *args, **kwargs)
    
        # logic
        
        def logical_and(self, x, *args, **kwargs):
            return self.torch.logical_and(x, *args, **kwargs)
        
        def logical_or(self, x, *args, **kwargs):
            return self.torch.logical_or(x, *args, **kwargs)
    
        def logical_xor(self, x, *args, **kwargs):
            return self.torch.logical_xor(x, *args, **kwargs)
        
        def logical_not(self, x, *args, **kwargs):
            return self.torch.logical_not(x, *args, **kwargs)
        
        # shaping
    
        def flatten(self, x, **kwargs):
            return self.torch.reshape(x, -1, **kwargs)
    
        def reshape(self, x, *args, **kwargs):
            return self.torch.reshape(x, *args, **kwargs)
    
        # broadcasting
    
        def broadcast_to(self, x, *args, **kwargs):
            return self.torch.broadcast_to(x, *args, **kwargs)
    
        def repeat(self, x, *args, **kwargs):
            return self.torch.repeat(x, *args, **kwargs)
    
        def tile(self, x, *args, **kwargs):
            return self.torch.tile(x, *args, **kwargs)
    
        def concatenate(self, x, *args, **kwargs):
            return self.torch.cat(x, *args, **kwargs)
        
        def split(self, x, *args, **kwargs):
            return self.torch.split(x, *args, **kwargs)
    
        # reduce
    
        def sum(self, x, *args, **kwargs):
            return self.torch.sum(x)
    
        def prod(self, x, *args, **kwargs):
            return self.torch.prod(x, *args, **kwargs)
    
        # min/max etc
    
        def max(self, x, *args, **kwargs):
            return self.torch.max(x, *args, **kwargs)
    
        def min(self, x, *args, **kwargs):
            return self.torch.min(x, *args, **kwargs)
    
        def mean(self, x, *args, **kwargs):
            return self.torch.mean(x, *args, **kwargs)
    
        def var(self, x, *args, **kwargs):
            return self.torch.var(x, *args, **kwargs)
    
        def std(self, x, *args, **kwargs):
            return self.torch.std(x, *args, **kwargs)
    
        # others
    
        def pad(self, x, *args, **kwargs):
            return self.torch.pad(x, *args, **kwargs)
    
        def insert(self, x, *args, **kwargs):
            return self.torch.insert(x, *args, **kwargs)
        
        def transpose(self, x, *args, **kwargs):
            return self.torch.transpose(x, 0, 1)
    
        def where(self, x, *args, **kwargs):
            return self.torch.where(x, *args, **kwargs)
        
        def cumsum(self, x, *args, **kwargs):
            return self.torch.cumsum(x, *args, **kwargs)
    
        def cumprod(self, x, *args, **kwargs):
            return self.torch.cumprod(x, *args, **kwargs)
    
        # not working yet
    
        def as_strided(self, x, *args, **kwargs):
            return self.torch.as_strided(x, *args, **kwargs)
        
        def sliding_window_view(self, x, *args, **kwargs):
            raise NotImplementedError
    
        def einsum(self, subscript, x, y, *args, **kwargs):
            return self.torch.einsum(subscript, x, y, *args, **kwargs)
    
    
    class TensorflowBackend(BackendInterface):
        try:
            import tensorflow as tf
        except ImportError:
            pass
        # Implement the necessary methods here using TensorFlow's functions
        pass