package jsat.classifiers.trees;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import jsat.DataSet;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.classifiers.trees.ImpurityScore;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/trees/MDI.class */
public class MDI implements TreeFeatureImportanceInference {
    private ImpurityScore.ImpurityMeasure im;

    public MDI(ImpurityScore.ImpurityMeasure impurityMeasure) {
        this.im = impurityMeasure;
    }

    public MDI() {
        this(ImpurityScore.ImpurityMeasure.GINI);
    }

    @Override // jsat.classifiers.trees.TreeFeatureImportanceInference
    public <Type extends DataSet> double[] getImportanceStats(TreeLearner treeLearner, DataSet<Type> dataSet) {
        double[] dArr = new double[dataSet.getNumFeatures()];
        if (!(dataSet instanceof ClassificationDataSet)) {
            throw new RuntimeException("MDI currently only supports classification datasets");
        }
        List<DataPointPair<Integer>> asDPPList = ((ClassificationDataSet) dataSet).getAsDPPList();
        int classSize = ((ClassificationDataSet) dataSet).getClassSize();
        ImpurityScore impurityScore = new ImpurityScore(classSize, this.im);
        for (DataPointPair<Integer> dataPointPair : asDPPList) {
            impurityScore.addPoint(dataPointPair.getDataPoint(), dataPointPair.getPair().intValue());
        }
        visit(treeLearner.getTreeNodeVisitor(), impurityScore, asDPPList, dArr, impurityScore.getSumOfWeights(), classSize);
        return dArr;
    }

    private void visit(TreeNodeVisitor treeNodeVisitor, ImpurityScore impurityScore, List<DataPointPair<Integer>> list, double[] dArr, double d, int i) {
        if (treeNodeVisitor == null || treeNodeVisitor.isLeaf()) {
            return;
        }
        double score = impurityScore.getScore();
        double sumOfWeights = impurityScore.getSumOfWeights();
        ArrayList arrayList = new ArrayList(treeNodeVisitor.childrenCount());
        ArrayList<ImpurityScore> arrayList2 = new ArrayList(treeNodeVisitor.childrenCount());
        arrayList.add(list);
        arrayList2.add(impurityScore);
        for (int i2 = 0; i2 < treeNodeVisitor.childrenCount() - 1; i2++) {
            arrayList.add(new ArrayList());
            arrayList2.add(new ImpurityScore(i, this.im));
        }
        ListIterator<DataPointPair<Integer>> listIterator = list.listIterator();
        while (listIterator.hasNext()) {
            DataPointPair<Integer> next = listIterator.next();
            int intValue = next.getPair().intValue();
            DataPoint dataPoint = next.getDataPoint();
            int path = treeNodeVisitor.getPath(dataPoint);
            if (path < 0) {
                impurityScore.removePoint(dataPoint, intValue);
            } else if (path > 0) {
                impurityScore.removePoint(dataPoint, intValue);
                ((ImpurityScore) arrayList2.get(path)).addPoint(dataPoint, intValue);
                ((List) arrayList.get(path)).add(next);
                listIterator.remove();
            }
        }
        double d2 = score;
        for (ImpurityScore impurityScore2 : arrayList2) {
            d2 -= impurityScore2.getScore() * (impurityScore2.getSumOfWeights() / (1.0E-5d + sumOfWeights));
        }
        Iterator<Integer> it = treeNodeVisitor.featuresUsed().iterator();
        while (it.hasNext()) {
            int intValue2 = it.next().intValue();
            dArr[intValue2] = dArr[intValue2] + ((d2 * sumOfWeights) / d);
        }
        for (int i3 = 0; i3 < arrayList2.size(); i3++) {
            visit(treeNodeVisitor.getChild(i3), (ImpurityScore) arrayList2.get(i3), (List) arrayList.get(i3), dArr, d, i);
        }
    }
}
