package jsat.classifiers.trees;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.boosting.Bagging;
import jsat.classifiers.trees.ImpurityScore;
import jsat.classifiers.trees.TreePruner;
import jsat.math.OnLineStatistics;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.utils.FakeExecutor;
import jsat.utils.IntSet;
import jsat.utils.ListUtils;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.AtomicDoubleArray;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/trees/RandomForest.class */
public class RandomForest implements Classifier, Regressor, Parameterized {
    private static final long serialVersionUID = 2725020584282958141L;
    private CategoricalData predicting;
    private int extraSamples;
    private int featureSamples;
    private int maxForestSize;
    private boolean useOutOfBagError;
    private boolean useOutOfBagImportance;
    private TreeFeatureImportanceInference importanceMeasure;
    private OnLineStatistics[] feature_importance;
    private double outOfBagError;
    private RandomDecisionTree baseLearner;
    private List<DecisionTree> forest;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/trees/RandomForest$LearningWorker.class */
    public class LearningWorker implements Callable<LearningWorker> {
        int toLearn;
        List<DecisionTree> learned;
        DataSet dataSet;
        Random random;
        OnLineStatistics[] fi;
        private AtomicDoubleArray votes;
        private int[][] counts;

        public LearningWorker(DataSet dataSet, int i, Random random, int[][] iArr, AtomicDoubleArray atomicDoubleArray) {
            this.dataSet = dataSet;
            this.toLearn = i;
            this.random = random;
            this.learned = new ArrayList(i);
            if (RandomForest.this.useOutOfBagError) {
                this.votes = atomicDoubleArray;
                this.counts = iArr;
            }
            if (RandomForest.this.useOutOfBagImportance) {
                this.fi = new OnLineStatistics[dataSet.getNumFeatures()];
                for (int i2 = 0; i2 < this.fi.length; i2++) {
                    this.fi[i2] = new OnLineStatistics();
                }
            }
        }

        /* JADX WARN: Can't rename method to resolve collision */
        /* JADX WARN: Multi-variable type inference failed */
        @Override // java.util.concurrent.Callable
        public LearningWorker call() throws Exception {
            RegressionDataSet regressionDataSet;
            IntSet intSet = new IntSet(RandomForest.this.baseLearner.getRandomFeatureCount());
            int[] iArr = new int[this.dataSet.getSampleSize()];
            for (int i = 0; i < this.toLearn; i++) {
                Bagging.sampleWithReplacement(iArr, iArr.length + RandomForest.this.extraSamples, this.random);
                intSet.clear();
                while (intSet.size() < Math.min(RandomForest.this.baseLearner.getRandomFeatureCount(), this.dataSet.getNumFeatures())) {
                    intSet.add((IntSet) Integer.valueOf(this.random.nextInt(this.dataSet.getNumFeatures())));
                }
                RandomDecisionTree mo582clone = RandomForest.this.baseLearner.mo582clone();
                if (this.dataSet instanceof ClassificationDataSet) {
                    mo582clone.trainC(Bagging.getWeightSampledDataSet((ClassificationDataSet) this.dataSet, iArr), intSet);
                } else {
                    mo582clone.train(Bagging.getWeightSampledDataSet((RegressionDataSet) this.dataSet, iArr), intSet);
                }
                this.learned.add(mo582clone);
                if (RandomForest.this.useOutOfBagError) {
                    for (int i2 = 0; i2 < iArr.length; i2++) {
                        if (iArr[i2] == 0) {
                            DataPoint dataPoint = this.dataSet.getDataPoint(i2);
                            if (this.dataSet instanceof ClassificationDataSet) {
                                int mostLikely = mo582clone.classify(dataPoint).mostLikely();
                                synchronized (this.counts[i2]) {
                                    int[] iArr2 = this.counts[i2];
                                    iArr2[mostLikely] = iArr2[mostLikely] + 1;
                                }
                            } else {
                                this.votes.getAndAdd(i2, mo582clone.regress(dataPoint));
                                synchronized (this.counts[i2]) {
                                    int[] iArr3 = this.counts[i2];
                                    iArr3[0] = iArr3[0] + 1;
                                }
                            }
                        }
                    }
                }
                if (RandomForest.this.useOutOfBagImportance) {
                    if (this.dataSet instanceof ClassificationDataSet) {
                        ClassificationDataSet classificationDataSet = (ClassificationDataSet) this.dataSet;
                        ClassificationDataSet classificationDataSet2 = new ClassificationDataSet(classificationDataSet.getNumNumericalVars(), classificationDataSet.getCategories(), classificationDataSet.getPredicting());
                        for (int i3 = 0; i3 < iArr.length; i3++) {
                            if (iArr[i3] == 0) {
                                classificationDataSet2.addDataPoint(classificationDataSet.getDataPoint(i3), classificationDataSet.getDataPointCategory(i3));
                            }
                        }
                        regressionDataSet = classificationDataSet2;
                    } else {
                        RegressionDataSet regressionDataSet2 = (RegressionDataSet) this.dataSet;
                        RegressionDataSet regressionDataSet3 = new RegressionDataSet(regressionDataSet2.getNumNumericalVars(), regressionDataSet2.getCategories());
                        for (int i4 = 0; i4 < iArr.length; i4++) {
                            if (iArr[i4] == 0) {
                                regressionDataSet3.addDataPoint(regressionDataSet2.getDataPoint(i4), regressionDataSet2.getTargetValue(i4));
                            }
                        }
                        regressionDataSet = regressionDataSet3;
                    }
                    double[] importanceStats = RandomForest.this.importanceMeasure.getImportanceStats(mo582clone, regressionDataSet);
                    for (int i5 = 0; i5 < this.fi.length; i5++) {
                        this.fi[i5].add(importanceStats[i5]);
                    }
                }
            }
            return this;
        }
    }

    public RandomForest() {
        this(100);
    }

    public RandomForest(int i) {
        this.useOutOfBagError = false;
        this.useOutOfBagImportance = false;
        this.importanceMeasure = new MDI();
        this.feature_importance = null;
        setExtraSamples(0);
        setMaxForestSize(i);
        autoFeatureSample();
        this.baseLearner = new RandomDecisionTree(1, Integer.MAX_VALUE, 3, TreePruner.PruningMethod.NONE, 1.0E-15d);
        this.baseLearner.setGainMethod(ImpurityScore.ImpurityMeasure.GINI);
    }

    public void setExtraSamples(int i) {
        this.extraSamples = i;
    }

    public int getExtraSamples() {
        return this.extraSamples;
    }

    public void setFeatureSamples(int i) {
        if (i <= 0) {
            throw new ArithmeticException("A positive number of features must be given");
        }
        this.featureSamples = i;
    }

    public void autoFeatureSample() {
        this.featureSamples = -1;
    }

    public boolean isAutoFeatureSample() {
        return this.featureSamples == -1;
    }

    public void setMaxForestSize(int i) {
        if (i <= 0) {
            throw new ArithmeticException("Must train a positive number of learners");
        }
        this.maxForestSize = i;
    }

    public int getMaxForestSize() {
        return this.maxForestSize;
    }

    public void setUseOutOfBagError(boolean z) {
        this.useOutOfBagError = z;
    }

    public boolean isUseOutOfBagError() {
        return this.useOutOfBagError;
    }

    public OnLineStatistics[] getFeatureImportance() {
        return this.feature_importance;
    }

    public void setUseOutOfBagImportance(boolean z) {
        this.useOutOfBagImportance = z;
    }

    public boolean isUseOutOfBagImportance() {
        return this.useOutOfBagImportance;
    }

    public double getOutOfBagError() {
        return this.outOfBagError;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        if (this.forest == null || this.forest.isEmpty()) {
            throw new RuntimeException("Classifier has not yet been trained");
        }
        if (this.predicting == null) {
            throw new RuntimeException("Classifier has been trained for regression");
        }
        CategoricalResults categoricalResults = new CategoricalResults(this.predicting.getNumOfCategories());
        Iterator<DecisionTree> it = this.forest.iterator();
        while (it.hasNext()) {
            categoricalResults.incProb(it.next().classify(dataPoint).mostLikely(), 1.0d);
        }
        categoricalResults.normalize();
        return categoricalResults;
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        this.predicting = classificationDataSet.getPredicting();
        this.forest = new ArrayList(this.maxForestSize);
        trainStep(classificationDataSet, executorService);
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        trainC(classificationDataSet, new FakeExecutor());
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return true;
    }

    @Override // jsat.regression.Regressor
    public double regress(DataPoint dataPoint) {
        if (this.forest == null || this.forest.isEmpty()) {
            throw new RuntimeException("Classifier has not yet been trained");
        }
        if (this.predicting != null) {
            throw new RuntimeException("Classifier has been trained for classification");
        }
        OnLineStatistics onLineStatistics = new OnLineStatistics();
        Iterator<DecisionTree> it = this.forest.iterator();
        while (it.hasNext()) {
            onLineStatistics.add(it.next().regress(dataPoint));
        }
        return onLineStatistics.getMean();
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet, ExecutorService executorService) {
        this.predicting = null;
        this.forest = new ArrayList(this.maxForestSize);
        trainStep(regressionDataSet, executorService);
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet) {
        train(regressionDataSet, new FakeExecutor());
    }

    private void trainStep(DataSet dataSet, ExecutorService executorService) {
        int[][] iArr;
        if (isAutoFeatureSample()) {
            this.baseLearner.setRandomFeatureCount(Math.max((int) Math.sqrt(dataSet.getNumFeatures()), 1));
        } else {
            this.baseLearner.setRandomFeatureCount(this.featureSamples);
        }
        int i = this.maxForestSize;
        int i2 = i / SystemInfo.LogicalCores;
        int i3 = i % SystemInfo.LogicalCores;
        if (executorService == null || (executorService instanceof FakeExecutor)) {
            i2 = i;
        }
        Random random = new Random();
        ArrayList arrayList = new ArrayList(SystemInfo.LogicalCores);
        AtomicDoubleArray atomicDoubleArray = null;
        if (dataSet instanceof RegressionDataSet) {
            atomicDoubleArray = new AtomicDoubleArray(dataSet.getSampleSize());
            iArr = new int[atomicDoubleArray.length()][1];
        } else {
            iArr = new int[dataSet.getSampleSize()][((ClassificationDataSet) dataSet).getClassSize()];
        }
        while (i > 0) {
            int i4 = i3;
            i3--;
            int i5 = i4 > 0 ? 1 : 0;
            i -= i2 + i5;
            arrayList.add(executorService.submit(new LearningWorker(dataSet, i2 + i5, new Random(random.nextInt()), iArr, atomicDoubleArray)));
        }
        this.outOfBagError = 0.0d;
        try {
            List<LearningWorker> collectFutures = ListUtils.collectFutures(arrayList);
            Iterator it = collectFutures.iterator();
            while (it.hasNext()) {
                this.forest.addAll(((LearningWorker) it.next()).learned);
            }
            if (this.useOutOfBagError) {
                if (dataSet instanceof ClassificationDataSet) {
                    ClassificationDataSet classificationDataSet = (ClassificationDataSet) dataSet;
                    for (int i6 = 0; i6 < iArr.length; i6++) {
                        int i7 = 0;
                        for (int i8 = 1; i8 < iArr[i6].length; i8++) {
                            if (iArr[i6][i8] > iArr[i6][i7]) {
                                i7 = i8;
                            }
                        }
                        if (i7 != classificationDataSet.getDataPointCategory(i6)) {
                            this.outOfBagError += 1.0d;
                        }
                    }
                } else {
                    RegressionDataSet regressionDataSet = (RegressionDataSet) dataSet;
                    for (int i9 = 0; i9 < iArr.length; i9++) {
                        this.outOfBagError += Math.pow((atomicDoubleArray.get(i9) / iArr[i9][0]) - regressionDataSet.getTargetValue(i9), 2.0d);
                    }
                }
                this.outOfBagError /= dataSet.getSampleSize();
            }
            if (this.useOutOfBagImportance) {
                this.feature_importance = new OnLineStatistics[dataSet.getNumFeatures()];
                for (int i10 = 0; i10 < dataSet.getNumFeatures(); i10++) {
                    this.feature_importance[i10] = new OnLineStatistics();
                }
                for (LearningWorker learningWorker : collectFutures) {
                    for (int i11 = 0; i11 < dataSet.getNumFeatures(); i11++) {
                        this.feature_importance[i11].add(learningWorker.fi[i11]);
                    }
                }
            }
        } catch (Exception e) {
            Logger.getLogger(RandomForest.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
        }
    }

    @Override // jsat.regression.Regressor
    /* renamed from: clone, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public RandomForest mo582clone() {
        RandomForest randomForest = new RandomForest(this.maxForestSize);
        randomForest.extraSamples = this.extraSamples;
        randomForest.featureSamples = this.featureSamples;
        if (this.predicting != null) {
            randomForest.predicting = this.predicting.m481clone();
        }
        if (this.forest != null) {
            randomForest.forest = new ArrayList(this.forest.size());
            Iterator<DecisionTree> it = this.forest.iterator();
            while (it.hasNext()) {
                randomForest.forest.add(it.next().mo582clone());
            }
        }
        randomForest.baseLearner = this.baseLearner.mo582clone();
        randomForest.useOutOfBagImportance = this.useOutOfBagImportance;
        randomForest.useOutOfBagError = this.useOutOfBagError;
        if (this.feature_importance != null) {
            randomForest.feature_importance = new OnLineStatistics[this.feature_importance.length];
            for (int i = 0; i < this.feature_importance.length; i++) {
                randomForest.feature_importance[i] = this.feature_importance[i].m690clone();
            }
        }
        return randomForest;
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }
}
