import numpy as np
import uproot as ur


class fromRoot:
    def __init__(self):
        # panel ids werden benutzt um layer und lader zu bestimmen und um u/v zu kartesisch um zu rechnen
        self.panelIDs = np.array([ 8480,  8512,  8736,  8768,  8992,  9024,  9248,  9280,
                              9504,  9536,  9760,  9792, 10016, 10048, 10272, 10304,
                             16672, 16704, 16928, 16960, 17184, 17216, 17440, 17472,
                             17696, 17728, 17952, 17984, 18208, 18240, 18464, 18496,
                             18720, 18752, 18976, 19008, 19232, 19264, 19488, 19520])

        # die koordinaten verschiebung der u/v koordinaten in korrekte panel position
        self.panelShifts = np.array([[1.3985    ,  0.2652658 ,  3.68255],
                               [ 1.3985    ,  0.23238491, -0.88255],
                               [ 0.80146531,  1.17631236,  3.68255],
                               [ 0.82407264,  1.15370502, -0.88255],
                               [-0.2582769 ,  1.3985    ,  3.68255],
                               [-0.2322286 ,  1.3985    , -0.88255],
                               [-1.17531186,  0.80246583, 3.68255 ],
                               [-1.15510614,  0.82267151, -0.88255],
                               [-1.3985    , -0.2645974 ,  3.68255],
                               [-1.3985    , -0.23012119, -0.88255],
                               [-0.80591227, -1.17186534,  3.68255],
                               [-0.82344228, -1.15433536, -0.88255],
                               [ 0.26975836, -1.3985    ,  3.68255],
                               [ 0.23326624, -1.3985    , -0.88255],
                               [ 1.1746111 , -0.80316652,  3.68255],
                               [ 1.15205703, -0.82572062, -0.88255],
                               [ 2.2015    ,  0.26959865,  5.01305],
                               [ 2.2015    ,  0.2524582 , -1.21305],
                               [ 1.77559093,  1.32758398,  5.01305],
                               [ 1.78212569,  1.31626522, -1.21305],
                               [ 0.87798948,  2.03516717,  5.01305],
                               [ 0.88478563,  2.03124357, -1.21305],
                               [-0.26129975,  2.2015    ,  5.01305],
                               [-0.25184137,  2.2015    , -1.21305],
                               [-1.32416655,  1.77756402,  5.01305],
                               [-1.31417539,  1.78333226, -1.21305],
                               [-2.03421133,  0.87964512,  5.01305],
                               [-2.02960691,  0.88762038, -1.21305],
                               [-2.2015    , -0.25954151,  5.01305],
                               [-2.2015    , -0.24969109, -1.21305],
                               [-1.77636043, -1.32625112,  5.01305],
                               [-1.78138268, -1.31755219, -1.21305],
                               [-0.87493138, -2.03693277, 5.01305 ],
                               [-0.8912978 , -2.02748378, -1.21305],
                               [ 0.26489725, -2.2015    ,  5.01305],
                               [ 0.25364439, -2.2015    , -1.21305],
                               [ 1.3269198 , -1.7759744 ,  5.01305],
                               [ 1.32258793, -1.77847528, -1.21305],
                               [ 2.03616649, -0.87625871,  5.01305],
                               [ 2.02936825, -0.8880338 , -1.21305]])

        # drehwinkel um panels korrekt auszurichten
        self.panelRotations = np.array([ 90,  90, 135, 135, 180, 180, 225, 225, 270, 270, 315, 315, 360,
                                   360, 405, 405,  90,  90, 120, 120, 150, 150, 180, 180, 210, 210,
                                   240, 240, 270, 270, 300, 300, 330, 330, 360, 360, 390, 390, 420,
                                   420])

        # ladder und layer ids
        self.panelLayer = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
        self.panelLadder = np.array([1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 20, 20, 21, 21])

        # generierung der look-up tabels für u/v -> x/y/z transformation
        self.transformation = {}
        self.layersLadders = {}
        for i in range(len(self.panelIDs)):
            self.transformation[str(self.panelIDs[i])] = [self.panelShifts[i], self.panelRotations[i]]
            self.layersLadders[str(self.panelIDs[i])] = [panelLayer[i], panelLadder[i]]

    def loadData(self, file, path='.'):
        self.eventTree = ur.open(f'{path}/{file}.root:tree')

    # liest den event tree aus, man muss das voll schlüsselwort angeben
    def getData(self, keyword: str):
        try:
            return self.eventTree.arrays(keyword, library='np')[keyword]
        except:
            return KeyError

    # generiert für jeden cluster eine event nummer
    def genEventNumbers(self, clusters):
        eventNumbers = []
        for i in range(len(clusters)):
            eventNumbers.append(np.array([i]*len(clusters[i])))
        return flatten(eventNumbers)

    # organisiert mc und digit daten aller events um, sodass sie zu clustern passen
    def getEventData(self, relations, *args):
        returnList = []
        for i in range(len(args)):
            stuffList = []
            for item in relations:
                stuffList.append([0] * len(item))
            returnList.append(stuffList)
        for i, references in enumerate(relations):
            for k, index in enumerate(references):
                for j in range(len(args)):
                    returnList[j][i][k] = args[j][index]
        if len(returnList) == 1:
            return returnList[0]
        else:
            return returnList

    def flatten(self, structure, max_depth=None, current_depth=0):
        flat_list = []

        for element in structure:
            if isinstance(element, (list, np.ndarray)) and (max_depth is None or current_depth < max_depth):
                flat_list.extend(self.flatten(element, max_depth, current_depth + 1))
            else:
                flat_list.append(element)

        return np.array(flat_list, dtype=object)

    # generiert alle pixel matrizen aller events
    def getClustersFlattened(self, uCellIDs, vCellIDs, cellCharges, clusterDigits, matrixSize=(9,9)):
        length = 0
        start = 0
        for item in cellCharges:
            length += len(item)
        events = [0] * length
        plotRange = int(np.round(matrixSize[0]/2)), int(np.round(matrixSize[1]/2))
        for event in range(len(cellCharges)):
            adcValues = []
            digitsU = uCellIDs[event]
            digitsV = vCellIDs[event]
            digitsCharge = cellCharges[event]
            digitIndices = clusterDigits[event]
            for indices in digitIndices:
                cacheImg = np.zeros(matrixSize)
                maxChargeIndex = digitsCharge[indices].argmax()
                uMax = digitsU[indices[maxChargeIndex]]
                vMax = digitsV[indices[maxChargeIndex]]
                for index in indices:
                    uPos = digitsU[index]
                    vPos = digitsV[index]
                    uu = int(uPos) - int(uMax) + plotRange[0]
                    vv = int(vPos) - int(vMax) + plotRange[1]
                    if uu >= 0 and uu < matrixSize[0] and vv >= 0 and vv < matrixSize[1]:
                        cacheImg[uu,vv] = digitsCharge[index]
                adcValues.append(cacheImg)
            stop = len(adcValues)
            events[start:start+stop] = adcValues
            start += stop
        return np.array(events, dtype=object)

    # rotiert und verschiebt eine Koordinate
    def rotShiftVector(self, vector, angle, shift=[0,0,0], scale=1):
        theta = np.deg2rad(angle)
        rotMatrix = np.array([[np.cos(theta),-np.sin(theta),0],[np.sin(theta),np.cos(theta),0],[0,0,1]])
        scaleMatrix = np.array([[1,0,0],[0,1,0],[0,0,scale]])
        return rotMatrix.dot(scaleMatrix.dot(vector)) + shift

    # berechnet die koordinaten aller Events
    def getCartesianFlattened(self, uPositions, vPositions, sensorIDs, transformations: dict):
        length = 0
        start = 0
        for item in sensorIDs:
            length += len(item)
        xArr, yArr, zArr = [0] * length, [0] * length, [0] * length
        for event in range(len(sensorIDs)):
            xyz = []
            uPos = uPositions[event]
            vPos = vPositions[event]
            sensors = sensorIDs[event]
            points = np.vstack((uPos, np.zeros(len(uPos)), vPos)).T
            for point, id in zip(points, sensors):
                shift, angle = transformations[str(id)]
                shifted = rotShiftVector(point, angle, shift)
                xyz.append(shifted)
            if len(xyz) > 0:
                stop = len(xyz)
                xArr[start:start+stop] = np.array(xyz)[:,0]
                yArr[start:start+stop] = np.array(xyz)[:,1]
                zArr[start:start+stop] = np.array(xyz)[:,2]
            start += stop
        return np.array(xArr, dtype=object), np.array(yArr, dtype=object), np.array(zArr, dtype=object)

    # bestimmt die layer und ladder nummer eines events
    def getLayers(self, sensorIDs, layersLadders: dict):
        layers, ladders = [], []
        for id in sensorIDs:
            layer, ladder = layersLadders[str(id)]
            layers.append(layer)
            ladders.append(ladder)
        return np.array(layers), np.array(ladders)

    # findet die fehlenden event referenzen in mc daten, setzt sie gleich -1
    def findMissing(self, lst: list, length: int) -> list:
        return sorted(set(range(0, length)) - set(lst))

    # füllt die mc-cluster beziehungs arrays mit fehlenden werten auf
    def fillMCList(self, fromClusters, toClusters, length):
        missingIndex = findMissing(fromClusters, length)
        testList = [-1] * length
        fillIndex = 0
        for i in range(len(testList)):
            if i in missingIndex:
                testList[i] = -1
            else:
                try:
                    testList[i] = int(toClusters[fillIndex])
                except TypeError:
                    testList[i] = int(toClusters[fillIndex][0])
                fillIndex += 1
        return testList

    # organisiert mc daten eines events um
    def getMCData(self, toClusters, pdgs, xMom, yMom, zMom):
        pxList, pyList, pzList = [], [], []
        pdgList = []
        for references in toClusters:
            if references == -1:
                pxList.append(0)
                pyList.append(0)
                pzList.append(0)
                pdgList.append(0)
            else:
                pxList.append(xMom[references])
                pyList.append(yMom[references])
                pzList.append(zMom[references])
                pdgList.append(pdgs[references])
        return np.array(pdgList,dtype=list), np.array(pxList,dtype=list), np.array(pyList,dtype=list), np.array(pzList,dtype=list)