Skip to content
Snippets Groups Projects
tree-test.py 3.98 KiB
Newer Older
johannes bilk's avatar
johannes bilk committed
import numpy as np
import sys, random
from matplotlib import pyplot as plt
from machineLearning.rf import (
    DecisionTree,
    Gini, Entropy, MAE, MSE,
    Mode, Mean,
    CART, ID3, C45,
    UsersChoice, Variance, Random, MutualInformation, ANOVA, KendallTau
)
from machineLearning.metric import ConfusionMatrix
from machineLearning.utility import Time
from machineLearning.settings.treeSettings import TreeSettings
from machineLearning.data import Data
johannes bilk's avatar
johannes bilk committed


def dataShift(dims):
    offSet = [0.25, 0.5, 0.25]
    diffLen = abs(len(offSet) - dims)
    offSet.extend([0] * diffLen)
    random.shuffle(offSet)
    return offSet[:dims]


def getImurity(impurity: str):
    if impurity == 'gini':
        return Gini() # Use Gini index as the impurity measure
    elif impurity == 'entropy':
        return Entropy() # Use Entropy index as the impurity measure
    elif impurity == 'mae':
        return MAE() # Use MAE index as the impurity measure
    elif impurity == 'mse':
        return MSE() # Use MSE index as the impurity measure


def getLeaf(leaf: str):
    if leaf == 'mode':
        return Mode() # Use mode as the leaf function
    elif leaf == 'mean':
        return Mean() # Use mean as the leaf function


def getSplit(split: str, percentile: int = None):
    if split == 'id3':
        return ID3(percentile) # Use ID3 algorithm for splitting
    elif split == 'c45':
        return C45(percentile) # Use C4.5 algorithm for splitting
    elif split == 'cart':
        return CART(percentile) # Use CART algorithm for splitting


def getFeatureSelection(selection: str, *args):
    if selection == 'choice':
        return UsersChoice(*args)
    elif selection == 'variance':
        return Variance(*args)
    elif selection == 'random':
        return Random(*args)
    elif selection == 'mutual':
        return MutualInformation(*args)
    elif selection == 'anova':
        return ANOVA(*args)
    elif selection == 'kendall':
        return KendallTau(*args)


if __name__ == "__main__":
    settings = TreeSettings()
    try:
        configFile = sys.argv[1]
        settings.getConfig(configFile)
        settings.setConfig()
    except IndexError:
        pass
    print(settings)

    # Create a timer object to measure execution time
    timer = Time()

    print("Importing data...\n")
    timer.start()
    data = Data(trainAmount=settings['trainAmount'], evalAmount=settings['validAmount'], dataPath=settings['dataPath'], normalize=settings['normalize'])
    data.inputFeatures(*settings['features'])
    data.importData(*settings['dataFiles'])
    print(data)
    timer.record("Importing Data")

    # Create and train a decision tree
    timer.start()
    print('Setting up tree')
    tree = DecisionTree(settings['depth'], settings['minSamples'])
    tree.setComponent(getImurity(settings['impurity']))
    tree.setComponent(getLeaf(settings['leaf']))
    tree.setComponent(getSplit(settings['split'],settings['percentile']))
    if settings['featSelection'] is not None:
        tree.setFeatureSelection(settings['featSelection'], settings['featParameter']) # Use random feature selection
    timer.record("Tree setup")

    # Train the tree using the training data
    timer.start()
    print('begin training...')
    #tree.train(trainData,trainLabels)
    tree.train(data.trainSet.data,data.trainSet.labels.argmax(1))
    timer.record("Training")

    # Evaluate the tree on the validation data
    timer.start()
    print('making predictions...')
    #prediction = tree.eval(validData)
    prediction = tree.eval(data.evalSet.data)
    timer.record("Prediction")

    # Print the trained decision tree
    print(tree)
    print()

    # Compute confusion matrix to evaluate the performance of the decision tree
    confusion = ConfusionMatrix(2)
    #confusion.update(prediction, validLabels)
    confusion.update(prediction, data.evalSet.labels.argmax(1))
    confusion.percentages()
    confusion.calcScores()

    # Compute confusion matrix to evaluate the performance of the decision tree
    print(confusion)
    print()
    print(timer)