Skip to content
Snippets Groups Projects
Commit fed87347 authored by johannes bilk's avatar johannes bilk
Browse files

added some file handling things

parent 1f5a0031
No related branches found
No related tags found
No related merge requests found
......@@ -155,11 +155,11 @@ class Rootable:
if isinstance(index, str):
return self.data[index]
return {key: value[index] for key, value in self.data.items()}
def __setitem__(self, index: str | int | ArrayLike, value: dict | Any) -> None:
"""
Allows setting the value of a column by using strings as keywords,
setting the value of a row by using integer indices or arrays,
setting the value of a row by using integer indices or arrays,
or setting a specific value using a tuple of key and index.
:param index: The column name, row index, or tuple of key and index.
:param value: The value to set.
......@@ -194,7 +194,7 @@ class Rootable:
key, op, value = match.groups()
op = op.strip() # remove any leading and trailing spaces
if op == 'in':
value = eval(value)
mask &= np.isin(self.data[key], value)
......@@ -230,11 +230,11 @@ class Rootable:
numRows = len(self.data[keys[0]])
for i in range(numRows):
yield {key: self.data[key][i] for key in keys}
yield {key: self.data[key][i] for key in keys}
def keys(self) -> list:
return list(self.data.keys())
def items(self) -> list:
return self.data.items()
......@@ -243,10 +243,10 @@ class Rootable:
def get(self, key: str) -> np.ndarray:
return self.data.get(key)
def pop(self, key: str) -> None:
return self.data.pop(key)
def stack(self, *columns, toKey: str, pop: bool = True) -> None:
"""
Stacks specified columns into a single column and stores it under a new key.
......@@ -258,32 +258,45 @@ class Rootable:
for column in columns:
if column not in self.data:
raise KeyError(f"Column '{column}' does not exist.")
# Column stack the specified columns
stackedColumn = np.column_stack([self.data[col] for col in columns])
# Flatten if it's 1D for consistency
if stackedColumn.shape[1] == 1:
stackedColumn = stackedColumn.flatten()
# Store it under the new key
self.data[toKey] = stackedColumn
# Remove the original columns if pop is True
if pop:
for column in columns:
self.data.pop(column)
def loadData(self, file: str, events: int = None, selection: str = None) -> None:
def loadData(self, fileName: str, events: int = None, selection: str = None) -> None:
"""
Reads the file off of the hard drive; it automatically creates event numbers.
file: str = it's the whole file path + .root ending
events: int = the number of events to import (None for all)
selection: str = method of event selection ('random' for random selection)
"""
self.eventTree = ur.open(f'{file}:tree')
file, _, treeName = fileName.partition(':')
if not file.endswith('.root'):
file += '.root'
if not treeName:
treeName = 'tree'
try: # checking if file exists
with open(file, 'r') as f:
self.eventTree = ur.open(f'{file}:{treeName}')
except FileNotFoundError:
raise FileNotFoundError(f"File {file} not found.")
numEvents = len(self.eventTree.arrays('PXDClusters/PXDClusters.m_clsCharge', library='np')['PXDClusters/PXDClusters.m_clsCharge'])
if events is not None:
if selection == 'random':
self.eventIndices = np.random.permutation(numEvents)[:events]
......@@ -292,14 +305,17 @@ class Rootable:
clusters = self.eventTree.arrays('PXDClusters/PXDClusters.m_clsCharge', library='np')['PXDClusters/PXDClusters.m_clsCharge'][self.eventIndices]
else:
clusters = self.eventTree.arrays('PXDClusters/PXDClusters.m_clsCharge', library='np')['PXDClusters/PXDClusters.m_clsCharge']
self._getEventNumbers(clusters)
def _getEventNumbers(self, clusters: np.ndarray, offset: int = 0) -> None:
"""
this generates event numbers from the structure of pxd clusters
"""
eventNumbers = []
for i in range(len(clusters)):
eventNumbers.append(np.array([i]*len(clusters[i])) + offset)
self.data['eventNumber'] = self._flatten(eventNumbers)
self.data['eventNumber'] = np.hstack(eventNumbers)
def _getData(self, keyword: str, library: str = 'np') -> np.ndarray:
"""
......@@ -314,28 +330,10 @@ class Rootable:
data = self.eventTree.arrays(keyword, library=library)[keyword][self.eventIndices]
else:
data = self.eventTree.arrays(keyword, library=library)[keyword]
return self._flatten(data)
return np.hstack(data)
except:
return KeyError
def _flatten(self, structure: ArrayLike, maxDepth: int = None, currentDepth: int = 0) -> np.ndarray:
"""
this is a private function, that gets called during loading branches
it flattens ragged array, one can set the depths to which one wants to flatten
structure: the list/array to flatten
maxDepth: int = the amount of flattening
currentDepth: int = don't touch this, it's used for recursively calling
"""
flat_list = []
for element in structure:
if isinstance(element, (list, np.ndarray)) and (maxDepth is None or currentDepth < maxDepth):
flat_list.extend(self._flatten(element, maxDepth, currentDepth + 1))
else:
flat_list.append(element)
return np.array(flat_list)
def getClusters(self) -> None:
"""
this uses the array from __init__ to load different branches into the data dict
......@@ -466,18 +464,18 @@ class Rootable:
# Checking if coordinates have been loaded
if self.gotClusters is False:
self.getCoordinates()
xSquare = np.square(self.data['xPosition'])
ySquare = np.square(self.data['yPosition'])
zSquare = np.square(self.data['zPosition'])
# Avoid division by zero by replacing zeros with a small number
r = np.sqrt(xSquare + ySquare + zSquare)
rSafe = np.where(r == 0, 1e-10, r)
theta = np.arccos(self.data['zPosition'] / rSafe)
phi = np.arctan2(self.data['yPosition'], self.data['xPosition'])
self.data['rPosition'] = r
self.data['thetaPosition'] = theta
self.data['phiPosition'] = phi
......@@ -547,11 +545,11 @@ class Rootable:
momentumZList.append(zmom)
pdgList.append(pdgs)
self.data['momentumX'] = self._flatten(momentumXList)
self.data['momentumY'] = self._flatten(momentumYList)
self.data['momentumZ'] = self._flatten(momentumZList)
self.data['pdg'] = self._flatten(pdgList)
self.data['clsNumber'] = self._flatten(clusterNumbersList)
self.data['momentumX'] = np.hstack(momentumXList)
self.data['momentumY'] = np.hstack(momentumYList)
self.data['momentumZ'] = np.hstack(momentumZList)
self.data['pdg'] = np.hstack(pdgList)
self.data['clsNumber'] = np.hstack(clusterNumbersList)
@staticmethod
def _findMissing(lst: list, length: int) -> list:
......
import unittest
from ..rootable import Rootable
fromShit = Rootable()
fromShit.loadData('../fuckbasf2/root-files/slow_pions.root')
fromShit.getClusters()
fromShit.genCoordisnate()
fromShit.getLayers()
fromShit.getMatrices()
fromShit.getMCData()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment