Skip to content
Snippets Groups Projects
forrest-test.py 4.62 KiB
Newer Older
  • Learn to ignore specific revisions
  • johannes bilk's avatar
    johannes bilk committed
    import numpy as np
    import sys
    from rf.randomForrest import RandomForest
    from rf.decisionTree import DecisionTree
    from rf.impurityMeasure import Gini, Entropy, MAE, MSE
    from rf.leafFunction import Mode, Mean
    from rf.featureSelection import UsersChoice, Variance, Random, MutualInformation, ANOVA, KendallTau
    from rf.splitAlgorithm import CART, ID3, C45
    from metric.confusionMatrix import ConfusionMatrix
    from utility.timer import Time
    from settings.forrestSettings import ForrestSettings
    from rf.voting import Majority, Confidence, Average, Median
    from data.data import Data
    from rf.boosting import AdaBoosting, GradientBoosting
    
    
    def getImurity(impurity: str):
        if impurity == 'gini':
            return Gini() # Use Gini index as the impurity measure
        elif impurity == 'entropy':
            return Entropy() # Use Entropy index as the impurity measure
        elif impurity == 'mae':
            return MAE() # Use MAE index as the impurity measure
        elif impurity == 'mse':
            return MSE() # Use MSE index as the impurity measure
    
    
    def getLeaf(leaf: str):
        if leaf == 'mode':
            return Mode() # Use mode as the leaf function
        elif leaf == 'mean':
            return Mean() # Use mean as the leaf function
    
    
    def getSplit(split: str, percentile: int = None):
        if split == 'id3':
            return ID3(percentile) # Use ID3 algorithm for splitting
        elif split == 'c45':
            return C45(percentile) # Use C4.5 algorithm for splitting
        elif split == 'cart':
            return CART(percentile) # Use CART algorithm for splitting
    
    
    def getVoting(voting: str, weights: list):
        if voting == 'majority':
            return Majority(weights)
        elif voting == 'confidence':
            return Confidence(weights)
        elif voting == 'average':
            return Average(weights)
        elif voting == 'median':
            return Median(weights)
    
    
    def getFeatureSelection(selection: str, *args):
        if selection == 'choice':
            return UsersChoice(*args)
        elif selection == 'variance':
            return Variance(*args)
        elif selection == 'random':
            return Random(*args)
        elif selection == 'mutual':
            return MutualInformation(*args)
        elif selection == 'anova':
            return ANOVA(*args)
        elif selection == 'kendall':
            return KendallTau(*args)
    
    
    def getBooster(booster: str):
        if booster == 'adaptive':
            return AdaBoosting()
        elif booster == 'gradient':
            return GradientBoosting()
    
    
    if __name__ == "__main__":
        settings = ForrestSettings()
        try:
            configFile = sys.argv[1]
            settings.getConfig(configFile)
            settings.setConfig()
        except IndexError:
            pass
        print(settings)
    
        # Create a timer object to measure execution time
        timer = Time()
    
        print("Importing data...\n")
        timer.start()
        data = Data(trainAmount=settings['trainAmount'], evalAmount=settings['validAmount'], dataPath=settings['dataPath'], normalize=settings['normalize'])
        data.inputFeatures(*settings['features'])
        data.importData(*settings['dataFiles'])
        print(data)
        timer.record("Importing Data")
    
        # Set up random forest
        timer.start()
        print("setting up forrest")
        forrest = RandomForest(bootstrapping=settings['bootstraping'], retrainFirst=settings['retrainFirst'])
        if settings['booster'] is not None:
            forrest.setComponent(getBooster(settings['booster']))
        if settings['voting'] is not None:
            forrest.setComponent(getVoting(settings['voting'], settings['votingWeights']))
        for i in range(settings['numTrees']):
            tree = DecisionTree(settings['depth'][i], settings['minSamples'][i])
            tree.setComponent(getImurity(settings['impurity'][i]))
            tree.setComponent(getLeaf(settings['leaf'][i]))
            tree.setComponent(getSplit(settings['split'][i], settings['percentile'][i]))
            if settings['featSelection'][i] is not None:
                tree.setComponent(getFeatureSelection(settings['featSelection'][i], settings['featParameter'][i]))
            forrest.append(tree)
        timer.record("Forrest setup")
    
        # Train the random forest
        timer.start()
        print("begin training")
        forrest.train(data.trainSet.data,data.trainSet.labels.argmax(1))
        timer.record("Training")
    
        # Evaluate the random forest
        timer.start()
        print("making predictions\n")
        #prediction = forrest.eval(validData)
        prediction = tree.eval(data.evalSet.data)
        timer.record("Prediction")
        print(forrest)
        print()
    
        # Calculate and print confusion matrix
        confusion = ConfusionMatrix(2)
        confusion.update(prediction, data.evalSet.labels.argmax(1))
        confusion.percentages()
        confusion.calcScores()
        print(confusion)
        print()
    
        # Print total execution time
        print(timer)