import numpy as np
from abc import ABC, abstractmethod
from .decisionTree import DecisionTree


class Boosting(ABC):
    def __init__(self) -> None:
        self.name = self.__class__.__name__

    def __call__(self, tree: DecisionTree, data: np.ndarray, targets: np.ndarray) -> None:
        self.train(data, targets)

    @abstractmethod
    def train(self, tree: DecisionTree, data: np.ndarray, targets: np.ndarray) -> None:
        pass

    @property
    def qualifiedName(self) -> tuple:
        return self.__class__.__module__, self.__class__.__name__


class AdaBoosting(Boosting):
    def __init__(self):
        super().__init__()
        self.alpha = []
        self.weights = None

    def train(self, tree: DecisionTree, data: np.ndarray, targets: np.ndarray) -> None:
        if self.weights is None:
            self.weights = np.ones(len(targets)) / len(targets)  # Initialize weights
        tree.train(data, targets, self.weights)  # Fit the learner with the current weights
        self.updateWeights(tree, data, targets)  # Update weights

    def updateWeights(self, tree: DecisionTree, data: np.ndarray, targets: np.ndarray) -> np.ndarray:
        # Implement AdaBoost's weight update logic
        predictions = tree.eval(data)
        errorRate = np.sum(self.weights[targets != predictions]) / np.sum(self.weights)
        treeWeights = 0.5 * np.log((1 - errorRate) / errorRate)
        self.alpha.append(treeWeights)

        # Update weights
        self.weights[targets == predictions] *= np.exp(-treeWeights)
        self.weights[targets != predictions] *= np.exp(treeWeights)
        self.weights /= np.sum(self.weights)  # Normalize weights


class GradientBoosting(Boosting):
    def __init__(self) -> None:
        super().__init__()
        self.residuals = None

    def train(self, tree: DecisionTree, data: np.ndarray, targets: np.ndarray) -> None:
        if self.residuals is None:
            tree.train(data, targets)
        else:
            tree.train(data, self.residuals)
        self.residuals = targets - tree.eval(data)