package jsat.classifiers.trees;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.Stack;
import java.util.concurrent.ExecutorService;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.classifiers.trees.ImpurityScore;
import jsat.math.OnLineStatistics;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.utils.IntList;
import jsat.utils.IntSet;
import jsat.utils.ListUtils;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/trees/ExtraTree.class */
public class ExtraTree implements Classifier, Regressor, TreeLearner, Parameterized {
    private static final long serialVersionUID = 7433728970041876327L;
    private int stopSize;
    private int selectionCount;
    private CategoricalData predicting;
    private boolean binaryCategoricalSplitting;
    private int numNumericFeatures;
    private ImpurityScore.ImpurityMeasure impMeasure;
    private TreeNodeVisitor root;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/trees/ExtraTree$NodeBase.class */
    public static abstract class NodeBase extends TreeNodeVisitor {
        private static final long serialVersionUID = 6783491817922690901L;
        protected TreeNodeVisitor[] children;

        public NodeBase() {
        }

        public NodeBase(int i) {
            this.children = new TreeNodeVisitor[i];
        }

        public NodeBase(NodeBase nodeBase) {
            if (nodeBase.children != null) {
                this.children = new TreeNodeVisitor[nodeBase.children.length];
                for (int i = 0; i < nodeBase.children.length; i++) {
                    if (nodeBase.children[i] != null) {
                        this.children[i] = nodeBase.children[i].mo583clone();
                    }
                }
            }
        }

        @Override // jsat.classifiers.trees.TreeNodeVisitor
        public int childrenCount() {
            return this.children.length;
        }

        @Override // jsat.classifiers.trees.TreeNodeVisitor
        public boolean isLeaf() {
            if (this.children == null) {
                return true;
            }
            for (int i = 0; i < this.children.length; i++) {
                if (this.children[i] != null) {
                    return false;
                }
            }
            return true;
        }

        @Override // jsat.classifiers.trees.TreeNodeVisitor
        public TreeNodeVisitor getChild(int i) {
            if (i < 0 || i > childrenCount()) {
                return null;
            }
            return this.children[i];
        }

        @Override // jsat.classifiers.trees.TreeNodeVisitor
        public void disablePath(int i) {
            if (isLeaf()) {
                return;
            }
            this.children[i] = null;
        }

        @Override // jsat.classifiers.trees.TreeNodeVisitor
        public boolean isPathDisabled(int i) {
            return isLeaf() || this.children[i] == null;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/trees/ExtraTree$NodeC.class */
    public static class NodeC extends NodeBase {
        private static final long serialVersionUID = -3977497656918695759L;
        private CategoricalResults crResult;

        public NodeC(CategoricalResults categoricalResults) {
            this.crResult = categoricalResults;
            this.children = null;
        }

        public NodeC(CategoricalResults categoricalResults, int i) {
            super(i);
            this.crResult = categoricalResults;
        }

        public NodeC(NodeC nodeC) {
            super(nodeC);
            this.crResult = nodeC.crResult.m482clone();
        }

        @Override // jsat.classifiers.trees.TreeNodeVisitor
        public CategoricalResults localClassify(DataPoint dataPoint) {
            return this.crResult;
        }

        @Override // jsat.classifiers.trees.TreeNodeVisitor
        public int getPath(DataPoint dataPoint) {
            return -1;
        }

        @Override // jsat.classifiers.trees.TreeNodeVisitor
        /* renamed from: clone */
        public TreeNodeVisitor mo583clone() {
            return new NodeC(this);
        }

        @Override // jsat.classifiers.trees.TreeNodeVisitor
        public Collection<Integer> featuresUsed() {
            return Collections.EMPTY_SET;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/trees/ExtraTree$NodeCCat.class */
    public class NodeCCat extends NodeC {
        private static final long serialVersionUID = 7413428280703235600L;
        private int catAtt;
        private int[] leftBranch;

        public NodeCCat(int i, int i2, CategoricalResults categoricalResults) {
            super(categoricalResults, i2);
            this.catAtt = i;
            this.leftBranch = null;
        }

        public NodeCCat(int i, Set<Integer> set, CategoricalResults categoricalResults) {
            super(categoricalResults, 2);
            this.catAtt = i;
            this.leftBranch = new int[set.size()];
            int i2 = 0;
            Iterator<Integer> it = set.iterator();
            while (it.hasNext()) {
                int i3 = i2;
                i2++;
                this.leftBranch[i3] = it.next().intValue();
            }
            Arrays.sort(this.leftBranch);
        }

        public NodeCCat(NodeCCat nodeCCat) {
            super(nodeCCat);
            this.catAtt = nodeCCat.catAtt;
            if (nodeCCat.leftBranch != null) {
                this.leftBranch = Arrays.copyOf(nodeCCat.leftBranch, nodeCCat.leftBranch.length);
            }
        }

        @Override // jsat.classifiers.trees.ExtraTree.NodeC, jsat.classifiers.trees.TreeNodeVisitor
        public int getPath(DataPoint dataPoint) {
            int[] categoricalValues = dataPoint.getCategoricalValues();
            return this.leftBranch == null ? categoricalValues[this.catAtt] : Arrays.binarySearch(this.leftBranch, categoricalValues[this.catAtt]) < 0 ? 1 : 0;
        }

        @Override // jsat.classifiers.trees.ExtraTree.NodeC, jsat.classifiers.trees.TreeNodeVisitor
        /* renamed from: clone */
        public TreeNodeVisitor mo583clone() {
            return new NodeCCat(this);
        }

        @Override // jsat.classifiers.trees.ExtraTree.NodeC, jsat.classifiers.trees.TreeNodeVisitor
        public Collection<Integer> featuresUsed() {
            IntList intList = new IntList(1);
            intList.add(this.catAtt + ExtraTree.this.numNumericFeatures);
            return intList;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/trees/ExtraTree$NodeCNum.class */
    public static class NodeCNum extends NodeC {
        private static final long serialVersionUID = 3967180517059509869L;
        private int numerAtt;
        private double threshold;

        public NodeCNum(int i, double d, CategoricalResults categoricalResults) {
            super(categoricalResults, 2);
            this.numerAtt = i;
            this.threshold = d;
        }

        public NodeCNum(NodeCNum nodeCNum) {
            super(nodeCNum);
            this.numerAtt = nodeCNum.numerAtt;
            this.threshold = nodeCNum.threshold;
        }

        @Override // jsat.classifiers.trees.ExtraTree.NodeC, jsat.classifiers.trees.TreeNodeVisitor
        public int getPath(DataPoint dataPoint) {
            return dataPoint.getNumericalValues().get(this.numerAtt) <= this.threshold ? 0 : 1;
        }

        @Override // jsat.classifiers.trees.ExtraTree.NodeC, jsat.classifiers.trees.TreeNodeVisitor
        /* renamed from: clone */
        public TreeNodeVisitor mo583clone() {
            return new NodeCNum(this);
        }

        @Override // jsat.classifiers.trees.ExtraTree.NodeC, jsat.classifiers.trees.TreeNodeVisitor
        public Collection<Integer> featuresUsed() {
            IntList intList = new IntList(1);
            intList.add(this.numerAtt);
            return intList;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/trees/ExtraTree$NodeR.class */
    public static class NodeR extends NodeBase {
        private static final long serialVersionUID = -2461046505444129890L;
        private double result;

        public NodeR(double d) {
            this.result = d;
        }

        public NodeR(double d, int i) {
            super(i);
            this.result = d;
        }

        public NodeR(NodeR nodeR) {
            super(nodeR);
            this.result = nodeR.result;
        }

        @Override // jsat.classifiers.trees.TreeNodeVisitor
        public double localRegress(DataPoint dataPoint) {
            return this.result;
        }

        @Override // jsat.classifiers.trees.TreeNodeVisitor
        public int getPath(DataPoint dataPoint) {
            return -1;
        }

        @Override // jsat.classifiers.trees.TreeNodeVisitor
        /* renamed from: clone */
        public TreeNodeVisitor mo583clone() {
            return new NodeR(this);
        }

        @Override // jsat.classifiers.trees.TreeNodeVisitor
        public Collection<Integer> featuresUsed() {
            return Collections.EMPTY_SET;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/trees/ExtraTree$NodeRCat.class */
    public class NodeRCat extends NodeR {
        private static final long serialVersionUID = 5868393594474661054L;
        private int catAtt;
        private int[] leftBranch;

        public NodeRCat(int i, int i2, double d) {
            super(d, i2);
            this.catAtt = i;
            this.leftBranch = null;
        }

        public NodeRCat(int i, Set<Integer> set, double d) {
            super(d, 2);
            this.catAtt = i;
            this.leftBranch = new int[set.size()];
            int i2 = 0;
            Iterator<Integer> it = set.iterator();
            while (it.hasNext()) {
                int i3 = i2;
                i2++;
                this.leftBranch[i3] = it.next().intValue();
            }
            Arrays.sort(this.leftBranch);
        }

        public NodeRCat(NodeRCat nodeRCat) {
            super(nodeRCat);
            this.catAtt = nodeRCat.catAtt;
            if (nodeRCat.leftBranch != null) {
                this.leftBranch = Arrays.copyOf(nodeRCat.leftBranch, nodeRCat.leftBranch.length);
            }
        }

        @Override // jsat.classifiers.trees.ExtraTree.NodeR, jsat.classifiers.trees.TreeNodeVisitor
        public int getPath(DataPoint dataPoint) {
            int[] categoricalValues = dataPoint.getCategoricalValues();
            return this.leftBranch == null ? categoricalValues[this.catAtt] : Arrays.binarySearch(this.leftBranch, categoricalValues[this.catAtt]) < 0 ? 1 : 0;
        }

        @Override // jsat.classifiers.trees.ExtraTree.NodeR, jsat.classifiers.trees.TreeNodeVisitor
        public Collection<Integer> featuresUsed() {
            IntList intList = new IntList(1);
            intList.add(this.catAtt + ExtraTree.this.numNumericFeatures);
            return intList;
        }

        @Override // jsat.classifiers.trees.ExtraTree.NodeR, jsat.classifiers.trees.TreeNodeVisitor
        /* renamed from: clone */
        public TreeNodeVisitor mo583clone() {
            return new NodeRCat(this);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/trees/ExtraTree$NodeRNum.class */
    public static class NodeRNum extends NodeR {
        private static final long serialVersionUID = -6775472771777960211L;
        private int numerAtt;
        private double threshold;

        public NodeRNum(int i, double d, double d2) {
            super(d2, 2);
            this.numerAtt = i;
            this.threshold = d;
        }

        public NodeRNum(NodeRNum nodeRNum) {
            super(nodeRNum);
            this.numerAtt = nodeRNum.numerAtt;
            this.threshold = nodeRNum.threshold;
        }

        @Override // jsat.classifiers.trees.ExtraTree.NodeR, jsat.classifiers.trees.TreeNodeVisitor
        public int getPath(DataPoint dataPoint) {
            return dataPoint.getNumericalValues().get(this.numerAtt) <= this.threshold ? 0 : 1;
        }

        @Override // jsat.classifiers.trees.ExtraTree.NodeR, jsat.classifiers.trees.TreeNodeVisitor
        /* renamed from: clone */
        public TreeNodeVisitor mo583clone() {
            return new NodeRNum(this);
        }

        @Override // jsat.classifiers.trees.ExtraTree.NodeR, jsat.classifiers.trees.TreeNodeVisitor
        public Collection<Integer> featuresUsed() {
            IntList intList = new IntList(1);
            intList.add(this.numerAtt);
            return intList;
        }
    }

    public ExtraTree() {
        this(Integer.MAX_VALUE, 5);
    }

    public ExtraTree(int i, int i2) {
        this.binaryCategoricalSplitting = true;
        this.impMeasure = ImpurityScore.ImpurityMeasure.NMI;
        this.stopSize = i2;
        this.selectionCount = i;
        this.impMeasure = ImpurityScore.ImpurityMeasure.NMI;
    }

    public void setImpurityMeasure(ImpurityScore.ImpurityMeasure impurityMeasure) {
        this.impMeasure = impurityMeasure;
    }

    public ImpurityScore.ImpurityMeasure getImpurityMeasure() {
        return this.impMeasure;
    }

    public void setStopSize(int i) {
        if (i <= 0) {
            throw new ArithmeticException("The stopping size must be a positive value");
        }
        this.stopSize = i;
    }

    public int getStopSize() {
        return this.stopSize;
    }

    public void setSelectionCount(int i) {
        this.selectionCount = i;
    }

    public int getSelectionCount() {
        return this.selectionCount;
    }

    public void setBinaryCategoricalSplitting(boolean z) {
        this.binaryCategoricalSplitting = z;
    }

    public boolean isBinaryCategoricalSplitting() {
        return this.binaryCategoricalSplitting;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        return this.root.classify(dataPoint);
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        trainC(classificationDataSet);
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        Random random = new Random();
        Stack<List<DataPointPair<Integer>>> stack = new Stack<>();
        IntList intList = new IntList(classificationDataSet.getNumFeatures());
        ListUtils.addRange(intList, 0, classificationDataSet.getNumFeatures(), 1);
        List<DataPointPair<Integer>> asDPPList = classificationDataSet.getAsDPPList();
        this.predicting = classificationDataSet.getPredicting();
        ImpurityScore impurityScore = new ImpurityScore(this.predicting.getNumOfCategories(), this.impMeasure);
        for (DataPointPair<Integer> dataPointPair : asDPPList) {
            impurityScore.addPoint(dataPointPair.getDataPoint(), dataPointPair.getPair().intValue());
        }
        this.numNumericFeatures = classificationDataSet.getNumNumericalVars();
        this.root = trainC(impurityScore, asDPPList, intList, classificationDataSet.getCategories(), random, stack);
    }

    private TreeNodeVisitor trainC(ImpurityScore impurityScore, List<DataPointPair<Integer>> list, List<Integer> list2, CategoricalData[] categoricalDataArr, Random random, Stack<List<DataPointPair<Integer>>> stack) {
        NodeBase nodeCNum;
        ImpurityScore[] createScores;
        ArrayList arrayList;
        if (list.size() < this.stopSize || impurityScore.getScore() == 0.0d) {
            if (list.isEmpty()) {
                return null;
            }
            return new NodeC(impurityScore.getResults());
        }
        double d = Double.NEGATIVE_INFINITY;
        double d2 = Double.NaN;
        int i = -1;
        ImpurityScore[] impurityScoreArr = null;
        ArrayList arrayList2 = null;
        IntSet intSet = null;
        Collections.shuffle(list2);
        int min = Math.min(this.selectionCount, list2.size());
        for (int i2 = 0; i2 < min; i2++) {
            double d3 = Double.NaN;
            IntSet intSet2 = null;
            int intValue = list2.get(i2).intValue();
            if (intValue < categoricalDataArr.length) {
                int numOfCategories = categoricalDataArr[intValue].getNumOfCategories();
                if (this.binaryCategoricalSplitting || numOfCategories == 2) {
                    createScores = createScores(2);
                    IntSet intSet3 = new IntSet(numOfCategories * 2);
                    Iterator<DataPointPair<Integer>> it = list.iterator();
                    while (it.hasNext()) {
                        intSet3.add((IntSet) Integer.valueOf(it.next().getDataPoint().getCategoricalValue(intValue)));
                    }
                    if (intSet3.size() == 1) {
                        return new NodeC(impurityScore.getResults());
                    }
                    intSet2 = new IntSet(numOfCategories);
                    ListUtils.randomSample(intSet3, intSet2, random.nextInt(intSet3.size() - 1) + 1, random);
                    arrayList = new ArrayList(2);
                    fillList(2, stack, arrayList);
                    for (DataPointPair<Integer> dataPointPair : list) {
                        DataPoint dataPoint = dataPointPair.getDataPoint();
                        int i3 = intSet2.contains(Integer.valueOf(dataPointPair.getDataPoint().getCategoricalValue(intValue))) ? 0 : 1;
                        createScores[i3].addPoint(dataPoint, dataPointPair.getPair().intValue());
                        ((List) arrayList.get(i3)).add(dataPointPair);
                    }
                } else {
                    createScores = createScores(numOfCategories);
                    arrayList = new ArrayList(numOfCategories);
                    fillList(numOfCategories, stack, arrayList);
                    for (DataPointPair<Integer> dataPointPair2 : list) {
                        DataPoint dataPoint2 = dataPointPair2.getDataPoint();
                        createScores[dataPoint2.getCategoricalValue(intValue)].addPoint(dataPoint2, dataPointPair2.getPair().intValue());
                        ((List) arrayList.get(dataPoint2.getCategoricalValue(intValue))).add(dataPointPair2);
                    }
                }
            } else {
                int length = intValue - categoricalDataArr.length;
                double d4 = Double.POSITIVE_INFINITY;
                double d5 = Double.NEGATIVE_INFINITY;
                Iterator<DataPointPair<Integer>> it2 = list.iterator();
                while (it2.hasNext()) {
                    double d6 = it2.next().getVector().get(length);
                    d4 = Math.min(d4, d6);
                    d5 = Math.max(d5, d6);
                }
                d3 = (random.nextDouble() * (d5 - d4)) + d4;
                createScores = createScores(2);
                arrayList = new ArrayList(2);
                fillList(2, stack, arrayList);
                for (DataPointPair<Integer> dataPointPair3 : list) {
                    int i4 = dataPointPair3.getVector().get(length) <= d3 ? 0 : 1;
                    ((List) arrayList.get(i4)).add(dataPointPair3);
                    createScores[i4].addPoint(dataPointPair3.getDataPoint(), dataPointPair3.getPair().intValue());
                }
            }
            double gain = ImpurityScore.gain(impurityScore, createScores);
            if (gain > d) {
                d = gain;
                i = intValue;
                d2 = d3;
                impurityScoreArr = createScores;
                if (arrayList2 != null) {
                    fillStack(stack, arrayList2);
                }
                arrayList2 = arrayList;
                intSet = intSet2;
            } else {
                fillStack(stack, arrayList);
            }
        }
        fillStack(stack, Arrays.asList(list));
        if (i >= categoricalDataArr.length) {
            nodeCNum = new NodeCNum(i - categoricalDataArr.length, d2, impurityScore.getResults());
        } else if (arrayList2.size() == 2) {
            nodeCNum = new NodeCCat(i, intSet, impurityScore.getResults());
        } else {
            nodeCNum = new NodeCCat(min, arrayList2.size(), impurityScore.getResults());
            list2.remove(new Integer(i));
        }
        for (int i5 = 0; i5 < nodeCNum.children.length; i5++) {
            nodeCNum.children[i5] = trainC(impurityScoreArr[i5], (List) arrayList2.get(i5), list2, categoricalDataArr, random, stack);
        }
        return nodeCNum;
    }

    private TreeNodeVisitor train(OnLineStatistics onLineStatistics, List<DataPointPair<Double>> list, List<Integer> list2, CategoricalData[] categoricalDataArr, Random random, Stack<List<DataPointPair<Double>>> stack) {
        NodeBase nodeRNum;
        OnLineStatistics[] createStats;
        ArrayList arrayList;
        if (list.size() < this.stopSize || onLineStatistics.getVarance() <= 0.0d || Double.isNaN(onLineStatistics.getVarance())) {
            return new NodeR(onLineStatistics.getMean());
        }
        double d = Double.NEGATIVE_INFINITY;
        double d2 = Double.NaN;
        int i = -1;
        OnLineStatistics[] onLineStatisticsArr = null;
        ArrayList arrayList2 = null;
        IntSet intSet = null;
        Collections.shuffle(list2);
        int min = Math.min(this.selectionCount, list2.size());
        for (int i2 = 0; i2 < min; i2++) {
            double d3 = Double.NaN;
            IntSet intSet2 = null;
            int intValue = list2.get(i2).intValue();
            if (intValue < categoricalDataArr.length) {
                int numOfCategories = categoricalDataArr[intValue].getNumOfCategories();
                if (this.binaryCategoricalSplitting || numOfCategories == 2) {
                    createStats = createStats(2);
                    IntSet intSet3 = new IntSet(numOfCategories * 2);
                    Iterator<DataPointPair<Double>> it = list.iterator();
                    while (it.hasNext()) {
                        intSet3.add((IntSet) Integer.valueOf(it.next().getDataPoint().getCategoricalValue(intValue)));
                    }
                    if (intSet3.size() == 1) {
                        return new NodeR(onLineStatistics.getMean());
                    }
                    intSet2 = new IntSet(numOfCategories);
                    ListUtils.randomSample(intSet3, intSet2, random.nextInt(intSet3.size() - 1) + 1, random);
                    arrayList = new ArrayList(2);
                    fillList(2, stack, arrayList);
                    for (DataPointPair<Double> dataPointPair : list) {
                        DataPoint dataPoint = dataPointPair.getDataPoint();
                        int i3 = intSet2.contains(Integer.valueOf(dataPointPair.getDataPoint().getCategoricalValue(intValue))) ? 0 : 1;
                        createStats[i3].add(dataPointPair.getPair().doubleValue(), dataPoint.getWeight());
                        ((List) arrayList.get(i3)).add(dataPointPair);
                    }
                } else {
                    createStats = createStats(numOfCategories);
                    arrayList = new ArrayList(numOfCategories);
                    fillList(numOfCategories, stack, arrayList);
                    for (DataPointPair<Double> dataPointPair2 : list) {
                        DataPoint dataPoint2 = dataPointPair2.getDataPoint();
                        createStats[dataPoint2.getCategoricalValue(intValue)].add(dataPointPair2.getPair().doubleValue(), dataPoint2.getWeight());
                        ((List) arrayList.get(dataPoint2.getCategoricalValue(intValue))).add(dataPointPair2);
                    }
                }
            } else {
                int length = intValue - categoricalDataArr.length;
                double d4 = Double.POSITIVE_INFINITY;
                double d5 = Double.NEGATIVE_INFINITY;
                Iterator<DataPointPair<Double>> it2 = list.iterator();
                while (it2.hasNext()) {
                    double d6 = it2.next().getVector().get(length);
                    d4 = Math.min(d4, d6);
                    d5 = Math.max(d5, d6);
                }
                d3 = (random.nextDouble() * (d5 - d4)) + d4;
                createStats = createStats(2);
                arrayList = new ArrayList(2);
                fillList(2, stack, arrayList);
                for (DataPointPair<Double> dataPointPair3 : list) {
                    int i4 = dataPointPair3.getVector().get(length) <= d3 ? 0 : 1;
                    ((List) arrayList.get(i4)).add(dataPointPair3);
                    createStats[i4].add(dataPointPair3.getPair().doubleValue(), dataPointPair3.getDataPoint().getWeight());
                }
            }
            double d7 = 1.0d;
            double varance = onLineStatistics.getVarance();
            double sumOfWeights = onLineStatistics.getSumOfWeights();
            for (OnLineStatistics onLineStatistics2 : createStats) {
                d7 -= (onLineStatistics2.getSumOfWeights() / sumOfWeights) * (onLineStatistics2.getVarance() / varance);
            }
            if (d7 > d) {
                d = d7;
                i = intValue;
                d2 = d3;
                onLineStatisticsArr = createStats;
                if (arrayList2 != null) {
                    fillStack(stack, arrayList2);
                }
                arrayList2 = arrayList;
                intSet = intSet2;
            } else {
                fillStack(stack, arrayList);
            }
        }
        fillStack(stack, Arrays.asList(list));
        if (i < 0) {
            return new NodeR(onLineStatistics.getMean());
        }
        if (i >= categoricalDataArr.length) {
            nodeRNum = new NodeRNum(i - categoricalDataArr.length, d2, onLineStatistics.getMean());
        } else if (arrayList2.size() == 2) {
            nodeRNum = new NodeRCat(i, intSet, onLineStatistics.getMean());
        } else {
            nodeRNum = new NodeRCat(min, arrayList2.size(), onLineStatistics.getMean());
            list2.remove(new Integer(i));
        }
        for (int i5 = 0; i5 < nodeRNum.children.length; i5++) {
            nodeRNum.children[i5] = train(onLineStatisticsArr[i5], (List) arrayList2.get(i5), list2, categoricalDataArr, random, stack);
        }
        return nodeRNum;
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return true;
    }

    @Override // jsat.regression.Regressor
    /* renamed from: clone, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public ExtraTree m571clone() {
        ExtraTree extraTree = new ExtraTree(this.selectionCount, this.stopSize);
        extraTree.impMeasure = this.impMeasure;
        extraTree.binaryCategoricalSplitting = this.binaryCategoricalSplitting;
        if (this.predicting != null) {
            extraTree.predicting = this.predicting.m481clone();
        }
        if (this.root != null) {
            extraTree.root = this.root.mo583clone();
        }
        extraTree.numNumericFeatures = this.numNumericFeatures;
        return extraTree;
    }

    public TreeNodeVisitor getTreeNodeVisitor() {
        return this.root;
    }

    private static <T> void fillList(int i, Stack<List<T>> stack, List<List<T>> list) {
        for (int i2 = 0; i2 < i; i2++) {
            if (stack.isEmpty()) {
                list.add(new ArrayList());
            } else {
                list.add(stack.pop());
            }
        }
    }

    private static <T> void fillStack(Stack<List<T>> stack, List<List<T>> list) {
        for (List<T> list2 : list) {
            list2.clear();
            stack.push(list2);
        }
    }

    private ImpurityScore[] createScores(int i) {
        ImpurityScore[] impurityScoreArr = new ImpurityScore[i];
        for (int i2 = 0; i2 < impurityScoreArr.length; i2++) {
            impurityScoreArr[i2] = new ImpurityScore(this.predicting.getNumOfCategories(), this.impMeasure);
        }
        return impurityScoreArr;
    }

    public double regress(DataPoint dataPoint) {
        return this.root.regress(dataPoint);
    }

    public void train(RegressionDataSet regressionDataSet, ExecutorService executorService) {
        train(regressionDataSet);
    }

    public void train(RegressionDataSet regressionDataSet) {
        Random random = new Random();
        Stack<List<DataPointPair<Double>>> stack = new Stack<>();
        IntList intList = new IntList(regressionDataSet.getNumFeatures());
        ListUtils.addRange(intList, 0, regressionDataSet.getNumFeatures(), 1);
        List<DataPointPair<Double>> asDPPList = regressionDataSet.getAsDPPList();
        OnLineStatistics onLineStatistics = new OnLineStatistics();
        for (DataPointPair<Double> dataPointPair : asDPPList) {
            onLineStatistics.add(dataPointPair.getPair().doubleValue(), dataPointPair.getDataPoint().getWeight());
        }
        this.numNumericFeatures = regressionDataSet.getNumNumericalVars();
        this.root = train(onLineStatistics, asDPPList, intList, regressionDataSet.getCategories(), random, stack);
    }

    private OnLineStatistics[] createStats(int i) {
        OnLineStatistics[] onLineStatisticsArr = new OnLineStatistics[i];
        for (int i2 = 0; i2 < onLineStatisticsArr.length; i2++) {
            onLineStatisticsArr[i2] = new OnLineStatistics();
        }
        return onLineStatisticsArr;
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }
}
