import numpy as np
import uproot as ur


class fromRoot:
    def __init__(self):
        # panel ids werden benutzt um layer und lader zu bestimmen und um u/v zu kartesisch um zu rechnen
        self.panelIDs = np.array([ 8480,  8512,  8736,  8768,  8992,  9024,  9248,  9280,
                              9504,  9536,  9760,  9792, 10016, 10048, 10272, 10304,
                             16672, 16704, 16928, 16960, 17184, 17216, 17440, 17472,
                             17696, 17728, 17952, 17984, 18208, 18240, 18464, 18496,
                             18720, 18752, 18976, 19008, 19232, 19264, 19488, 19520])

        # die koordinaten verschiebung der u/v koordinaten in korrekte panel position
        self.panelShifts = np.array([[1.3985    ,  0.2652658 ,  3.68255],
                               [ 1.3985    ,  0.23238491, -0.88255],
                               [ 0.80146531,  1.17631236,  3.68255],
                               [ 0.82407264,  1.15370502, -0.88255],
                               [-0.2582769 ,  1.3985    ,  3.68255],
                               [-0.2322286 ,  1.3985    , -0.88255],
                               [-1.17531186,  0.80246583, 3.68255 ],
                               [-1.15510614,  0.82267151, -0.88255],
                               [-1.3985    , -0.2645974 ,  3.68255],
                               [-1.3985    , -0.23012119, -0.88255],
                               [-0.80591227, -1.17186534,  3.68255],
                               [-0.82344228, -1.15433536, -0.88255],
                               [ 0.26975836, -1.3985    ,  3.68255],
                               [ 0.23326624, -1.3985    , -0.88255],
                               [ 1.1746111 , -0.80316652,  3.68255],
                               [ 1.15205703, -0.82572062, -0.88255],
                               [ 2.2015    ,  0.26959865,  5.01305],
                               [ 2.2015    ,  0.2524582 , -1.21305],
                               [ 1.77559093,  1.32758398,  5.01305],
                               [ 1.78212569,  1.31626522, -1.21305],
                               [ 0.87798948,  2.03516717,  5.01305],
                               [ 0.88478563,  2.03124357, -1.21305],
                               [-0.26129975,  2.2015    ,  5.01305],
                               [-0.25184137,  2.2015    , -1.21305],
                               [-1.32416655,  1.77756402,  5.01305],
                               [-1.31417539,  1.78333226, -1.21305],
                               [-2.03421133,  0.87964512,  5.01305],
                               [-2.02960691,  0.88762038, -1.21305],
                               [-2.2015    , -0.25954151,  5.01305],
                               [-2.2015    , -0.24969109, -1.21305],
                               [-1.77636043, -1.32625112,  5.01305],
                               [-1.78138268, -1.31755219, -1.21305],
                               [-0.87493138, -2.03693277, 5.01305 ],
                               [-0.8912978 , -2.02748378, -1.21305],
                               [ 0.26489725, -2.2015    ,  5.01305],
                               [ 0.25364439, -2.2015    , -1.21305],
                               [ 1.3269198 , -1.7759744 ,  5.01305],
                               [ 1.32258793, -1.77847528, -1.21305],
                               [ 2.03616649, -0.87625871,  5.01305],
                               [ 2.02936825, -0.8880338 , -1.21305]])

        # drehwinkel um panels korrekt auszurichten
        self.panelRotations = np.array([ 90,  90, 135, 135, 180, 180, 225, 225, 270, 270, 315, 315, 360,
                                   360, 405, 405,  90,  90, 120, 120, 150, 150, 180, 180, 210, 210,
                                   240, 240, 270, 270, 300, 300, 330, 330, 360, 360, 390, 390, 420,
                                   420])

        # ladder und layer ids
        self.panelLayer = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
        self.panelLadder = np.array([1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 20, 20, 21, 21])

        # generierung der look-up tabels für u/v -> x/y/z transformation
        self.transformation = {}
        self.layersLadders = {}
        for i in range(len(self.panelIDs)):
            self.transformation[str(self.panelIDs[i])] = [self.panelShifts[i], self.panelRotations[i]]
            self.layersLadders[str(self.panelIDs[i])] = [self.panelLayer[i], self.panelLadder[i]]

        self.gotClusters = False
        self.clusters = ['PXDClusters/PXDClusters.m_clsCharge',
                         'PXDClusters/PXDClusters.m_seedCharge',
                         'PXDClusters/PXDClusters.m_clsSize',
                         'PXDClusters/PXDClusters.m_uSize',
                         'PXDClusters/PXDClusters.m_vSize',
                         'PXDClusters/PXDClusters.m_uPosition',
                         'PXDClusters/PXDClusters.m_vPosition',
                         'PXDClusters/PXDClusters.m_sensorID']

        self.digits = ['PXDDigits/PXDDigits.m_uCellID',
                       'PXDDigits/PXDDigits.m_vCellID',
                       'PXDDigits/PXDDigits.m_charge']

        self.clusterToDigis = 'PXDClustersToPXDDigits/m_elements/m_elements.m_to'

        self.mcData = ['MCParticles/MCParticles.m_pdg',
                       'MCParticles/MCParticles.m_momentum_x',
                       'MCParticles/MCParticles.m_momentum_y',
                       'MCParticles/MCParticles.m_momentum_z']

        self.clusterToMC = 'PXDClustersToMCParticles/m_elements/m_elements.m_to'
        self.mcToCluster = 'PXDClustersToMCParticles/m_elements/m_elements.m_from'

        self.data = {}

    def __getitem__(self, index):
        if isinstance(index, str):
            return self.data[index]
        else:
            return {key: value[index] for key, value in self.data.items()}

    def loadData(self, file):
        self.eventTree = ur.open(f'{file}:tree')
        self._genEventNumbers()

    # liest den event tree aus, man muss das voll schlüsselwort angeben
    def _getData(self, keyword: str):
        try:
            data = self.eventTree.arrays(keyword, library='np')[keyword]
            return self._flatten(data)
        except:
            return KeyError

    def _flatten(self, structure, max_depth=None, current_depth=0):
        flat_list = []

        for element in structure:
            if isinstance(element, (list, np.ndarray)) and (max_depth is None or current_depth < max_depth):
                flat_list.extend(self._flatten(element, max_depth, current_depth + 1))
            else:
                flat_list.append(element)

        return np.array(flat_list, dtype=object)

    # generiert für jeden cluster eine event nummer
    def _genEventNumbers(self):
        eventNumbers = []
        clusters = self.eventTree.arrays('PXDClusters/PXDClusters.m_clsCharge', library='np')['PXDClusters/PXDClusters.m_clsCharge']
        for i in range(len(clusters)):
            eventNumbers.append(np.array([i]*len(clusters[i])))
        self.data['eventNumbers'] = self._flatten(eventNumbers)

    def getClusters(self):
        self.gotClusters = True
        for branch in self.clusters:
            data = self._getData(branch)
            keyword = branch.split('_')[-1]
            self.data[keyword] = data

    def getMatrices(self):
        uCellIDs = self.eventTree.arrays(self.digits[0], library='np')[self.digits[0]]
        vCellIDs = self.eventTree.arrays(self.digits[1], library='np')[self.digits[1]]
        cellCharges = self.eventTree.arrays(self.digits[2], library='np')[self.digits[2]]
        clusterDigits = self.eventTree.arrays(self.clusterToDigis, library='np')[self.clusterToDigis]

        self.data['ADCs'] = self._genMatrices(uCellIDs, vCellIDs, cellCharges, clusterDigits)

    # generiert alle pixel matrizen aller events
    def _genMatrices(self, uCellIDs, vCellIDs, cellCharges, clusterDigits, matrixSize=(9, 9)):
        plotRange = np.array(matrixSize) // 2
        events = []

        for event in range(len(cellCharges)):
            adcValues = []
            digitsU = np.array(uCellIDs[event])
            digitsV = np.array(vCellIDs[event])
            digitsCharge = np.array(cellCharges[event])
            digitIndices = clusterDigits[event]

            for indices in digitIndices:
                cacheImg = np.zeros(matrixSize)
                maxChargeIndex = digitsCharge[indices].argmax()
                uMax, vMax = digitsU[indices[maxChargeIndex]], digitsV[indices[maxChargeIndex]]
                uPos, vPos = digitsU[indices] - uMax + plotRange[0], digitsV[indices] - vMax + plotRange[1]

                valid_indices = (uPos >= 0) & (uPos < matrixSize[0]) & (vPos >= 0) & (vPos < matrixSize[1])
                cacheImg[uPos[valid_indices].astype(int), vPos[valid_indices].astype(int)] = digitsCharge[indices][valid_indices]
                adcValues.append(cacheImg)

            events.extend(adcValues)

        return np.array(events, dtype=object)

    def genCoordisnate(self):
        if self.gotClusters is False:
            self.getClusters()
        xcoords, ycoords, zcoords = self._getCartesian(self.data['uPosition'], self.data['vPosition'], self.data['sensorID'])
        self.data['xcoords'] = xcoords
        self.data['ycoords'] = ycoords
        self.data['zcoords'] = zcoords

    def _getCartesian(self, uPositions, vPositions, sensorIDs):
        length = len(sensorIDs)
        xArr, yArr, zArr = np.zeros(length), np.zeros(length), np.zeros(length)

        for index, (u, v, sensor_id) in enumerate(zip(uPositions, vPositions, sensorIDs)):
            shift, angle = self.transformation[str(sensor_id)]
            theta = np.deg2rad(angle)
            rotMatrix = np.array([[np.cos(theta), -np.sin(theta), 0], [np.sin(theta), np.cos(theta), 0], [0, 0, 1]])
            point = np.array([u, 0, v])
            shifted = rotMatrix.dot(point) + shift
            xArr[index], yArr[index], zArr[index] = shifted

        return xArr, yArr, zArr

    # bestimmt die layer und ladder nummer eines events
    def getLayers(self):
        if self.gotClusters is False:
            self.getClusters()
        layers, ladders = [], []
        for id in self.data['sensorID']:
            layer, ladder = self.layersLadders[str(id)]
            layers.append(layer)
            ladders.append(ladder)
        self.data['layers'] = np.array(layers)
        self.data['ladder'] = np.array(ladders)

    def getMCData(self):
        pdg = self.eventTree.arrays(self.mcData[0], library='np')[self.mcData[0]]
        momentumX = self.eventTree.arrays(self.mcData[1], library='np')[self.mcData[1]]
        momentumY = self.eventTree.arrays(self.mcData[2], library='np')[self.mcData[2]]
        momentumZ = self.eventTree.arrays(self.mcData[3], library='np')[self.mcData[3]]

        clusterToMC = self.eventTree.arrays(self.clusterToMC, library='np')[self.clusterToMC]
        mcToCluster = self.eventTree.arrays(self.mcToCluster, library='np')[self.mcToCluster]
        clsCharge = self.eventTree.arrays('PXDClusters/PXDClusters.m_clsCharge', library='np')['PXDClusters/PXDClusters.m_clsCharge']

        # mc umorganisieren
        momentumXList = []
        momentumYList = []
        momentumZList = []
        pdgList = []
        clusterNumbersList = []
        for i in range(len(clusterToMC)):
            fullClusterReferences = self._fillMCList(mcToCluster[i], clusterToMC[i], len(clsCharge[i]))
            clusterNumbersList.append(fullClusterReferences)
            pdgs, xmom, ymom, zmom = self._getMCData(fullClusterReferences, pdg[i], momentumX[i], momentumY[i], momentumZ[i])
            momentumXList.append(xmom)
            momentumYList.append(ymom)
            momentumZList.append(zmom)
            pdgList.append(pdgs)

        self.data['momentumX'] = self._flatten(momentumXList)
        self.data['momentumY'] = self._flatten(momentumYList)
        self.data['momentumZ'] = self._flatten(momentumZList)
        self.data['pdg'] = self._flatten(pdgList)
        self.data['clsNumber'] = self._flatten(clusterNumbersList)


    # findet die fehlenden event referenzen in mc daten, setzt sie gleich -1
    def _findMissing(self, lst: list, length: int) -> list:
        return sorted(set(range(0, length)) - set(lst))

    # füllt die mc-cluster beziehungs arrays mit fehlenden werten auf
    def _fillMCList(self, fromClusters, toClusters, length):
        missingIndex = self._findMissing(fromClusters, length)
        testList = [-1] * length
        fillIndex = 0
        for i in range(len(testList)):
            if i in missingIndex:
                testList[i] = -1
            else:
                try:
                    testList[i] = int(toClusters[fillIndex])
                except TypeError:
                    testList[i] = int(toClusters[fillIndex][0])
                fillIndex += 1
        return testList

    # organisiert mc daten eines events um
    def _getMCData(self, toClusters, pdgs, xMom, yMom, zMom):
        pxList, pyList, pzList = [], [], []
        pdgList = []
        for references in toClusters:
            if references == -1:
                pxList.append(0)
                pyList.append(0)
                pzList.append(0)
                pdgList.append(0)
            else:
                pxList.append(xMom[references])
                pyList.append(yMom[references])
                pzList.append(zMom[references])
                pdgList.append(pdgs[references])
        return np.array(pdgList,dtype=list), np.array(pxList,dtype=list), np.array(pyList,dtype=list), np.array(pzList,dtype=list)