import numpy as np
from ..common import calcSpherical
from concurrent.futures import ThreadPoolExecutor


class ClusterCoordinates:
    """
    This class takes care of cluster coordinates
    """
    def __init__(self) -> None:
        # these are the sensor IDs of the pxd modules/panels from the root file, they are
        # use to identify on which panels a cluster event happened
        self.panelIDs = np.array([ 8480,  8512,  8736,  8768,  8992,  9024,  9248,  9280,
                                   9504,  9536,  9760,  9792, 10016, 10048, 10272, 10304,
                                  16672, 16704, 16928, 16960, 17184, 17216, 17440, 17472,
                                  17696, 17728, 17952, 17984, 18208, 18240, 18464, 18496,
                                  18720, 18752, 18976, 19008, 19232, 19264, 19488, 19520])

        # every line in this corresponds to one entry in the array above, this is used
        # to put the projected uv plane in the right position
        self.panelShifts = np.array([[ 1.3985    ,  0.2652658 ,  3.68255],
                                     [ 1.3985    ,  0.23238491, -0.88255],
                                     [ 0.80146531,  1.17631236,  3.68255],
                                     [ 0.82407264,  1.15370502, -0.88255],
                                     [-0.2582769 ,  1.3985    ,  3.68255],
                                     [-0.2322286 ,  1.3985    , -0.88255],
                                     [-1.17531186,  0.80246583,  3.68255],
                                     [-1.15510614,  0.82267151, -0.88255],
                                     [-1.3985    , -0.2645974 ,  3.68255],
                                     [-1.3985    , -0.23012119, -0.88255],
                                     [-0.80591227, -1.17186534,  3.68255],
                                     [-0.82344228, -1.15433536, -0.88255],
                                     [ 0.26975836, -1.3985    ,  3.68255],
                                     [ 0.23326624, -1.3985    , -0.88255],
                                     [ 1.1746111 , -0.80316652,  3.68255],
                                     [ 1.15205703, -0.82572062, -0.88255],
                                     [ 2.2015    ,  0.26959865,  5.01305],
                                     [ 2.2015    ,  0.2524582 , -1.21305],
                                     [ 1.77559093,  1.32758398,  5.01305],
                                     [ 1.78212569,  1.31626522, -1.21305],
                                     [ 0.87798948,  2.03516717,  5.01305],
                                     [ 0.88478563,  2.03124357, -1.21305],
                                     [-0.26129975,  2.2015    ,  5.01305],
                                     [-0.25184137,  2.2015    , -1.21305],
                                     [-1.32416655,  1.77756402,  5.01305],
                                     [-1.31417539,  1.78333226, -1.21305],
                                     [-2.03421133,  0.87964512,  5.01305],
                                     [-2.02960691,  0.88762038, -1.21305],
                                     [-2.2015    , -0.25954151,  5.01305],
                                     [-2.2015    , -0.24969109, -1.21305],
                                     [-1.77636043, -1.32625112,  5.01305],
                                     [-1.78138268, -1.31755219, -1.21305],
                                     [-0.87493138, -2.03693277,  5.01305],
                                     [-0.8912978 , -2.02748378, -1.21305],
                                     [ 0.26489725, -2.2015    ,  5.01305],
                                     [ 0.25364439, -2.2015    , -1.21305],
                                     [ 1.3269198 , -1.7759744 ,  5.01305],
                                     [ 1.32258793, -1.77847528, -1.21305],
                                     [ 2.03616649, -0.87625871,  5.01305],
                                     [ 2.02936825, -0.8880338 , -1.21305]])

        # every entry here corresponds to the entries in the array above, these are
        # used for rotating the projected uv plane
        self.panelRotations = np.array([ 90,  90, 135, 135, 180, 180, 225, 225, 270, 270, 315, 315, 360,
                                        360, 405, 405,  90,  90, 120, 120, 150, 150, 180, 180, 210, 210,
                                        240, 240, 270, 270, 300, 300, 330, 330, 360, 360, 390, 390, 420,
                                        420])

        # the layer and ladder arrays, for finding them from sensor id
        self.panelLayer  = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2])
        self.panelLadder = np.array([1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 20, 20, 21, 21])

        # all transpormaations are stored in a dict, with the sensor id as a keyword
        self.transformation = {}
        self.layersLadders = {}
        for i in range(len(self.panelIDs)):
            self.transformation[self.panelIDs[i]] = [self.panelShifts[i], self.panelRotations[i]]
            self.layersLadders[self.panelIDs[i]] = [self.panelLayer[i], self.panelLadder[i]]

    def get(self, uPositions: np.ndarray, vPositions: np.ndarray, sensorIDs: np.ndarray) -> dict:
        """
        converting the uv coordinates, together with sensor ids, into xyz coordinates
        """
        #setting up index chunks for multi threading
        indexChunks = np.array_split(range(len(sensorIDs)), 4)

        # Initialize result lists
        xResults, yResults, zResults = [], [], []
        coordinates = {}

        with ThreadPoolExecutor(max_workers=4) as executor:
            futures = [executor.submit(self._process, uPositions[chunk], vPositions[chunk], sensorIDs[chunk]) for chunk in indexChunks]

            for future in futures:
                x, y, z = future.result()
                xResults.append(x)
                yResults.append(y)
                zResults.append(z)

            coordinates['xPosition'] = np.concatenate(xResults)
            coordinates['yPosition'] = np.concatenate(yResults)
            coordinates['zPosition'] = np.concatenate(zResults)

        return coordinates

    def _process(self, uPositions: np.ndarray, vPositions: np.ndarray, sensorIDs: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
        a private method for transposing/converting 2d uv coords into 3d xyz coordinates
        """
        length = len(sensorIDs)
        xArr, yArr, zArr = np.zeros(length), np.zeros(length), np.zeros(length)

        # iterting over the cluster arrays
        for index, (u, v, sensorID) in enumerate(zip(uPositions, vPositions, sensorIDs)):
            # grabbing the shift vector and rotation angle
            shift, angle = self.transformation[sensorID]

            # setting up rotation matrix
            theta = np.deg2rad(angle)
            rotMatrix = np.array([[np.cos(theta), -np.sin(theta), 0], [np.sin(theta), np.cos(theta), 0], [0, 0, 1]])

            # projecting uv coordinates into 3d space
            point = np.array([u, 0, v])

            # shifting and rotating the projected vector
            shifted = rotMatrix.dot(point) + shift
            xArr[index], yArr[index], zArr[index] = shifted

        return xArr, yArr, zArr

    def layers(self, sensorIDs: np.ndarray) -> dict:
        """
        looks up the corresponding layers and ladders for every cluster
        """
        layersLadders = {}
        length = len(sensorIDs)
        layers = np.empty(length, dtype=int)
        ladders = np.empty(length, dtype=int)

        for i, id in enumerate(sensorIDs):
            layers[i], ladders[i] = self.layersLadders[id]

        return {'layer': np.array(layers, dtype=int),
               'ladder': np.array(ladders, dtype=int)}

    def sphericals(self, xPosition: np.ndarray, yPosition: np.ndarray, zPosition: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
        this calculates spherical coordinates from xyz coordinates
        """
        return calcSpherical(xPosition, yPosition, zPosition)