package jsat.classifiers.trees;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.classifiers.trees.ImpurityScore;
import jsat.classifiers.trees.TreePruner;
import jsat.exceptions.FailedToFitException;
import jsat.exceptions.ModelMismatchException;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.utils.FakeExecutor;
import jsat.utils.IntList;
import jsat.utils.IntSet;
import jsat.utils.ModifiableCountDownLatch;
import jsat.utils.SystemInfo;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/trees/DecisionTree.class */
public class DecisionTree implements Classifier, Regressor, Parameterized, TreeLearner {
    private static final long serialVersionUID = 9220980056440500214L;
    private int maxDepth;
    private int minSamples;
    private Node root;
    private CategoricalData predicting;
    private TreePruner.PruningMethod pruningMethod;
    private double testProportion;
    private DecisionStump baseStump;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/trees/DecisionTree$Node.class */
    public static class Node extends TreeNodeVisitor {
        private static final long serialVersionUID = -7507748424627088734L;
        protected final DecisionStump stump;
        protected Node[] paths;

        public Node(DecisionStump decisionStump) {
            this.stump = decisionStump;
            this.paths = new Node[decisionStump.getNumberOfPaths()];
        }

        @Override // jsat.classifiers.trees.TreeNodeVisitor
        public double getPathWeight(int i) {
            return this.stump.pathRatio[i];
        }

        @Override // jsat.classifiers.trees.TreeNodeVisitor
        public boolean isLeaf() {
            if (this.paths == null) {
                return true;
            }
            for (int i = 0; i < this.paths.length; i++) {
                if (this.paths[i] != null) {
                    return false;
                }
            }
            return true;
        }

        @Override // jsat.classifiers.trees.TreeNodeVisitor
        public int childrenCount() {
            return this.paths.length;
        }

        @Override // jsat.classifiers.trees.TreeNodeVisitor
        public CategoricalResults localClassify(DataPoint dataPoint) {
            return this.stump.classify(dataPoint);
        }

        @Override // jsat.classifiers.trees.TreeNodeVisitor
        public double localRegress(DataPoint dataPoint) {
            return this.stump.regress(dataPoint);
        }

        @Override // jsat.classifiers.trees.TreeNodeVisitor
        /* renamed from: clone */
        public Node mo583clone() {
            Node node = new Node(this.stump.clone());
            for (int i = 0; i < this.paths.length; i++) {
                node.paths[i] = this.paths[i] == null ? null : this.paths[i].mo583clone();
            }
            return node;
        }

        @Override // jsat.classifiers.trees.TreeNodeVisitor
        public TreeNodeVisitor getChild(int i) {
            if (isLeaf()) {
                return null;
            }
            return this.paths[i];
        }

        @Override // jsat.classifiers.trees.TreeNodeVisitor
        public void setPath(int i, TreeNodeVisitor treeNodeVisitor) {
            if (treeNodeVisitor instanceof Node) {
                this.paths[i] = (Node) treeNodeVisitor;
            } else {
                super.setPath(i, treeNodeVisitor);
            }
        }

        @Override // jsat.classifiers.trees.TreeNodeVisitor
        public void disablePath(int i) {
            this.paths[i] = null;
        }

        @Override // jsat.classifiers.trees.TreeNodeVisitor
        public int getPath(DataPoint dataPoint) {
            return this.stump.whichPath(dataPoint);
        }

        @Override // jsat.classifiers.trees.TreeNodeVisitor
        public boolean isPathDisabled(int i) {
            return isLeaf() || this.paths[i] == null;
        }

        @Override // jsat.classifiers.trees.TreeNodeVisitor
        public Collection<Integer> featuresUsed() {
            IntList intList = new IntList(1);
            intList.add(this.stump.getSplittingAttribute());
            return intList;
        }
    }

    @Override // jsat.regression.Regressor
    public double regress(DataPoint dataPoint) {
        if (dataPoint.numNumericalValues() == this.root.stump.numNumeric() && dataPoint.numCategoricalValues() == this.root.stump.numCategorical()) {
            return this.root.regress(dataPoint);
        }
        throw new ModelMismatchException("Tree expected " + this.root.stump.numNumeric() + " numeric and " + this.root.stump.numCategorical() + " categorical features, instead received data with " + dataPoint.numNumericalValues() + " and " + dataPoint.numCategoricalValues() + " features respectively");
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet, ExecutorService executorService) {
        IntSet intSet = new IntSet(regressionDataSet.getNumFeatures());
        for (int i = 0; i < regressionDataSet.getNumFeatures(); i++) {
            intSet.add((IntSet) Integer.valueOf(i));
        }
        train(regressionDataSet, intSet, executorService);
    }

    public void train(RegressionDataSet regressionDataSet, Set<Integer> set) {
        train(regressionDataSet, set, new FakeExecutor());
    }

    public void train(RegressionDataSet regressionDataSet, Set<Integer> set, ExecutorService executorService) {
        ModifiableCountDownLatch modifiableCountDownLatch = new ModifiableCountDownLatch(1);
        this.root = makeNodeR(regressionDataSet.getDPPList(), set, 0, executorService, modifiableCountDownLatch);
        try {
            modifiableCountDownLatch.await();
        } catch (InterruptedException e) {
            Logger.getLogger(DecisionTree.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
        }
        if (this.root == null) {
            DecisionStump decisionStump = new DecisionStump();
            decisionStump.train(regressionDataSet, executorService);
            this.root = new Node(decisionStump);
        }
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet) {
        train(regressionDataSet, new FakeExecutor());
    }

    public DecisionTree() {
        this(Integer.MAX_VALUE, 10, TreePruner.PruningMethod.REDUCED_ERROR, 0.1d);
    }

    public DecisionTree(int i) {
        this(i, 10, TreePruner.PruningMethod.NONE, 1.0E-5d);
    }

    public DecisionTree(int i, int i2, TreePruner.PruningMethod pruningMethod, double d) {
        this.baseStump = new DecisionStump();
        setMaxDepth(i);
        setMinSamples(i2);
        setPruningMethod(pruningMethod);
        setTestProportion(d);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public DecisionTree(DecisionTree decisionTree) {
        this.baseStump = new DecisionStump();
        this.maxDepth = decisionTree.maxDepth;
        this.minSamples = decisionTree.minSamples;
        if (decisionTree.root != null) {
            this.root = decisionTree.root.mo583clone();
        }
        if (decisionTree.predicting != null) {
            this.predicting = decisionTree.predicting.m481clone();
        }
        this.pruningMethod = decisionTree.pruningMethod;
        this.testProportion = decisionTree.testProportion;
        this.baseStump = decisionTree.baseStump.clone();
    }

    public static DecisionTree getC45Tree() {
        DecisionTree decisionTree = new DecisionTree();
        decisionTree.setMinResultSplitSize(2);
        decisionTree.setMinSamples(3);
        decisionTree.setMinResultSplitSize(2);
        decisionTree.setTestProportion(1.0d);
        decisionTree.setPruningMethod(TreePruner.PruningMethod.ERROR_BASED);
        decisionTree.baseStump.setGainMethod(ImpurityScore.ImpurityMeasure.INFORMATION_GAIN_RATIO);
        return decisionTree;
    }

    public void setGainMethod(ImpurityScore.ImpurityMeasure impurityMeasure) {
        this.baseStump.setGainMethod(impurityMeasure);
    }

    public ImpurityScore.ImpurityMeasure getGainMethod() {
        return this.baseStump.getGainMethod();
    }

    public void setMinResultSplitSize(int i) {
        this.baseStump.setMinResultSplitSize(i);
    }

    public int getMinResultSplitSize() {
        return this.baseStump.getMinResultSplitSize();
    }

    public void setMaxDepth(int i) {
        if (i < 0) {
            throw new RuntimeException("The maximum depth must be a positive number");
        }
        this.maxDepth = i;
    }

    public int getMaxDepth() {
        return this.maxDepth;
    }

    public void setMinSamples(int i) {
        this.minSamples = i;
    }

    public int getMinSamples() {
        return this.minSamples;
    }

    public void setPruningMethod(TreePruner.PruningMethod pruningMethod) {
        this.pruningMethod = pruningMethod;
    }

    public TreePruner.PruningMethod getPruningMethod() {
        return this.pruningMethod;
    }

    public double getTestProportion() {
        return this.testProportion;
    }

    public void setTestProportion(double d) {
        if (d < 0.0d || d > 1.0d || Double.isInfinite(d) || Double.isNaN(d)) {
            throw new ArithmeticException("Proportion must be in the range [0, 1], not " + d);
        }
        this.testProportion = d;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        if (dataPoint.numNumericalValues() == this.root.stump.numNumeric() && dataPoint.numCategoricalValues() == this.root.stump.numCategorical()) {
            return this.root.classify(dataPoint);
        }
        throw new ModelMismatchException("Tree expected " + this.root.stump.numNumeric() + " numeric and " + this.root.stump.numCategorical() + " categorical features, instead received data with " + dataPoint.numNumericalValues() + " and " + dataPoint.numCategoricalValues() + " features respectively");
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        IntSet intSet = new IntSet(classificationDataSet.getNumFeatures());
        for (int i = 0; i < classificationDataSet.getNumFeatures(); i++) {
            intSet.add((IntSet) Integer.valueOf(i));
        }
        trainC(classificationDataSet, intSet, executorService);
    }

    protected void trainC(ClassificationDataSet classificationDataSet, Set<Integer> set, ExecutorService executorService) {
        if (classificationDataSet.getSampleSize() < this.minSamples) {
            throw new FailedToFitException("There are only " + classificationDataSet.getSampleSize() + " data points in the sample set, at least " + this.minSamples + " are needed to make a tree");
        }
        this.predicting = classificationDataSet.getPredicting();
        ModifiableCountDownLatch modifiableCountDownLatch = new ModifiableCountDownLatch(1);
        List<DataPointPair<Integer>> asDPPList = classificationDataSet.getAsDPPList();
        ArrayList arrayList = new ArrayList();
        if (this.pruningMethod != TreePruner.PruningMethod.NONE && this.testProportion != 0.0d) {
            if (this.testProportion != 1.0d) {
                int size = (int) (asDPPList.size() * this.testProportion);
                Random random = new Random(size);
                for (int i = 0; i < size; i++) {
                    arrayList.add(asDPPList.remove(random.nextInt(asDPPList.size())));
                }
            } else {
                arrayList.addAll(asDPPList);
            }
        }
        this.root = makeNodeC(asDPPList, set, 0, executorService, modifiableCountDownLatch);
        try {
            modifiableCountDownLatch.await();
        } catch (InterruptedException e) {
            System.err.println(e.getMessage());
            Logger.getLogger(DecisionTree.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
        }
        if (this.root != null) {
            TreePruner.prune(this.root, this.pruningMethod, arrayList);
            return;
        }
        DecisionStump decisionStump = new DecisionStump();
        decisionStump.trainC(classificationDataSet, executorService);
        this.root = new Node(decisionStump);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Node makeNodeC(List<DataPointPair<Integer>> list, final Set<Integer> set, final int i, final ExecutorService executorService, final ModifiableCountDownLatch modifiableCountDownLatch) {
        boolean z = (1 << i) < ((long) (SystemInfo.LogicalCores * 2));
        boolean z2 = (1 << (i + 1)) >= ((long) (SystemInfo.LogicalCores * 2));
        if (i > this.maxDepth || set.isEmpty() || list.size() < this.minSamples || list.isEmpty()) {
            modifiableCountDownLatch.countDown();
            return null;
        }
        DecisionStump clone = this.baseStump.clone();
        clone.setPredicting(this.predicting);
        List<List<DataPointPair<Integer>>> trainC = z ? clone.trainC(list, set, executorService) : clone.trainC(list, set);
        final Node node = new Node(clone);
        if (clone.getNumberOfPaths() > 1) {
            for (int i2 = 0; i2 < node.paths.length; i2++) {
                final int i3 = i2;
                final List<DataPointPair<Integer>> list2 = trainC.get(i2);
                modifiableCountDownLatch.countUp();
                if (z2) {
                    executorService.submit(new Runnable() { // from class: jsat.classifiers.trees.DecisionTree.1
                        @Override // java.lang.Runnable
                        public void run() {
                            node.paths[i3] = DecisionTree.this.makeNodeC(list2, new IntSet((Set<Integer>) set), i + 1, executorService, modifiableCountDownLatch);
                        }
                    });
                } else {
                    node.paths[i3] = makeNodeC(list2, new IntSet(set), i + 1, executorService, modifiableCountDownLatch);
                }
            }
        }
        modifiableCountDownLatch.countDown();
        return node;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Node makeNodeR(List<DataPointPair<Double>> list, final Set<Integer> set, final int i, final ExecutorService executorService, final ModifiableCountDownLatch modifiableCountDownLatch) {
        boolean z = (1 << i) < ((long) (SystemInfo.LogicalCores * 2));
        boolean z2 = (1 << (i + 1)) >= ((long) (SystemInfo.LogicalCores * 2));
        if (i > this.maxDepth || set.isEmpty() || list.size() < this.minSamples || list.isEmpty()) {
            modifiableCountDownLatch.countDown();
            return null;
        }
        DecisionStump clone = this.baseStump.clone();
        List<List<DataPointPair<Double>>> trainR = z ? clone.trainR(list, set, executorService) : clone.trainR(list, set);
        if (trainR == null) {
            modifiableCountDownLatch.countDown();
            return null;
        }
        final Node node = new Node(clone);
        if (clone.getNumberOfPaths() > 1) {
            for (int i2 = 0; i2 < node.paths.length; i2++) {
                final int i3 = i2;
                final List<DataPointPair<Double>> list2 = trainR.get(i2);
                modifiableCountDownLatch.countUp();
                if (z2) {
                    executorService.submit(new Runnable() { // from class: jsat.classifiers.trees.DecisionTree.2
                        @Override // java.lang.Runnable
                        public void run() {
                            node.paths[i3] = DecisionTree.this.makeNodeR(list2, new IntSet((Set<Integer>) set), i + 1, executorService, modifiableCountDownLatch);
                        }
                    });
                } else {
                    node.paths[i3] = makeNodeR(list2, new IntSet(set), i + 1, executorService, modifiableCountDownLatch);
                }
            }
        }
        modifiableCountDownLatch.countDown();
        return node;
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        trainC(classificationDataSet, new FakeExecutor());
    }

    public void trainC(ClassificationDataSet classificationDataSet, Set<Integer> set) {
        trainC(classificationDataSet, set, new FakeExecutor());
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return true;
    }

    @Override // jsat.regression.Regressor
    public DecisionTree clone() {
        DecisionTree decisionTree = new DecisionTree(this.maxDepth, this.minSamples, this.pruningMethod, this.testProportion);
        if (this.predicting != null) {
            decisionTree.predicting = this.predicting.m481clone();
        }
        if (this.root != null) {
            decisionTree.root = this.root.mo583clone();
        }
        decisionTree.baseStump = this.baseStump.clone();
        return decisionTree;
    }

    @Override // jsat.classifiers.trees.TreeLearner
    public TreeNodeVisitor getTreeNodeVisitor() {
        return this.root;
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        ArrayList arrayList = new ArrayList(Parameter.getParamsFromMethods(this));
        for (Parameter parameter : this.baseStump.getParameters()) {
            if (!parameter.getName().contains("Gain Method") && !parameter.getName().contains("Numeric Handling")) {
                arrayList.add(parameter);
            }
        }
        return Collections.unmodifiableList(arrayList);
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }
}
