package jsat.classifiers.trees;

import java.util.Arrays;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.DataPoint;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/trees/ImpurityScore.class */
public class ImpurityScore implements Cloneable {
    private double sumOfWeights;
    private double[] counts;
    private ImpurityMeasure impurityMeasure;

    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/trees/ImpurityScore$ImpurityMeasure.class */
    public enum ImpurityMeasure {
        INFORMATION_GAIN,
        INFORMATION_GAIN_RATIO,
        NMI,
        GINI,
        CLASSIFICATION_ERROR
    }

    public ImpurityScore(int i, ImpurityMeasure impurityMeasure) {
        this.sumOfWeights = 0.0d;
        this.counts = new double[i];
        this.impurityMeasure = impurityMeasure;
    }

    private ImpurityScore(ImpurityScore impurityScore) {
        this.sumOfWeights = impurityScore.sumOfWeights;
        this.counts = Arrays.copyOf(impurityScore.counts, impurityScore.counts.length);
        this.impurityMeasure = impurityScore.impurityMeasure;
    }

    public void removePoint(DataPoint dataPoint, int i) {
        removePoint(dataPoint.getWeight(), i);
    }

    public void removePoint(double d, int i) {
        double[] dArr = this.counts;
        dArr[i] = dArr[i] - d;
        this.sumOfWeights -= d;
    }

    public void addPoint(DataPoint dataPoint, int i) {
        addPoint(dataPoint.getWeight(), i);
    }

    public void addPoint(double d, int i) {
        double[] dArr = this.counts;
        dArr[i] = dArr[i] + d;
        this.sumOfWeights += d;
    }

    public double getScore() {
        if (this.sumOfWeights <= 0.0d) {
            return 0.0d;
        }
        double d = 0.0d;
        if (this.impurityMeasure == ImpurityMeasure.INFORMATION_GAIN_RATIO || this.impurityMeasure == ImpurityMeasure.INFORMATION_GAIN || this.impurityMeasure == ImpurityMeasure.NMI) {
            for (double d2 : this.counts) {
                double doubleValue = Double.valueOf(d2).doubleValue() / this.sumOfWeights;
                if (doubleValue > 0.0d) {
                    d += (doubleValue * Math.log(doubleValue)) / Math.log(2.0d);
                }
            }
        } else if (this.impurityMeasure == ImpurityMeasure.GINI) {
            d = 1.0d;
            for (double d3 : this.counts) {
                double d4 = d3 / this.sumOfWeights;
                d -= d4 * d4;
            }
        } else if (this.impurityMeasure == ImpurityMeasure.CLASSIFICATION_ERROR) {
            double d5 = 0.0d;
            for (double d6 : this.counts) {
                d5 = Math.max(d5, d6 / this.sumOfWeights);
            }
            d = 1.0d - d5;
        }
        return Math.abs(d);
    }

    public double getSumOfWeights() {
        return this.sumOfWeights;
    }

    public ImpurityMeasure getImpurityMeasure() {
        return this.impurityMeasure;
    }

    public CategoricalResults getResults() {
        CategoricalResults categoricalResults = new CategoricalResults(this.counts.length);
        for (int i = 0; i < this.counts.length; i++) {
            categoricalResults.setProb(i, this.counts[i] / this.sumOfWeights);
        }
        return categoricalResults;
    }

    public static double gain(ImpurityScore impurityScore, ImpurityScore... impurityScoreArr) {
        return gain(impurityScore, 1.0d, impurityScoreArr);
    }

    public static double gain(ImpurityScore impurityScore, double d, ImpurityScore... impurityScoreArr) {
        double d2 = d * impurityScore.sumOfWeights;
        if (impurityScoreArr[0].impurityMeasure == ImpurityMeasure.NMI) {
            double d3 = 0.0d;
            double d4 = 0.0d;
            double d5 = 0.0d;
            for (int i = 0; i < impurityScore.counts.length; i++) {
                double d6 = (d * impurityScore.counts[i]) / d2;
                if (d6 > 0.0d) {
                    double log = Math.log(d6);
                    d5 += d6 * log;
                    for (int i2 = 0; i2 < impurityScoreArr.length; i2++) {
                        double d7 = impurityScoreArr[i2].sumOfWeights / d2;
                        if (d7 > 0.0d) {
                            double d8 = impurityScoreArr[i2].counts[i] / d2;
                            if (d8 > 0.0d) {
                                d3 += d8 * ((Math.log(d8) - log) - Math.log(d7));
                                if (i == 0) {
                                    d4 += d7 * Math.log(d7);
                                }
                            }
                        }
                    }
                }
            }
            return (2.0d * d3) / (Math.abs(d4) + Math.abs(d5));
        }
        double d9 = 0.0d;
        if (!(impurityScoreArr[0].impurityMeasure == ImpurityMeasure.INFORMATION_GAIN_RATIO)) {
            for (ImpurityScore impurityScore2 : impurityScoreArr) {
                double sumOfWeights = impurityScore2.getSumOfWeights() / d2;
                if (sumOfWeights > 0.0d) {
                    d9 += sumOfWeights * impurityScore2.getScore();
                }
            }
            return impurityScore.getScore() - d9;
        }
        double d10 = 1.0d;
        for (ImpurityScore impurityScore3 : impurityScoreArr) {
            double sumOfWeights2 = impurityScore3.getSumOfWeights() / d2;
            if (sumOfWeights2 > 0.0d) {
                d9 += sumOfWeights2 * impurityScore3.getScore();
                d10 += sumOfWeights2 * (-Math.log(sumOfWeights2));
            }
        }
        return (impurityScore.getScore() - d9) / d10;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public ImpurityScore m586clone() {
        return new ImpurityScore(this);
    }
}
