Skip to content
Snippets Groups Projects
forrest-test.ipynb 75.3 KiB
Newer Older
johannes bilk's avatar
johannes bilk committed
{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Testing the Forrest\n",
    "\n",
    "## Importing the Basics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import random\n",
    "from matplotlib import pyplot as plt\n",
    "from machineLearning.metric import ConfusionMatrix, RegressionScores\n",
    "from machineLearning.utility import ModelIO\n",
    "from machineLearning.rf import (\n",
    "    RandomForest, DecisionTree,\n",
    "    Gini, Entropy, MAE, MSE,\n",
johannes bilk's avatar
johannes bilk committed
    "    Mode, Mean, Confidence, Probabilities,\n",
    "    CART, ID3, C45,\n",
    "    AdaBoosting, GradientBoosting,\n",
    "    Majority, Confidence, Average, Median\n",
    ")"
johannes bilk's avatar
johannes bilk committed
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generating Test Data\n",
    "\n",
    "Here I generate random test data. It's two blocks shifted very slightly in some dimensions. For classifier tasks each block gets a label, for regressor tasks each block gets the average coordinates plus some random value as a traget. It's a very simple dummy data set meant for testing the code.\n",
    "\n",
    "Here one can change the dimensionallity and amount of the data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dataShift(dims):\n",
    "    offSet = [5, 1.5, 2.5]\n",
    "    diffLen = abs(len(offSet) - dims)\n",
    "    offSet.extend([0] * diffLen)\n",
    "    random.shuffle(offSet)\n",
    "    return offSet[:dims]\n",
    "\n",
    "# Initialize some parameters\n",
    "totalAmount = 6400\n",
    "dims = 5\n",
    "evalAmount = totalAmount // 4\n",
    "trainAmount = totalAmount - evalAmount\n",
    "offSet = dataShift(dims)\n",
    "\n",
    "# Create covariance matrix\n",
    "cov = np.eye(dims)  # This creates a covariance matrix with variances 1 and covariances 0\n",
    "\n",
    "# Generate random multivariate data\n",
    "oneData = np.random.multivariate_normal(np.zeros(dims), cov, totalAmount)\n",
    "twoData = np.random.multivariate_normal(offSet, cov, totalAmount)\n",
    "\n",
    "# Split the data into training and evaluation sets\n",
    "trainData = np.vstack((oneData[:trainAmount], twoData[:trainAmount]))\n",
    "validData = np.vstack((oneData[trainAmount:], twoData[trainAmount:]))\n",
    "\n",
    "# Labels for classification tasks\n",
    "trainLabels = np.hstack((np.zeros(trainAmount), np.ones(trainAmount)))\n",
    "validLabels = np.hstack((np.zeros(evalAmount), np.ones(evalAmount)))\n",
    "\n",
    "# Targets for regression tasks\n",
    "trainTargets = np.sum(trainData, axis=1) + np.random.normal(0, 0.1, 2*trainAmount)\n",
    "validTargets = np.sum(validData, axis=1) + np.random.normal(0, 0.1, 2*evalAmount)\n",
    "\n",
    "# Shuffle the training data\n",
    "trainIndex = np.random.permutation(len(trainData))\n",
    "trainData = trainData[trainIndex]\n",
johannes bilk's avatar
johannes bilk committed
    "trainLabels = trainLabels[trainIndex]\n",
    "trainTargets = trainTargets[trainIndex]"
johannes bilk's avatar
johannes bilk committed
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Creating the Forrest\n",
    "\n",
    "Here the forrest is created. One can set the number of trees and set the maximum depth. Depending on the task, we add a different impurity function and a different leaf function. Finally we add the split algorithm and set the feature percentile. Higher numbers look at more possible splits, but decreases speed. Lower numbers look at less possible splits, speeding up the algorithm. Depending on the data set this can have a strong impact on the performance.\n",
    "\n",
    "One can set a different depth, leaf function, splitting algorithm and impurity function for each tree. Here in this simple case we create all trees with same parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "task = 'classifier' # 'classifier'/'regressor'\n",
    "forrest = RandomForest(bootstrapping=False, retrainFirst=False)\n",
    "#forrest.setComponent(GradientBoosting())\n",
johannes bilk's avatar
johannes bilk committed
    "forrest.setComponent(Majority())\n",
johannes bilk's avatar
johannes bilk committed
    "    tree = DecisionTree(maxDepth=7, minSamplesSplit=2)\n",
    "    if task == 'regressor':\n",
    "        tree.setComponent(MSE())\n",
    "        tree.setComponent(Mean())\n",
    "    elif task == 'classifier':\n",
    "        tree.setComponent(Entropy())\n",
    "        tree.setComponent(Mode())\n",
    "    tree.setComponent(CART(featurePercentile=90))\n",
    "    forrest.append(tree)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Trainining the tree\n",
    "\n",
    "Again, depending on the task we train the forrest with targets or labels. Then we make a prediction and plot the tree."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tree 1 |⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿| done ✔                  | 18%\n",
      "tree 2 |⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿| done ✔                  | 18%\n",
      "tree 3 |⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿| done ✔                  | 18%\n",
      "tree 4 |⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿| done ✔                  | 18%\n",
      "tree 5 |⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿| done ✔                  | 18%\n",
      "━━━━━━━━━━━━━━━━━━━━━━━━━━━━ forrest ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n",
      "voting: Majority, booster: GradientBoosting, bootstrapping: False\n",
johannes bilk's avatar
johannes bilk committed
      "\n",
      "—————————————————————— tree: 1/5 ———————————————————————\n",
      "split: CART, impurity: Entropy, leaf: Mode, nodes: 47\n",
      "maxDepth: 7, reached depth: 7, minSamplesSplit: 2\n",
      "························································\n",
      "╴feat: 3 <= 2.29, samples: 9600\n",
      "     ├─feat: 1 <= 2.33, samples: 4747\n",
      "     │   ├─feat: 3 <= 2.06, samples: 4694\n",
      "     │   │   └─╴value: 0.0\n",
      "     │   │   └─╴feat: 0 <= 1.10, samples: 52\n",
      "     │   │       └─╴value: 0.0\n",
      "     │   │       └─╴feat: 4 <= -0.88, samples: 7\n",
      "     │   │           └─╴value: 1.0\n",
      "     │   │           └─╴feat: 4 <= 0.63, samples: 6\n",
      "     │   │               └─╴value: 0.0\n",
      "     │   │               └─╴value: 1.0\n",
      "     │   └─╴feat: 3 <= 1.29, samples: 53\n",
      "     │       └─╴value: 0.0\n",
      "     │       └─╴feat: 0 <= 0.35, samples: 8\n",
      "     │           └─╴value: 0.0\n",
      "     └─╴feat: 3 <= 3.34, samples: 4853\n",
      "         ├─feat: 1 <= 1.46, samples: 267\n",
      "         │   ├─feat: 0 <= 0.93, samples: 82\n",
      "         │   │   ├─feat: 3 <= 2.62, samples: 59\n",
      "         │   │   │   └─╴value: 0.0\n",
      "         │   │   │   └─╴feat: 1 <= 0.09, samples: 29\n",
      "         │   │   │       └─╴value: 0.0\n",
      "         │   │   │       └─╴value: 0.0\n",
      "         │   │   └─╴feat: 1 <= -0.32, samples: 23\n",
      "         │   │       ├─feat: 2 <= -0.30, samples: 4\n",
      "         │   │       │   └─╴value: 0.0\n",
      "         │   │       │   └─╴value: 1.0\n",
      "         │   │       └─╴feat: 3 <= 2.46, samples: 19\n",
      "         │   │           └─╴value: 1.0\n",
      "         │   │           └─╴value: 1.0\n",
      "         │   └─╴feat: 0 <= -0.64, samples: 185\n",
      "         │       ├─feat: 3 <= 2.48, samples: 5\n",
      "         │       │   └─╴value: 0.0\n",
      "         │       │   └─╴value: 1.0\n",
      "         │       └─╴feat: 4 <= 1.87, samples: 180\n",
      "         │           └─╴value: 1.0\n",
      "         │           └─╴feat: 3 <= 2.61, samples: 4\n",
      "         │               └─╴value: 0.0\n",
      "         │               └─╴value: 1.0\n",
      "         └─╴feat: 3 <= 3.57, samples: 4586\n",
      "             ├─feat: 1 <= 0.83, samples: 152\n",
      "             │   ├─feat: 0 <= -0.07, samples: 9\n",
      "             │   │   └─╴value: 1.0\n",
      "             │   └─╴value: 1.0\n",
      "             └─╴value: 1.0\n",
      "\n",
      "—————————————————————— tree: 2/5 ———————————————————————\n",
      "split: CART, impurity: Entropy, leaf: Mode, nodes: 47\n",
      "maxDepth: 7, reached depth: 7, minSamplesSplit: 2\n",
      "························································\n",
      "╴feat: 3 <= 2.29, samples: 9600\n",
      "     ├─feat: 1 <= 2.33, samples: 4747\n",
      "     │   ├─feat: 3 <= 2.06, samples: 4694\n",
      "     │   │   └─╴value: 0.0\n",
      "     │   │   └─╴feat: 0 <= 1.10, samples: 52\n",
      "     │   │       └─╴value: 0.0\n",
      "     │   │       └─╴feat: 0 <= 1.14, samples: 7\n",
      "     │   │           └─╴value: 1.0\n",
      "     │   │           └─╴feat: 4 <= 0.63, samples: 6\n",
      "     │   │               └─╴value: 0.0\n",
      "     │   │               └─╴value: 1.0\n",
      "     │   └─╴feat: 3 <= 1.29, samples: 53\n",
      "     │       └─╴value: 0.0\n",
      "     │       └─╴feat: 0 <= 0.35, samples: 8\n",
      "     │           └─╴value: 0.0\n",
      "     └─╴feat: 3 <= 3.34, samples: 4853\n",
      "         ├─feat: 1 <= 1.46, samples: 267\n",
      "         │   ├─feat: 0 <= 0.93, samples: 82\n",
      "         │   │   ├─feat: 3 <= 2.62, samples: 59\n",
      "         │   │   │   └─╴value: 0.0\n",
      "         │   │   │   └─╴feat: 1 <= 0.09, samples: 29\n",
      "         │   │   │       └─╴value: 0.0\n",
      "         │   │   │       └─╴value: 0.0\n",
      "         │   │   └─╴feat: 1 <= -0.32, samples: 23\n",
      "         │   │       ├─feat: 2 <= -0.30, samples: 4\n",
      "         │   │       │   └─╴value: 0.0\n",
      "         │   │       │   └─╴value: 1.0\n",
      "         │   │       └─╴feat: 3 <= 2.46, samples: 19\n",
      "         │   │           └─╴value: 1.0\n",
      "         │   │           └─╴value: 1.0\n",
      "         │   └─╴feat: 0 <= -0.64, samples: 185\n",
      "         │       ├─feat: 3 <= 2.48, samples: 5\n",
      "         │       │   └─╴value: 0.0\n",
      "         │       │   └─╴value: 1.0\n",
      "         │       └─╴feat: 4 <= 1.87, samples: 180\n",
      "         │           └─╴value: 1.0\n",
      "         │           └─╴feat: 1 <= 1.86, samples: 4\n",
      "         │               └─╴value: 0.0\n",
      "         │               └─╴value: 1.0\n",
      "         └─╴feat: 3 <= 3.57, samples: 4586\n",
      "             ├─feat: 1 <= 0.83, samples: 152\n",
      "             │   ├─feat: 0 <= -0.07, samples: 9\n",
      "             │   │   └─╴value: 1.0\n",
      "             │   └─╴value: 1.0\n",
      "             └─╴value: 1.0\n",
      "\n",
      "—————————————————————— tree: 3/5 ———————————————————————\n",
      "split: CART, impurity: Entropy, leaf: Mode, nodes: 47\n",
      "maxDepth: 7, reached depth: 7, minSamplesSplit: 2\n",
      "························································\n",
      "╴feat: 3 <= 2.29, samples: 9600\n",
      "     ├─feat: 1 <= 2.33, samples: 4747\n",
      "     │   ├─feat: 3 <= 2.06, samples: 4694\n",
      "     │   │   └─╴value: 0.0\n",
      "     │   │   └─╴feat: 0 <= 1.10, samples: 52\n",
      "     │   │       └─╴value: 0.0\n",
      "     │   │       └─╴feat: 3 <= 2.22, samples: 7\n",
      "     │   │           ├─feat: 4 <= 0.63, samples: 6\n",
      "     │   │           │   └─╴value: 0.0\n",
      "     │   │           │   └─╴value: 1.0\n",
      "     │   │           └─╴value: 1.0\n",
      "     │   └─╴feat: 3 <= 1.29, samples: 53\n",
      "     │       └─╴value: 0.0\n",
      "     │       └─╴feat: 1 <= 2.37, samples: 8\n",
      "     │           └─╴value: 0.0\n",
      "     └─╴feat: 3 <= 3.34, samples: 4853\n",
      "         ├─feat: 1 <= 1.46, samples: 267\n",
      "         │   ├─feat: 0 <= 0.93, samples: 82\n",
      "         │   │   ├─feat: 3 <= 2.62, samples: 59\n",
      "         │   │   │   └─╴feat: 1 <= 0.09, samples: 29\n",
      "         │   │   │       └─╴value: 0.0\n",
      "         │   │   │       └─╴value: 0.0\n",
      "         │   │   └─╴feat: 1 <= -0.32, samples: 23\n",
      "         │   │       ├─feat: 2 <= -0.30, samples: 4\n",
      "         │   │       │   └─╴value: 0.0\n",
      "         │   │       │   └─╴value: 1.0\n",
      "         │   │       └─╴feat: 3 <= 2.46, samples: 19\n",
      "         │   │           └─╴value: 1.0\n",
      "         │   │           └─╴value: 1.0\n",
      "         │   └─╴feat: 0 <= -0.64, samples: 185\n",
      "         │       ├─feat: 3 <= 2.48, samples: 5\n",
      "         │       │   └─╴value: 0.0\n",
      "         │       │   └─╴value: 1.0\n",
      "         │       └─╴feat: 4 <= 1.87, samples: 180\n",
      "         │           └─╴value: 1.0\n",
      "         │           └─╴feat: 3 <= 2.61, samples: 4\n",
      "         │               └─╴value: 0.0\n",
      "         │               └─╴value: 1.0\n",
      "         └─╴feat: 3 <= 3.57, samples: 4586\n",
      "             ├─feat: 1 <= 0.83, samples: 152\n",
      "             │   ├─feat: 0 <= -0.07, samples: 9\n",
      "             │   │   └─╴value: 1.0\n",
      "             │   └─╴value: 1.0\n",
      "             └─╴value: 1.0\n",
      "\n",
      "—————————————————————— tree: 4/5 ———————————————————————\n",
      "split: CART, impurity: Entropy, leaf: Mode, nodes: 47\n",
      "maxDepth: 7, reached depth: 7, minSamplesSplit: 2\n",
      "························································\n",
      "╴feat: 3 <= 2.29, samples: 9600\n",
      "     ├─feat: 1 <= 2.33, samples: 4747\n",
      "     │   ├─feat: 3 <= 2.06, samples: 4694\n",
      "     │   │   └─╴value: 0.0\n",
      "     │   │   └─╴feat: 0 <= 1.10, samples: 52\n",
      "     │   │       └─╴value: 0.0\n",
      "     │   │       └─╴feat: 0 <= 1.14, samples: 7\n",
      "     │   │           └─╴value: 1.0\n",
      "     │   │           └─╴feat: 4 <= 0.63, samples: 6\n",
      "     │   │               └─╴value: 0.0\n",
      "     │   │               └─╴value: 1.0\n",
      "     │   └─╴feat: 3 <= 1.29, samples: 53\n",
      "     │       └─╴value: 0.0\n",
      "     │       └─╴feat: 1 <= 2.37, samples: 8\n",
      "     │           └─╴value: 0.0\n",
      "     └─╴feat: 3 <= 3.34, samples: 4853\n",
      "         ├─feat: 1 <= 1.46, samples: 267\n",
      "         │   ├─feat: 0 <= 0.93, samples: 82\n",
      "         │   │   ├─feat: 3 <= 2.62, samples: 59\n",
      "         │   │   │   └─╴value: 0.0\n",
      "         │   │   │   └─╴feat: 1 <= 0.09, samples: 29\n",
      "         │   │   │       └─╴value: 0.0\n",
      "         │   │   │       └─╴value: 0.0\n",
      "         │   │   └─╴feat: 1 <= -0.32, samples: 23\n",
      "         │   │       ├─feat: 2 <= -0.30, samples: 4\n",
      "         │   │       │   └─╴value: 0.0\n",
      "         │   │       │   └─╴value: 1.0\n",
      "         │   │       └─╴feat: 3 <= 2.46, samples: 19\n",
      "         │   │           └─╴value: 1.0\n",
      "         │   │           └─╴value: 1.0\n",
      "         │   └─╴feat: 0 <= -0.64, samples: 185\n",
      "         │       ├─feat: 4 <= 0.49, samples: 5\n",
      "         │       │   └─╴value: 0.0\n",
      "         │       │   └─╴value: 1.0\n",
      "         │       └─╴feat: 4 <= 1.87, samples: 180\n",
      "         │           └─╴value: 1.0\n",
      "         │           └─╴feat: 3 <= 2.61, samples: 4\n",
      "         │               └─╴value: 0.0\n",
      "         │               └─╴value: 1.0\n",
      "         └─╴feat: 3 <= 3.57, samples: 4586\n",
      "             ├─feat: 1 <= 0.83, samples: 152\n",
      "             │   ├─feat: 0 <= -0.07, samples: 9\n",
      "             │   │   └─╴value: 1.0\n",
      "             │   └─╴value: 1.0\n",
      "             └─╴value: 1.0\n",
      "\n",
      "—————————————————————— tree: 5/5 ———————————————————————\n",
      "split: CART, impurity: Entropy, leaf: Mode, nodes: 47\n",
      "maxDepth: 7, reached depth: 7, minSamplesSplit: 2\n",
      "························································\n",
      "╴feat: 3 <= 2.29, samples: 9600\n",
      "     ├─feat: 1 <= 2.33, samples: 4747\n",
      "     │   ├─feat: 3 <= 2.06, samples: 4694\n",
      "     │   │   └─╴value: 0.0\n",
      "     │   │   └─╴feat: 0 <= 1.10, samples: 52\n",
      "     │   │       └─╴value: 0.0\n",
      "     │   │       └─╴feat: 0 <= 1.14, samples: 7\n",
      "     │   │           └─╴value: 1.0\n",
      "     │   │           └─╴feat: 4 <= 0.63, samples: 6\n",
      "     │   │               └─╴value: 0.0\n",
      "     │   │               └─╴value: 1.0\n",
      "     │   └─╴feat: 3 <= 1.29, samples: 53\n",
      "     │       └─╴value: 0.0\n",
      "     │       └─╴feat: 0 <= 0.35, samples: 8\n",
      "     │           └─╴value: 0.0\n",
      "     └─╴feat: 3 <= 3.34, samples: 4853\n",
      "         ├─feat: 1 <= 1.46, samples: 267\n",
      "         │   ├─feat: 0 <= 0.93, samples: 82\n",
      "         │   │   ├─feat: 3 <= 2.62, samples: 59\n",
johannes bilk's avatar
johannes bilk committed
      "         │   │   │   └─╴value: 0.0\n",
      "         │   │   │   └─╴feat: 1 <= 0.09, samples: 29\n",
      "         │   │   │       └─╴value: 0.0\n",
      "         │   │   │       └─╴value: 0.0\n",
      "         │   │   └─╴feat: 1 <= -0.32, samples: 23\n",
      "         │   │       ├─feat: 2 <= -0.30, samples: 4\n",
      "         │   │       │   └─╴value: 0.0\n",
      "         │   │       │   └─╴value: 1.0\n",
      "         │   │       └─╴feat: 3 <= 2.46, samples: 19\n",
      "         │   │           └─╴value: 1.0\n",
      "         │   │           └─╴value: 1.0\n",
      "         │   └─╴feat: 0 <= -0.64, samples: 185\n",
      "         │       ├─feat: 3 <= 2.48, samples: 5\n",
      "         │       │   └─╴value: 0.0\n",
      "         │       │   └─╴value: 1.0\n",
      "         │       └─╴feat: 4 <= 1.87, samples: 180\n",
      "         │           └─╴value: 1.0\n",
      "         │           └─╴feat: 1 <= 1.86, samples: 4\n",
      "         │               └─╴value: 0.0\n",
      "         │               └─╴value: 1.0\n",
      "         └─╴feat: 3 <= 3.57, samples: 4586\n",
      "             ├─feat: 1 <= 0.83, samples: 152\n",
      "             │   ├─feat: 0 <= -0.07, samples: 9\n",
johannes bilk's avatar
johannes bilk committed
      "             │   │   └─╴value: 0.0\n",
      "             │   │   └─╴value: 1.0\n",
      "             │   └─╴value: 1.0\n",
johannes bilk's avatar
johannes bilk committed
      "             └─╴value: 1.0\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "if task == 'regressor':\n",
    "    forrest.train(trainData, trainTargets)\n",
    "elif task == 'classifier':\n",
    "    forrest.train(trainData,trainLabels)\n",
    "forrest.bake()\n",
    "prediction = forrest.eval(validData)\n",
    "print(forrest)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Accuracy(name='tree: 0', accuracy=0.9959375),\n",
       " Accuracy(name='tree: 1', accuracy=0.9959375),\n",
       " Accuracy(name='tree: 2', accuracy=0.9953125),\n",
       " Accuracy(name='tree: 3', accuracy=0.9953125),\n",
       " Accuracy(name='tree: 4', accuracy=0.9959375)]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "forrest.accuracy(validData, validLabels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Create bar plot\n",
    "plt.bar(np.arange(len(forrest.featureImportance)), forrest.featureImportance, color='steelblue')\n",
    "\n",
    "# Add labels and title\n",
    "plt.xlabel('Feature Index')\n",
    "plt.ylabel('Importance')\n",
    "plt.title('Feature Importance')\n",
    "\n",
    "# Add grid\n",
    "plt.grid(True, linestyle='--', alpha=0.6)\n",
    "\n",
    "# Show plot\n",
    "plt.show()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluating predictions\n",
    "\n",
    "Depending on the task at hand we create a confusion matrix (classification) or simple metrics (regression). Since the number of classes is fixed to two, we don't need to change anything here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "━━━━━━━━━━━━ evaluation ━━━━━━━━━━━━\n",
      "————————— confusion matrix —————————\n",
      "              Class 0     Class 1   \n",
      "····································\n",
      "     Class 0    1591         9      \n",
      "                49%          0%     \n",
      "····································\n",
      "     Class 1     4          1596    \n",
      "                 0%         49%     \n",
johannes bilk's avatar
johannes bilk committed
      "\n",
      "———————————————————————————————— scores ———————————————————————————————\n",
      "                accuracy       precision      sensitivity      miss rate    \n",
      "·······································································\n",
      "     Class 0     0.996           0.997           0.994           0.006      \n",
      "     Class 1     0.996           0.994           0.998           0.003      \n",
      "·······································································\n",
      "       total     0.996           0.996           0.996           0.004      \n"
     ]
    }
   ],
   "source": [
    "if task == 'regressor':\n",
    "    metrics = RegressionScores(numClasses=2)\n",
    "    metrics.calcScores(prediction, validTargets, validLabels)\n",
    "    print(metrics)\n",
    "elif task == 'classifier':\n",
    "    confusion = ConfusionMatrix(numClasses=2)\n",
    "    confusion.update(prediction, validLabels)\n",
    "    confusion.percentages()\n",
    "    confusion.calcScores()\n",
    "    print(confusion)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Saving and Loading a Forrest\n",
    "\n",
    "Forrests can be converted to dictionaries and then saved as a json file. This allows us to load them and re-use them. Also json is a raw text format, which is neat."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "━━━━━━━━━━━━━━━━━━━━━━━━━━━━ forrest ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n",
      "voting: Majority, booster: GradientBoosting, bootstrapping: False\n",
johannes bilk's avatar
johannes bilk committed
      "\n",
      "————————————————————— tree: 01/15 ——————————————————————\n",
      "split: CART, impurity: Entropy, leaf: Mode, nodes: 47\n",
      "maxDepth: 7, reached depth: 7, minSamplesSplit: 2\n",
      "························································\n",
      "╴feat: 3 <= 2.29, samples: 9600\n",
      "     ├─feat: 1 <= 2.33, samples: 4747\n",
      "     │   ├─feat: 3 <= 2.06, samples: 4694\n",
      "     │   │   └─╴value: 0.0\n",
      "     │   │   └─╴feat: 0 <= 1.10, samples: 52\n",
      "     │   │       └─╴value: 0.0\n",
      "     │   │       └─╴feat: 4 <= -0.88, samples: 7\n",
      "     │   │           └─╴value: 1.0\n",
      "     │   │           └─╴feat: 4 <= 0.63, samples: 6\n",
      "     │   │               └─╴value: 0.0\n",
      "     │   │               └─╴value: 1.0\n",
      "     │   └─╴feat: 3 <= 1.29, samples: 53\n",
      "     │       └─╴value: 0.0\n",
      "     │       └─╴feat: 0 <= 0.35, samples: 8\n",
      "     │           └─╴value: 0.0\n",
      "     └─╴feat: 3 <= 3.34, samples: 4853\n",
      "         ├─feat: 1 <= 1.46, samples: 267\n",
      "         │   ├─feat: 0 <= 0.93, samples: 82\n",
      "         │   │   ├─feat: 3 <= 2.62, samples: 59\n",
      "         │   │   │   └─╴value: 0.0\n",
      "         │   │   │   └─╴feat: 1 <= 0.09, samples: 29\n",
      "         │   │   │       └─╴value: 0.0\n",
      "         │   │   │       └─╴value: 0.0\n",
      "         │   │   └─╴feat: 1 <= -0.32, samples: 23\n",
      "         │   │       ├─feat: 2 <= -0.30, samples: 4\n",
      "         │   │       │   └─╴value: 0.0\n",
      "         │   │       │   └─╴value: 1.0\n",
      "         │   │       └─╴feat: 3 <= 2.46, samples: 19\n",
      "         │   │           └─╴value: 1.0\n",
      "         │   │           └─╴value: 1.0\n",
      "         │   └─╴feat: 0 <= -0.64, samples: 185\n",
      "         │       ├─feat: 3 <= 2.48, samples: 5\n",
      "         │       │   └─╴value: 0.0\n",
      "         │       │   └─╴value: 1.0\n",
      "         │       └─╴feat: 4 <= 1.87, samples: 180\n",
      "         │           └─╴value: 1.0\n",
      "         │           └─╴feat: 3 <= 2.61, samples: 4\n",
      "         │               └─╴value: 0.0\n",
      "         │               └─╴value: 1.0\n",
      "         └─╴feat: 3 <= 3.57, samples: 4586\n",
      "             ├─feat: 1 <= 0.83, samples: 152\n",
      "             │   ├─feat: 0 <= -0.07, samples: 9\n",
      "             │   │   └─╴value: 1.0\n",
      "             │   └─╴value: 1.0\n",
      "             └─╴value: 1.0\n",
      "\n",
      "————————————————————— tree: 02/15 ——————————————————————\n",
      "split: CART, impurity: Entropy, leaf: Mode, nodes: 47\n",
      "maxDepth: 7, reached depth: 7, minSamplesSplit: 2\n",
      "························································\n",
      "╴feat: 3 <= 2.29, samples: 9600\n",
      "     ├─feat: 1 <= 2.33, samples: 4747\n",
      "     │   ├─feat: 3 <= 2.06, samples: 4694\n",
      "     │   │   └─╴value: 0.0\n",
      "     │   │   └─╴feat: 0 <= 1.10, samples: 52\n",
      "     │   │       └─╴value: 0.0\n",
      "     │   │       └─╴feat: 0 <= 1.14, samples: 7\n",
      "     │   │           └─╴value: 1.0\n",
      "     │   │           └─╴feat: 4 <= 0.63, samples: 6\n",
      "     │   │               └─╴value: 0.0\n",
      "     │   │               └─╴value: 1.0\n",
      "     │   └─╴feat: 3 <= 1.29, samples: 53\n",
      "     │       └─╴value: 0.0\n",
      "     │       └─╴feat: 0 <= 0.35, samples: 8\n",
      "     │           └─╴value: 0.0\n",
      "     └─╴feat: 3 <= 3.34, samples: 4853\n",
      "         ├─feat: 1 <= 1.46, samples: 267\n",
      "         │   ├─feat: 0 <= 0.93, samples: 82\n",
      "         │   │   ├─feat: 3 <= 2.62, samples: 59\n",
      "         │   │   │   └─╴value: 0.0\n",
      "         │   │   │   └─╴feat: 1 <= 0.09, samples: 29\n",
      "         │   │   │       └─╴value: 0.0\n",
      "         │   │   │       └─╴value: 0.0\n",
      "         │   │   └─╴feat: 1 <= -0.32, samples: 23\n",
      "         │   │       ├─feat: 2 <= -0.30, samples: 4\n",
      "         │   │       │   └─╴value: 0.0\n",
      "         │   │       │   └─╴value: 1.0\n",
      "         │   │       └─╴feat: 3 <= 2.46, samples: 19\n",
      "         │   │           └─╴value: 1.0\n",
      "         │   │           └─╴value: 1.0\n",
      "         │   └─╴feat: 0 <= -0.64, samples: 185\n",
      "         │       ├─feat: 3 <= 2.48, samples: 5\n",
      "         │       │   └─╴value: 0.0\n",
      "         │       │   └─╴value: 1.0\n",
      "         │       └─╴feat: 4 <= 1.87, samples: 180\n",
      "         │           └─╴value: 1.0\n",
      "         │           └─╴feat: 1 <= 1.86, samples: 4\n",
      "         │               └─╴value: 0.0\n",
      "         │               └─╴value: 1.0\n",
      "         └─╴feat: 3 <= 3.57, samples: 4586\n",
      "             ├─feat: 1 <= 0.83, samples: 152\n",
      "             │   ├─feat: 0 <= -0.07, samples: 9\n",
      "             │   │   └─╴value: 1.0\n",
      "             │   └─╴value: 1.0\n",
      "             └─╴value: 1.0\n",
      "\n",
      "————————————————————— tree: 03/15 ——————————————————————\n",
      "split: CART, impurity: Entropy, leaf: Mode, nodes: 47\n",
      "maxDepth: 7, reached depth: 7, minSamplesSplit: 2\n",
      "························································\n",
      "╴feat: 3 <= 2.29, samples: 9600\n",
      "     ├─feat: 1 <= 2.33, samples: 4747\n",
      "     │   ├─feat: 3 <= 2.06, samples: 4694\n",
      "     │   │   └─╴value: 0.0\n",
      "     │   │   └─╴feat: 0 <= 1.10, samples: 52\n",
      "     │   │       └─╴value: 0.0\n",
      "     │   │       └─╴feat: 3 <= 2.22, samples: 7\n",
      "     │   │           ├─feat: 4 <= 0.63, samples: 6\n",
      "     │   │           │   └─╴value: 0.0\n",
      "     │   │           │   └─╴value: 1.0\n",
      "     │   │           └─╴value: 1.0\n",
      "     │   └─╴feat: 3 <= 1.29, samples: 53\n",
      "     │       └─╴value: 0.0\n",
      "     │       └─╴feat: 1 <= 2.37, samples: 8\n",
      "     │           └─╴value: 0.0\n",
      "     └─╴feat: 3 <= 3.34, samples: 4853\n",
      "         ├─feat: 1 <= 1.46, samples: 267\n",
      "         │   ├─feat: 0 <= 0.93, samples: 82\n",
      "         │   │   ├─feat: 3 <= 2.62, samples: 59\n",
      "         │   │   │   └─╴feat: 1 <= 0.09, samples: 29\n",
      "         │   │   │       └─╴value: 0.0\n",
      "         │   │   │       └─╴value: 0.0\n",
      "         │   │   └─╴feat: 1 <= -0.32, samples: 23\n",
      "         │   │       ├─feat: 2 <= -0.30, samples: 4\n",
      "         │   │       │   └─╴value: 0.0\n",
      "         │   │       │   └─╴value: 1.0\n",
      "         │   │       └─╴feat: 3 <= 2.46, samples: 19\n",
      "         │   │           └─╴value: 1.0\n",
      "         │   │           └─╴value: 1.0\n",
      "         │   └─╴feat: 0 <= -0.64, samples: 185\n",
      "         │       ├─feat: 3 <= 2.48, samples: 5\n",
      "         │       │   └─╴value: 0.0\n",
      "         │       │   └─╴value: 1.0\n",
      "         │       └─╴feat: 4 <= 1.87, samples: 180\n",
      "         │           └─╴value: 1.0\n",
      "         │           └─╴feat: 3 <= 2.61, samples: 4\n",
      "         │               └─╴value: 0.0\n",
      "         │               └─╴value: 1.0\n",
      "         └─╴feat: 3 <= 3.57, samples: 4586\n",
      "             ├─feat: 1 <= 0.83, samples: 152\n",
      "             │   ├─feat: 0 <= -0.07, samples: 9\n",
      "             │   │   └─╴value: 1.0\n",
      "             │   └─╴value: 1.0\n",
      "             └─╴value: 1.0\n",
      "\n",
      "————————————————————— tree: 04/15 ——————————————————————\n",
      "split: CART, impurity: Entropy, leaf: Mode, nodes: 47\n",
      "maxDepth: 7, reached depth: 7, minSamplesSplit: 2\n",
      "························································\n",
      "╴feat: 3 <= 2.29, samples: 9600\n",
      "     ├─feat: 1 <= 2.33, samples: 4747\n",
      "     │   ├─feat: 3 <= 2.06, samples: 4694\n",
      "     │   │   └─╴value: 0.0\n",
      "     │   │   └─╴feat: 0 <= 1.10, samples: 52\n",
      "     │   │       └─╴value: 0.0\n",
      "     │   │       └─╴feat: 0 <= 1.14, samples: 7\n",
      "     │   │           └─╴value: 1.0\n",
      "     │   │           └─╴feat: 4 <= 0.63, samples: 6\n",
      "     │   │               └─╴value: 0.0\n",
      "     │   │               └─╴value: 1.0\n",
      "     │   └─╴feat: 3 <= 1.29, samples: 53\n",
      "     │       └─╴value: 0.0\n",
      "     │       └─╴feat: 1 <= 2.37, samples: 8\n",
      "     │           └─╴value: 0.0\n",
      "     └─╴feat: 3 <= 3.34, samples: 4853\n",
      "         ├─feat: 1 <= 1.46, samples: 267\n",
      "         │   ├─feat: 0 <= 0.93, samples: 82\n",
      "         │   │   ├─feat: 3 <= 2.62, samples: 59\n",
      "         │   │   │   └─╴value: 0.0\n",
      "         │   │   │   └─╴feat: 1 <= 0.09, samples: 29\n",
      "         │   │   │       └─╴value: 0.0\n",
      "         │   │   │       └─╴value: 0.0\n",
      "         │   │   └─╴feat: 1 <= -0.32, samples: 23\n",
      "         │   │       ├─feat: 2 <= -0.30, samples: 4\n",
      "         │   │       │   └─╴value: 0.0\n",
      "         │   │       │   └─╴value: 1.0\n",
      "         │   │       └─╴feat: 3 <= 2.46, samples: 19\n",
      "         │   │           └─╴value: 1.0\n",
      "         │   │           └─╴value: 1.0\n",
      "         │   └─╴feat: 0 <= -0.64, samples: 185\n",
      "         │       ├─feat: 4 <= 0.49, samples: 5\n",
      "         │       │   └─╴value: 0.0\n",
      "         │       │   └─╴value: 1.0\n",
      "         │       └─╴feat: 4 <= 1.87, samples: 180\n",
      "         │           └─╴value: 1.0\n",
      "         │           └─╴feat: 3 <= 2.61, samples: 4\n",
      "         │               └─╴value: 0.0\n",
      "         │               └─╴value: 1.0\n",
      "         └─╴feat: 3 <= 3.57, samples: 4586\n",
      "             ├─feat: 1 <= 0.83, samples: 152\n",
      "             │   ├─feat: 0 <= -0.07, samples: 9\n",
      "             │   │   └─╴value: 1.0\n",
      "             │   └─╴value: 1.0\n",
      "             └─╴value: 1.0\n",
      "\n",
      "————————————————————— tree: 05/15 ——————————————————————\n",
      "split: CART, impurity: Entropy, leaf: Mode, nodes: 47\n",
      "maxDepth: 7, reached depth: 7, minSamplesSplit: 2\n",
      "························································\n",
      "╴feat: 3 <= 2.29, samples: 9600\n",
      "     ├─feat: 1 <= 2.33, samples: 4747\n",
      "     │   ├─feat: 3 <= 2.06, samples: 4694\n",
      "     │   │   └─╴value: 0.0\n",
      "     │   │   └─╴feat: 0 <= 1.10, samples: 52\n",
      "     │   │       └─╴value: 0.0\n",
      "     │   │       └─╴feat: 0 <= 1.14, samples: 7\n",
      "     │   │           └─╴value: 1.0\n",
      "     │   │           └─╴feat: 4 <= 0.63, samples: 6\n",
      "     │   │               └─╴value: 0.0\n",
      "     │   │               └─╴value: 1.0\n",
      "     │   └─╴feat: 3 <= 1.29, samples: 53\n",
      "     │       └─╴value: 0.0\n",
      "     │       └─╴feat: 0 <= 0.35, samples: 8\n",
      "     │           └─╴value: 0.0\n",
      "     └─╴feat: 3 <= 3.34, samples: 4853\n",
      "         ├─feat: 1 <= 1.46, samples: 267\n",
      "         │   ├─feat: 0 <= 0.93, samples: 82\n",
      "         │   │   ├─feat: 3 <= 2.62, samples: 59\n",
johannes bilk's avatar
johannes bilk committed
      "         │   │   │   └─╴value: 0.0\n",
      "         │   │   │   └─╴feat: 1 <= 0.09, samples: 29\n",
      "         │   │   │       └─╴value: 0.0\n",
      "         │   │   │       └─╴value: 0.0\n",
      "         │   │   └─╴feat: 1 <= -0.32, samples: 23\n",
      "         │   │       ├─feat: 2 <= -0.30, samples: 4\n",
      "         │   │       │   └─╴value: 0.0\n",
      "         │   │       │   └─╴value: 1.0\n",
      "         │   │       └─╴feat: 3 <= 2.46, samples: 19\n",
      "         │   │           └─╴value: 1.0\n",
      "         │   │           └─╴value: 1.0\n",
      "         │   └─╴feat: 0 <= -0.64, samples: 185\n",
      "         │       ├─feat: 3 <= 2.48, samples: 5\n",
      "         │       │   └─╴value: 0.0\n",
      "         │       │   └─╴value: 1.0\n",
      "         │       └─╴feat: 4 <= 1.87, samples: 180\n",
      "         │           └─╴value: 1.0\n",
      "         │           └─╴feat: 1 <= 1.86, samples: 4\n",
      "         │               └─╴value: 0.0\n",
      "         │               └─╴value: 1.0\n",
      "         └─╴feat: 3 <= 3.57, samples: 4586\n",
      "             ├─feat: 1 <= 0.83, samples: 152\n",
      "             │   ├─feat: 0 <= -0.07, samples: 9\n",
johannes bilk's avatar
johannes bilk committed
      "             │   │   └─╴value: 0.0\n",
      "             │   │   └─╴value: 1.0\n",
      "             │   └─╴value: 1.0\n",
johannes bilk's avatar
johannes bilk committed
      "             └─╴value: 1.0\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "ModelIO.save(forrest, 'forrest-test')\n",
    "newForrest = ModelIO.load('forrest-test')\n",
johannes bilk's avatar
johannes bilk committed
    "print(newForrest)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "━━━━━━━━━━━━ evaluation ━━━━━━━━━━━━\n",
      "————————— confusion matrix —————————\n",
      "              Class 0     Class 1   \n",
      "····································\n",
      "     Class 0    1591         9      \n",
johannes bilk's avatar
johannes bilk committed
      "                49%          0%     \n",
      "····································\n",
      "     Class 1     4          1596    \n",
      "                 0%         49%     \n",
johannes bilk's avatar
johannes bilk committed
      "\n",
      "———————————————————————————————— scores ———————————————————————————————\n",
      "                accuracy       precision      sensitivity      miss rate    \n",
      "·······································································\n",
      "     Class 0     0.996           0.997           0.994           0.006      \n",
      "     Class 1     0.996           0.994           0.998           0.003      \n",
johannes bilk's avatar
johannes bilk committed
      "·······································································\n",
      "       total     0.996           0.996           0.996           0.004      \n"
johannes bilk's avatar
johannes bilk committed
     ]
    }
   ],
   "source": [
    "prediction = newForrest.eval(validData)\n",
    "\n",
    "if task == 'regressor':\n",
    "    newMetrics = RegressionScores(numClasses=2)\n",
    "    newMetrics.calcScores(prediction, validTargets, validLabels)\n",
    "    print(newMetrics)\n",
    "elif task == 'classifier':\n",
    "    newConfusion = ConfusionMatrix(numClasses=2)\n",
    "    newConfusion.update(prediction, validLabels)\n",
    "    newConfusion.percentages()\n",
    "    newConfusion.calcScores()\n",
    "    print(newConfusion)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Comment\n",
    "\n",
    "The forrest works as well as the tree code, because it completely builds on it, thus inhereting all it's problems. The progress bar for training trees is very erratic, because it's always set to the current level and because of the recurvise learning process it's jumping between levels. Also if you are using boosting, every tree will be trained twice."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  },
  "vscode": {
   "interpreter": {
    "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}