package jsat.classifiers.trees;

import java.util.ArrayList;
import java.util.List;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.math.SpecialMath;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/trees/TreePruner.class */
public class TreePruner {

    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/trees/TreePruner$PruningMethod.class */
    public enum PruningMethod {
        NONE,
        REDUCED_ERROR,
        ERROR_BASED
    }

    private TreePruner() {
    }

    public static void prune(TreeNodeVisitor treeNodeVisitor, PruningMethod pruningMethod, ClassificationDataSet classificationDataSet) {
        prune(treeNodeVisitor, pruningMethod, classificationDataSet.getAsDPPList());
    }

    public static void prune(TreeNodeVisitor treeNodeVisitor, PruningMethod pruningMethod, List<DataPointPair<Integer>> list) {
        if (pruningMethod == PruningMethod.NONE) {
            return;
        }
        if (pruningMethod == PruningMethod.REDUCED_ERROR) {
            pruneReduceError(null, -1, treeNodeVisitor, list);
        } else {
            if (pruningMethod != PruningMethod.ERROR_BASED) {
                throw new RuntimeException("BUG: please report");
            }
            pruneErrorBased(null, -1, treeNodeVisitor, list, 0.25d);
        }
    }

    private static int pruneReduceError(TreeNodeVisitor treeNodeVisitor, int i, TreeNodeVisitor treeNodeVisitor2, List<DataPointPair<Integer>> list) {
        if (treeNodeVisitor2 == null) {
            return 0;
        }
        int i2 = 0;
        if (!treeNodeVisitor2.isLeaf()) {
            int childrenCount = treeNodeVisitor2.childrenCount();
            ArrayList arrayList = new ArrayList(childrenCount);
            ArrayList arrayList2 = new ArrayList(0);
            for (int i3 = 0; i3 < childrenCount; i3++) {
                arrayList.add(new ArrayList());
            }
            for (DataPointPair<Integer> dataPointPair : list) {
                int path = treeNodeVisitor2.getPath(dataPointPair.getDataPoint());
                if (path >= 0) {
                    ((List) arrayList.get(path)).add(dataPointPair);
                } else {
                    arrayList2.add(dataPointPair);
                }
            }
            if (!arrayList2.isEmpty()) {
                DecisionStump.distributMissing(arrayList, arrayList2);
            }
            for (int i4 = childrenCount - 1; i4 >= 0; i4--) {
                i2 += pruneReduceError(treeNodeVisitor2, i4, treeNodeVisitor2.getChild(i4), (List) arrayList.get(i4));
            }
        }
        if (!treeNodeVisitor2.isLeaf() || treeNodeVisitor == null) {
            return i2;
        }
        double d = 0.0d;
        double d2 = 0.0d;
        for (DataPointPair<Integer> dataPointPair2 : list) {
            DataPoint dataPoint = dataPointPair2.getDataPoint();
            int intValue = dataPointPair2.getPair().intValue();
            if (treeNodeVisitor2.localClassify(dataPoint).mostLikely() == intValue) {
                d += dataPoint.getWeight();
            }
            if (treeNodeVisitor.localClassify(dataPoint).mostLikely() == intValue) {
                d2 += dataPoint.getWeight();
            }
        }
        if (d2 < d) {
            return i2;
        }
        treeNodeVisitor.disablePath(i);
        return i2 + 1;
    }

    private static double pruneErrorBased(TreeNodeVisitor treeNodeVisitor, int i, TreeNodeVisitor treeNodeVisitor2, List<DataPointPair<Integer>> list, double d) {
        double computeBinomialUpperBound;
        if (treeNodeVisitor2 == null || list.isEmpty()) {
            return 0.0d;
        }
        if (treeNodeVisitor2.isLeaf()) {
            int i2 = 0;
            double d2 = 0.0d;
            for (DataPointPair<Integer> dataPointPair : list) {
                if (treeNodeVisitor2.localClassify(dataPointPair.getDataPoint()).mostLikely() != dataPointPair.getPair().intValue()) {
                    i2 = (int) (i2 + dataPointPair.getDataPoint().getWeight());
                }
                d2 += dataPointPair.getDataPoint().getWeight();
            }
            return computeBinomialUpperBound(d2, d, i2);
        }
        ArrayList arrayList = new ArrayList(treeNodeVisitor2.childrenCount());
        ArrayList arrayList2 = new ArrayList(0);
        for (int i3 = 0; i3 < treeNodeVisitor2.childrenCount(); i3++) {
            arrayList.add(new ArrayList());
        }
        int i4 = 0;
        double d3 = 0.0d;
        double d4 = 0.0d;
        for (DataPointPair<Integer> dataPointPair2 : list) {
            DataPoint dataPoint = dataPointPair2.getDataPoint();
            if (treeNodeVisitor2.localClassify(dataPoint).mostLikely() != dataPointPair2.getPair().intValue()) {
                i4 = (int) (i4 + dataPoint.getWeight());
            }
            d4 += dataPoint.getWeight();
            int path = treeNodeVisitor2.getPath(dataPoint);
            if (path >= 0) {
                ((List) arrayList.get(path)).add(dataPointPair2);
            } else {
                arrayList2.add(dataPointPair2);
            }
        }
        if (!arrayList2.isEmpty()) {
            DecisionStump.distributMissing(arrayList, arrayList2);
        }
        int i5 = 0;
        int i6 = -1;
        for (int i7 = 0; i7 < arrayList.size(); i7++) {
            if (!treeNodeVisitor2.isPathDisabled(i7)) {
                d3 += pruneErrorBased(treeNodeVisitor2, i7, treeNodeVisitor2.getChild(i7), (List) arrayList.get(i7), d);
                if (i5 < ((List) arrayList.get(i7)).size()) {
                    i5 = ((List) arrayList.get(i7)).size();
                    i6 = i7;
                }
            }
        }
        double computeBinomialUpperBound2 = computeBinomialUpperBound(d4, d, i4);
        if (i6 == -1) {
            computeBinomialUpperBound = Double.POSITIVE_INFINITY;
        } else {
            TreeNodeVisitor child = treeNodeVisitor2.getChild(i6);
            int i8 = 0;
            for (int i9 = 0; i9 < arrayList.size(); i9++) {
                for (DataPointPair dataPointPair3 : (List) arrayList.get(i9)) {
                    if (child.classify(dataPointPair3.getDataPoint()).mostLikely() != ((Integer) dataPointPair3.getPair()).intValue()) {
                        i8 = (int) (i8 + dataPointPair3.getDataPoint().getWeight());
                    }
                }
            }
            computeBinomialUpperBound = computeBinomialUpperBound(d4, d, i8);
        }
        if (computeBinomialUpperBound < computeBinomialUpperBound2 && computeBinomialUpperBound < d3 && treeNodeVisitor != null) {
            try {
                treeNodeVisitor.setPath(i, treeNodeVisitor2.getChild(i6));
                return computeBinomialUpperBound;
            } catch (UnsupportedOperationException e) {
            }
        }
        if (computeBinomialUpperBound2 >= d3) {
            return d3;
        }
        for (int i10 = 0; i10 < treeNodeVisitor2.childrenCount(); i10++) {
            treeNodeVisitor2.disablePath(i10);
        }
        return computeBinomialUpperBound2;
    }

    private static double computeBinomialUpperBound(double d, double d2, double d3) {
        return d * (1.0d - SpecialMath.invBetaIncReg(d2, (d - d3) + 1.0E-9d, d3 + 1.0d));
    }
}
