package jsat.classifiers.trees;

import java.util.Random;
import jsat.DataSet;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.evaluation.Accuracy;
import jsat.classifiers.evaluation.ClassificationScore;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.regression.evaluation.MeanSquaredError;
import jsat.regression.evaluation.RegressionScore;
import jsat.utils.random.XORWOW;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/trees/MDA.class */
public class MDA implements TreeFeatureImportanceInference {
    private ClassificationScore cs_base = new Accuracy();
    private RegressionScore rs_base = new MeanSquaredError();

    @Override // jsat.classifiers.trees.TreeFeatureImportanceInference
    public <Type extends DataSet> double[] getImportanceStats(TreeLearner treeLearner, DataSet<Type> dataSet) {
        double d;
        double d2;
        double d3;
        double d4;
        double[] dArr = new double[dataSet.getNumFeatures()];
        XORWOW xorwow = new XORWOW();
        if (dataSet instanceof ClassificationDataSet) {
            ClassificationDataSet classificationDataSet = (ClassificationDataSet) dataSet;
            ClassificationScore m514clone = this.cs_base.m514clone();
            m514clone.prepare(classificationDataSet.getPredicting());
            for (int i = 0; i < classificationDataSet.getSampleSize(); i++) {
                DataPoint dataPoint = classificationDataSet.getDataPoint(i);
                m514clone.addResult(((Classifier) treeLearner).classify(dataPoint), classificationDataSet.getDataPointCategory(i), dataPoint.getWeight());
            }
            double score = m514clone.getScore();
            boolean lowerIsBetter = m514clone.lowerIsBetter();
            for (int i2 = 0; i2 < dataSet.getNumFeatures(); i2++) {
                m514clone.prepare(classificationDataSet.getPredicting());
                for (int i3 = 0; i3 < classificationDataSet.getSampleSize(); i3++) {
                    DataPoint dataPoint2 = classificationDataSet.getDataPoint(i3);
                    m514clone.addResult(walkCorruptedPath(treeLearner, dataPoint2, i2, xorwow).localClassify(dataPoint2), classificationDataSet.getDataPointCategory(i3), dataPoint2.getWeight());
                }
                double score2 = m514clone.getScore();
                int i4 = i2;
                if (lowerIsBetter) {
                    d3 = score2;
                    d4 = score;
                } else {
                    d3 = score;
                    d4 = score2;
                }
                dArr[i4] = (d3 - d4) / (score + 0.001d);
            }
        } else if (dataSet instanceof RegressionDataSet) {
            RegressionDataSet regressionDataSet = (RegressionDataSet) dataSet;
            RegressionScore m722clone = this.rs_base.m722clone();
            m722clone.prepare();
            for (int i5 = 0; i5 < regressionDataSet.getSampleSize(); i5++) {
                DataPoint dataPoint3 = regressionDataSet.getDataPoint(i5);
                m722clone.addResult(((Regressor) treeLearner).regress(dataPoint3), regressionDataSet.getTargetValue(i5), dataPoint3.getWeight());
            }
            double score3 = m722clone.getScore();
            boolean lowerIsBetter2 = m722clone.lowerIsBetter();
            for (int i6 = 0; i6 < dataSet.getNumFeatures(); i6++) {
                m722clone.prepare();
                for (int i7 = 0; i7 < regressionDataSet.getSampleSize(); i7++) {
                    DataPoint dataPoint4 = regressionDataSet.getDataPoint(i7);
                    m722clone.addResult(walkCorruptedPath(treeLearner, dataPoint4, i6, xorwow).localRegress(dataPoint4), regressionDataSet.getTargetValue(i7), dataPoint4.getWeight());
                }
                double score4 = m722clone.getScore();
                int i8 = i6;
                if (lowerIsBetter2) {
                    d = score4;
                    d2 = score3;
                } else {
                    d = score3;
                    d2 = score4;
                }
                dArr[i8] = (d - d2) / (score3 + 0.001d);
            }
        }
        return dArr;
    }

    private TreeNodeVisitor walkCorruptedPath(TreeLearner treeLearner, DataPoint dataPoint, int i, Random random) {
        TreeNodeVisitor treeNodeVisitor;
        TreeNodeVisitor treeNodeVisitor2 = treeLearner.getTreeNodeVisitor();
        while (true) {
            treeNodeVisitor = treeNodeVisitor2;
            if (treeNodeVisitor.isLeaf()) {
                break;
            }
            int path = treeNodeVisitor.getPath(dataPoint);
            int childrenCount = treeNodeVisitor.childrenCount();
            if (treeNodeVisitor.featuresUsed().contains(Integer.valueOf(i))) {
                path = (path + random.nextInt(childrenCount)) % childrenCount;
            }
            if (treeNodeVisitor.isPathDisabled(path)) {
                break;
            }
            treeNodeVisitor2 = treeNodeVisitor.getChild(path);
        }
        return treeNodeVisitor;
    }
}
