import numpy as np
from abc import ABC, abstractmethod
from .decisionTree import DecisionTree


class Pruning(ABC):
    """
    Base class for pruning methods.
    """
    def __init__(self) -> None:
        self.name = self.__class__.__name__

    def __call__(self, tree: DecisionTree, evalData: np.ndarray, evalTargets: np.ndarray) -> None:
        self.prune(tree, evalData, evalTargets)

    @abstractmethod
    def prune(self, tree: DecisionTree, evalData: np.ndarray, evalTargets: np.ndarray) -> None:
        """
        Prune a decision tree.
        """
        pass


class ReducedError(Pruning):
    """
    Reduced Error Pruning.
    """
    def __init__(self) -> None:
        super().__init__()

    def prune(self, tree: DecisionTree, evalData: np.ndarray, evalTargets: np.ndarray) -> None:
        predictions = tree.eval(evalData)
        initialAccuracy = np.mean(predictions == evalTargets)

        # Traverse the tree in reverse (from leaves to root)
        for node in tree.breadthLast():
            if not node.hasChildren:  # Skip leaf nodes
                continue

            # Temporarily remove children
            initialValues = node.values
            initialLeft, initialRight = node.left, node.right
            node.popChildren()

            # Check accuracy on validation set
            predictions = tree.eval(evalData)
            prunedAccuracy = np.mean(predictions == evalTargets)

            # If accuracy decreases, restore children
            if prunedAccuracy < initialAccuracy:
                node.hasChildren = True
                node.values = initialValues
                node.left, node.right = initialLeft, initialRight


class CostComplexity(Pruning):
    """
    Cost Complexity Pruning.
    """
    def __init__(self) -> None:
        super().__init__()

    def prune(self, tree: DecisionTree, evalData: np.ndarray, evalTargets: np.ndarray) -> None:
        # Initial complexity and fit
        initialComplexity = tree.countNodes()
        initialFit = np.sum((tree.eval(evalData) - evalTargets) ** 2)

        # Traverse the tree in reverse (from leaves to root)
        for node in tree.breadthLast():
            if not node.hasChildren:  # Skip leaf nodes
                continue

            # Temporarily remove children
            initialValues = node.values
            initialLeft, initialRight = node.left, node.right
            node.popChildren()

            # Compute complexity and fit of pruned tree
            prunedComplexity = tree.countNodes()
            prunedFit = np.sum((tree.eval(evalData) - evalTargets) ** 2)

            # If sum of complexity and fit is higher for pruned tree, restore children
            if prunedComplexity + prunedFit > initialComplexity + initialFit:
                node.hasChildren = True
                #node.values = initialValues
                node.left, node.right = initialLeft, initialRight


class PessimisticError(Pruning):
    """
    Pessimistic Error Pruning.
    """
    def __init__(self):
        super().__init__()

    def prune(self, tree: DecisionTree, evalData: np.ndarray, evalTargets: np.ndarray):
        # Traverse the tree in reverse breadth-first order
        for node in tree.breadthLast():
            if not node.hasChildren:  # Skip leaf nodes
                continue

            # Calculate the estimated error rate of the node and its children
            nodeError = self.estimateError(node, evalData, evalTargets)
            childrenError = sum(self.estimateError(child, evalData, evalTargets) for child in node.getChildren())

            # If the node's error rate is lower, prune its children
            if nodeError <= childrenError:
                node.popChildren()

    def estimateError(self, node, evalData: np.ndarray, evalTargets: np.ndarray) -> float:
        """
        Estimate the error of the node using the validation set
        """
        # Get the indices of the data that fall within this node
        indices = np.where(evalData[:, node.feature] <= node.threshold)[0]

        # Check if node is a leaf or not
        if not node.hasChildren:
            # This node's prediction is the most common target value among the data in this node
            prediction = np.argmax(np.bincount(evalTargets[indices].astype(int)))

            # Compute error as the proportion of incorrect predictions
            error = np.sum(evalTargets[indices] != prediction) / len(indices)
        else:
            # This node is not a leaf, so its error is the average error of its children
            error = np.mean([self.estimateError(child, evalData, evalTargets) for child in node.getChildren()])

        return error