Skip to content
Snippets Groups Projects
confusionMatrix.py 14.52 KiB
import numpy as np
from numpy.typing import ArrayLike
import warnings
warnings.filterwarnings("error")


class ConfusionMatrix(object):
    """
    This class creates a confusion matrix based on labels
    It calculates also the performance scores based on the confusion matrix
    One can update the scores to be calculated
    """
    __slots__ = ['name', 'numClasses', 'matrix', 'procent', 'classes', 'classNames', 'nameLength', 'scoreNames', 'scoreLength', 'scoreFormular', 'scores', 'totals', '_scoreByFormular', '_socreByName', '_wrongFormular', '_wrongName']

    def __init__(self, numClasses: int = None, classNames: list = None) -> None:
        self.name = self.__class__.__name__

        if numClasses is None and classNames is None:
            raise ValueError('need to give either/both number of classes or class names')
        if numClasses is not None and classNames is not None:
            if len(classNames) != numClasses:
                raise ValueError('number of classes must be length of class names')

        # matrix
        self.numClasses = len(classNames) if numClasses is None else numClasses
        self.matrix = np.zeros((self.numClasses, self.numClasses), dtype=int)
        self.procent = np.zeros((self.numClasses, self.numClasses), dtype=float)
        self.classes = np.arange(0, self.numClasses)

        # class names
        self.classNames = [f'Class {i}' for i in range(self.numClasses)] if classNames is None else classNames
        self.nameLength = [len(item) for item in self.classNames]

        # scores
        self.scoreNames = ['accuracy', 'precision', 'sensitivity', 'miss rate']
        self.scoreLength = [len(item) for item in self.scoreNames]
        self.scoreFormular = ['(tp+tn)/(tp+tn+fp+fn)', 'tp/(tp+fp)', 'tp/(tp+fn)', 'fn/(fn+tp)']
        self.scores = np.zeros((len(self.classNames), len(self.scoreNames)))

        # total scores
        self.totals = np.zeros(len(self.scoreNames))

        # all possible score formulars and score names, used for configuring scores by the user
        self._scoreByFormular = {'tp/(tp+fn)': 'sensitivity',
                                 'tp/(fn+tp)': 'sensitivity',

                                 'tn/(tn+fp)': 'rejection',
                                 'tn/(fp+tn)': 'rejection',

                                 'fn/(fn+tp)': 'miss rate',
                                 'fn/(tp+fn)': 'miss rate',

                                 'tp/(tp+fp)': 'precision',
                                 'tp/(fp+tp)': 'precision',

                                 'fp/(fp+tn)': 'fall-out',
                                 'fp/(tn+fp)': 'fall-out',

                                 'fn/(fn+tn)': 'false omission',
                                 'fn/(tn+fn)': 'false omission',

                                 'fp/(fp+tp)': 'false dicovery',
                                 'fp/(tp+fp)': 'false dicovery',

                        '(2*tp)/(2*tp+fp+fn)': 'f1 score',
                        '(2*tp)/(2*tp+fn+fp)': 'f1 score',
                        '(2*tp)/(fn+2*tp+fp)': 'f1 score',
                        '(2*tp)/(fp+fn+2*tp)': 'f1 score',
                        '(2*tp)/(fn+fp+2*tp)': 'f1 score',
                        '(2*tp)/(fp+2*tp+fn)': 'f1 score',

                              'tp/(tp+fn+fp)': 'threat score',
                              'tp/(tp+fp+fn)': 'threat score',
                              'tp/(fp+tp+fn)': 'threat score',
                              'tp/(fn+fp+tp)': 'threat score',
                              'tp/(fp+fn+tp)': 'threat score',
                              'tp/(fn+tp+fp)': 'threat score',

                      '(tp+tn)/(tp+fn+fp+tn)': 'accuracy',
                      '(tp+tn)/(tp+fn+tn+fp)': 'accuracy',
                      '(tp+tn)/(tp+fp+fn+tn)': 'accuracy',
                      '(tp+tn)/(tp+fp+tn+fn)': 'accuracy',
                      '(tp+tn)/(tp+tn+fn+fp)': 'accuracy',
                      '(tp+tn)/(tp+tn+fp+fn)': 'accuracy',
                      '(tp+tn)/(fn+tp+fp+tn)': 'accuracy',
                      '(tp+tn)/(fn+tp+tn+fp)': 'accuracy',
                      '(tp+tn)/(fn+fp+tp+tn)': 'accuracy',
                      '(tp+tn)/(fn+fp+tn+tp)': 'accuracy',
                      '(tp+tn)/(fn+tn+tp+fp)': 'accuracy',
                      '(tp+tn)/(fn+tn+fp+tp)': 'accuracy',
                      '(tp+tn)/(fp+tp+fn+tn)': 'accuracy',
                      '(tp+tn)/(fp+tp+tn+fn)': 'accuracy',
                      '(tp+tn)/(fp+fn+tp+tn)': 'accuracy',
                      '(tp+tn)/(fp+fn+tn+tp)': 'accuracy',
                      '(tp+tn)/(fp+tn+tp+fn)': 'accuracy',
                      '(tp+tn)/(fp+tn+fn+tp)': 'accuracy',
                      '(tp+tn)/(tn+tp+fn+fp)': 'accuracy',
                      '(tp+tn)/(tn+tp+fp+fn)': 'accuracy',
                      '(tp+tn)/(tn+fn+tp+fp)': 'accuracy',
                      '(tp+tn)/(tn+fn+fp+tp)': 'accuracy',
                      '(tp+tn)/(tn+fp+tp+fn)': 'accuracy',
                      '(tp+tn)/(tn+fp+fn+tp)': 'accuracy',
                      '(tn+tp)/(tp+fn+fp+tn)': 'accuracy',
                      '(tn+tp)/(tp+fn+tn+fp)': 'accuracy',
                      '(tn+tp)/(tp+fp+fn+tn)': 'accuracy',
                      '(tn+tp)/(tp+fp+tn+fn)': 'accuracy',
                      '(tn+tp)/(tp+tn+fn+fp)': 'accuracy',
                      '(tn+tp)/(tp+tn+fp+fn)': 'accuracy',
                      '(tn+tp)/(fn+tp+fp+tn)': 'accuracy',
                      '(tn+tp)/(fn+tp+tn+fp)': 'accuracy',
                      '(tn+tp)/(fn+fp+tp+tn)': 'accuracy',
                      '(tn+tp)/(fn+fp+tn+tp)': 'accuracy',
                      '(tn+tp)/(fn+tn+tp+fp)': 'accuracy',
                      '(tn+tp)/(fn+tn+fp+tp)': 'accuracy',
                      '(tn+tp)/(fp+tp+fn+tn)': 'accuracy',
                      '(tn+tp)/(fp+tp+tn+fn)': 'accuracy',
                      '(tn+tp)/(fp+fn+tp+tn)': 'accuracy',
                      '(tn+tp)/(fp+fn+tn+tp)': 'accuracy',
                      '(tn+tp)/(fp+tn+tp+fn)': 'accuracy',
                      '(tn+tp)/(fp+tn+fn+tp)': 'accuracy',
                      '(tn+tp)/(tn+tp+fn+fp)': 'accuracy',
                      '(tn+tp)/(tn+tp+fp+fn)': 'accuracy',
                      '(tn+tp)/(tn+fn+tp+fp)': 'accuracy',
                      '(tn+tp)/(tn+fn+fp+tp)': 'accuracy',
                      '(tn+tp)/(tn+fp+tp+fn)': 'accuracy',
                      '(tn+tp)/(tn+fp+fn+tp)': 'accuracy'}

        self._socreByName = {'sensitivity': 'tp/(tp+fn)',
                                  'recall': 'tp/(tp+fn)',
                                 'hitrate': 'tp/(tp+fn)',
                        'truepositiverate': 'tp/(tp+fn)',
                            'truepositive': 'tp/(tp+fn)',
                                     'tpr': 'tp/(tp+fn)',

                             'specificity': 'tn/(tn+fp)',
                             'selectivity': 'tn/(tn+fp)',
                        'truenegativerate': 'tn/(tn+fp)',
                            'truenegative': 'tn/(tn+fp)',
                                     'tnr': 'tn/(tn+fp)',

                               'precision': 'tp/(tp+fp)',
                 'positivepredictivevalue': 'tp/(tp+fp)',
                      'positivepredictive': 'tp/(tp+fp)',
                                     'ppv': 'tp/(tp+fp)',

                               'rejection': 'tn/(tn+fn)',
                 'negativepredictivevalue': 'tn/(tn+fn)',
                      'negativepredictive': 'tn/(tn+fn)',
                                     'npv': 'tn/(tn+fn)',

                                'missrate': 'fn/(fn+tp)',
                       'falsenegativerate': 'fn/(fn+tp)',
                           'falsenegative': 'fn/(fn+tp)',
                                     'fnr': 'fn/(fn+tp)',

                                 'fallout': 'fp/(fp+tn)',
                       'falsepositiverate': 'fp/(fp+tn)',
                           'falsepositive': 'fp/(fp+tn)',
                                     'fpr': 'fp/(fp+tn)',

                      'falsediscoveryrate': 'fp/(fp+tp)',
                          'falsediscovery': 'fp/(fp+tp)',
                                     'fdr': 'fp/(fp+tp)',

                       'falseomissionrate': 'fn/(fn+tn)',
                           'falseomission': 'fn/(fn+tn)',
                                     'for': 'fn/(fn+tn)',

                             'threatscore': 'tp/(tp+fn+fp)',
                    'criticalsuccessindex': 'tp/(tp+fn+fp)',
                         'criticalsuccess': 'tp/(tp+fn+fp)',
                                      'ts': 'tp/(tp+fn+fp)',
                                     'csi': 'tp/(tp+fn+fp)',

                                'accuracy': '(tp+tn)/(tp+fn+fp+tn)',
                                     'acc': '(tp+tn)/(tp+fn+fp+tn)',

                                 'f1score': '(2*tp)/(2*tp+fp+fn)'}
        self._wrongFormular = []
        self._wrongName = []

    def update(self, prediction: ArrayLike, target: ArrayLike) -> None:
        """
        Update the confusion matrix based on new predictions and targets.
        """

        # convert one-hot encoding to categorial
        if len(target.shape) == 2:
            target = np.argmax(target, axis=-1)
            prediction = np.argmax(prediction, axis=-1)

        # loop across the different combinations of actual / predicted classes
        for i in range(self.numClasses):
            for j in range(self.numClasses):
                # count the number of instances in each combination of actual / predicted classes
                self.matrix[i, j] += np.sum((target == self.classes[i]) & (prediction == self.classes[j]))

    def percentages(self) -> None:
        # Convert the confusion matrix to percentages.
        self.procent = np.round(100 * (self.matrix / self.matrix.sum()), 2)

    def calcScores(self) -> None:
        """
        Calculate the scores based on the confusion matrix.
        """

        # reading tp, tn, fp, tn from confusion matrix
        tptensor = self.matrix.diagonal()
        fntensor = (self.matrix.sum(1) - self.matrix.diagonal())#.reshape(-1,1)
        fptensor = (self.matrix.sum(0) - self.matrix.diagonal())#.reshape(-1,1)
        tntensor = (self.matrix.sum() - self.matrix.sum(1) - self.matrix.sum(0) + self.matrix.diagonal())#.reshape(-1,1)

        # calculating scores
        for i, formular in enumerate(self.scoreFormular):
            for j, category in enumerate(self.classNames):
                # tp, fn, fp, tn will be used by 'eval(formular)'
                # every formular is a string consiting of tp, tn, fp, tn
                # pyflakes says these variables are never used, but
                # they are used with 'formular', where eval(...) converts
                # a string into code, which uses tp, fn, fp and tn
                tp, fn, fp, tn = tptensor[j], fntensor[j], fptensor[j], tntensor[j]
                try:
                    self.scores[j,i] = eval(formular)
                except RuntimeWarning:
                    self.scores[j,i] = np.nan

        # estimating the total scores across all categories
        self.totals = np.nanmean(self.scores, axis=0)
        #self.totals = self.scores.mean(0)

    def setScores(self, *scores: str) -> None:
        """
        allows setting custom scores, needs to be based confusion matrix
        'scores' either needs to be a list of names of formulars
        when call, it overwrites all scores
        """
        self.scoreNames = []
        self.scoreFormular = []
        for score in scores:
            name = score.lower().replace(' ','').replace('-','')
            if score in self._scoreByFormular:
                self.scoreNames.append(self._scoreByFormular[score])
                self.scoreFormular.append(score)
            elif name in self._socreByName:
                self.scoreNames.append(score)
                self.scoreFormular.append(self._socreByName[name])
            elif '(' in score or ')' in score or '+' in score or '-' in score or '/' in score or '*' in score:
                self._wrongFormular.append(score)
            else:
                self._wrongName.append(score)

    def __str__(self) -> str:
        """
        Print the confusion matrix and the scores.
        """
        lengthAddition = 5
        center = (self.numClasses + 1) * (np.max(self.nameLength) + lengthAddition)
        printString = ''

        # printing the section title
        if np.sum(self.scores) > 0:
            printString += ' evaluation '.center(center, '') + '\n'

        # printing the confusion matrix
        printString += ' confusion matrix '.center(center, '') + '\n'
        printString += ''.ljust(np.max(self.nameLength) + lengthAddition)
        for head in self.classNames:
            printString += head.center(np.max(self.nameLength) + lengthAddition)
        printString += '\n' + '·' * (center) + '\n'
        for i, (line, pro) in enumerate(zip(self.matrix, self.procent)):
            printString += self.classNames[i].rjust(np.max(self.nameLength) + lengthAddition)
            for item in line:
                printString += str(int(item.item())).center(np.max(self.nameLength) + lengthAddition)
            printString += '\n'
            if np.sum(self.procent) > 0:
                printString += ''.rjust(np.max(self.nameLength) + lengthAddition)
                for item in pro:
                    printString += (str(int(item.item())) + '%').center(np.max(self.nameLength) + lengthAddition)
                printString += '\n'
                if i < self.numClasses - 1:
                    printString += '·' * (center) + '\n'

        # printing the scores
        if np.sum(self.scores) > 0:
            center = np.max(self.nameLength) + len(self.scoreNames) * (np.max(self.scoreLength) + lengthAddition)
            printString += '\n' + ' scores '.center(center, '') + '\n'
            printString += ''.ljust(np.max(self.nameLength) + lengthAddition)
            for head in self.scoreNames:
                printString += head.center(np.max(self.scoreLength) + lengthAddition)
            printString += '\n' + '·' * (center) + '\n'
            for i, line in enumerate(self.scores):
                printString += self.classNames[i].rjust(np.max(self.nameLength) + lengthAddition)
                for item in line:
                    printString += str(round(item.item(),3)).center(np.max(self.scoreLength) + lengthAddition)
                printString += '\n'
            printString += '·' * (center) + '\n'
            printString += 'total'.rjust(np.max(self.nameLength) + lengthAddition)
            for item in self.totals:
                printString += str(round(item.item(),3)).center(np.max(self.scoreLength) + lengthAddition)
        return printString