package jsat.classifiers.trees;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import jsat.DataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.classifiers.trees.ImpurityScore;
import jsat.exceptions.FailedToFitException;
import jsat.math.OnLineStatistics;
import jsat.regression.RegressionDataSet;
import jsat.utils.FakeExecutor;
import jsat.utils.SystemInfo;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/trees/ERTrees.class */
public class ERTrees extends ExtraTree {
    private static final long serialVersionUID = 7139392253403373132L;
    private ExtraTree baseTree;
    private boolean useDefaultSelectionCount;
    private boolean useDefaultStopSize;
    private CategoricalData predicting;
    private ExtraTree[] forrest;
    private int forrestSize;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/trees/ERTrees$ForrestPlanter.class */
    public class ForrestPlanter implements Runnable {
        int start;
        int end;
        DataSet dataSet;
        CountDownLatch latch;

        public ForrestPlanter(int i, int i2, DataSet dataSet, CountDownLatch countDownLatch) {
            this.start = i;
            this.end = i2;
            this.dataSet = dataSet;
            this.latch = countDownLatch;
        }

        @Override // java.lang.Runnable
        public void run() {
            if (this.dataSet instanceof ClassificationDataSet) {
                ClassificationDataSet classificationDataSet = (ClassificationDataSet) this.dataSet;
                for (int i = this.start; i < this.end; i++) {
                    ERTrees.this.forrest[i] = ERTrees.this.baseTree.m530clone();
                    ERTrees.this.forrest[i].trainC(classificationDataSet);
                }
            } else {
                if (!(this.dataSet instanceof RegressionDataSet)) {
                    throw new RuntimeException("BUG: Please report");
                }
                RegressionDataSet regressionDataSet = (RegressionDataSet) this.dataSet;
                for (int i2 = this.start; i2 < this.end; i2++) {
                    ERTrees.this.forrest[i2] = ERTrees.this.baseTree.m530clone();
                    ERTrees.this.forrest[i2].train(regressionDataSet);
                }
            }
            this.latch.countDown();
        }
    }

    public ERTrees() {
        this(100);
    }

    public ERTrees(int i) {
        this.baseTree = new ExtraTree();
        this.useDefaultSelectionCount = true;
        this.useDefaultStopSize = true;
        this.forrestSize = i;
    }

    public <Type extends DataSet> OnLineStatistics[] evaluateFeatureImportance(DataSet<Type> dataSet) {
        return dataSet instanceof ClassificationDataSet ? evaluateFeatureImportance(dataSet, new MDI(ImpurityScore.ImpurityMeasure.GINI)) : evaluateFeatureImportance(dataSet, new ImportanceByUses());
    }

    public <Type extends DataSet> OnLineStatistics[] evaluateFeatureImportance(DataSet<Type> dataSet, TreeFeatureImportanceInference treeFeatureImportanceInference) {
        OnLineStatistics[] onLineStatisticsArr = new OnLineStatistics[dataSet.getNumFeatures()];
        for (int i = 0; i < onLineStatisticsArr.length; i++) {
            onLineStatisticsArr[i] = new OnLineStatistics();
        }
        for (ExtraTree extraTree : this.forrest) {
            double[] importanceStats = treeFeatureImportanceInference.getImportanceStats(extraTree, dataSet);
            for (int i2 = 0; i2 < onLineStatisticsArr.length; i2++) {
                onLineStatisticsArr[i2].add(importanceStats[i2]);
            }
        }
        return onLineStatisticsArr;
    }

    public void setUseDefaultSelectionCount(boolean z) {
        this.useDefaultSelectionCount = z;
    }

    public boolean getUseDefaultSelectionCount() {
        return this.useDefaultSelectionCount;
    }

    public void setUseDefaultStopSize(boolean z) {
        this.useDefaultStopSize = z;
    }

    public boolean getUseDefaultStopSize() {
        return this.useDefaultStopSize;
    }

    public void setForrestSize(int i) {
        this.forrestSize = i;
    }

    public int getForrestSize() {
        return this.forrestSize;
    }

    @Override // jsat.classifiers.trees.ExtraTree, jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        CategoricalResults categoricalResults = new CategoricalResults(this.predicting.getNumOfCategories());
        for (ExtraTree extraTree : this.forrest) {
            categoricalResults.incProb(extraTree.classify(dataPoint).mostLikely(), 1.0d);
        }
        categoricalResults.normalize();
        return categoricalResults;
    }

    private void doTraining(ExecutorService executorService, DataSet dataSet) throws FailedToFitException {
        this.forrest = new ExtraTree[this.forrestSize];
        int i = this.forrestSize / SystemInfo.LogicalCores;
        int i2 = this.forrestSize % SystemInfo.LogicalCores;
        int i3 = 0;
        CountDownLatch countDownLatch = new CountDownLatch(SystemInfo.LogicalCores);
        while (i3 < this.forrestSize) {
            int i4 = i3;
            int i5 = i4 + i;
            int i6 = i2;
            i2--;
            if (i6 > 0) {
                i5++;
            }
            i3 = i5;
            executorService.submit(new ForrestPlanter(i4, i5, dataSet, countDownLatch));
        }
        try {
            countDownLatch.await();
        } catch (InterruptedException e) {
            throw new FailedToFitException(e);
        }
    }

    @Override // jsat.classifiers.trees.ExtraTree, jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        if (this.useDefaultSelectionCount) {
            this.baseTree.setSelectionCount((int) Math.max(Math.round(Math.sqrt(classificationDataSet.getNumFeatures())), 1L));
        }
        if (this.useDefaultStopSize) {
            this.baseTree.setStopSize(2);
        }
        this.predicting = classificationDataSet.getPredicting();
        doTraining(executorService, classificationDataSet);
    }

    @Override // jsat.classifiers.trees.ExtraTree, jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        trainC(classificationDataSet, new FakeExecutor());
    }

    @Override // jsat.classifiers.trees.ExtraTree, jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return true;
    }

    @Override // jsat.classifiers.trees.ExtraTree, jsat.regression.Regressor
    public double regress(DataPoint dataPoint) {
        double d = 0.0d;
        for (ExtraTree extraTree : this.forrest) {
            d += extraTree.regress(dataPoint);
        }
        return d / this.forrest.length;
    }

    @Override // jsat.classifiers.trees.ExtraTree, jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet, ExecutorService executorService) {
        if (this.useDefaultSelectionCount) {
            this.baseTree.setSelectionCount(regressionDataSet.getNumFeatures());
        }
        if (this.useDefaultStopSize) {
            this.baseTree.setStopSize(5);
        }
        doTraining(executorService, regressionDataSet);
    }

    @Override // jsat.classifiers.trees.ExtraTree, jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet) {
        train(regressionDataSet, new FakeExecutor());
    }

    @Override // jsat.classifiers.trees.ExtraTree
    /* renamed from: clone, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public ERTrees mo583clone() {
        ERTrees eRTrees = new ERTrees();
        eRTrees.forrestSize = this.forrestSize;
        eRTrees.useDefaultSelectionCount = this.useDefaultSelectionCount;
        eRTrees.useDefaultStopSize = this.useDefaultStopSize;
        eRTrees.baseTree = this.baseTree.m530clone();
        if (this.predicting != null) {
            eRTrees.predicting = this.predicting.m480clone();
        }
        if (this.forrest != null) {
            eRTrees.forrest = new ExtraTree[this.forrest.length];
            for (int i = 0; i < this.forrest.length; i++) {
                eRTrees.forrest[i] = this.forrest[i].m530clone();
            }
        }
        return eRTrees;
    }

    @Override // jsat.classifiers.trees.ExtraTree, jsat.classifiers.trees.TreeLearner
    public TreeNodeVisitor getTreeNodeVisitor() {
        throw new UnsupportedOperationException("Can not get the tree node vistor becase ERTrees is really a ensemble");
    }
}
