package jsat.classifiers.bayesian;

import java.util.Arrays;
import java.util.Iterator;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.DataPoint;
import jsat.distributions.Normal;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.IndexValue;
import jsat.linear.Vec;
import jsat.math.MathTricks;
import jsat.math.OnLineStatistics;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/bayesian/NaiveBayesUpdateable.class */
public class NaiveBayesUpdateable extends BaseUpdateableClassifier {
    private static final long serialVersionUID = 1835073945715343486L;
    private double[][][] apriori;
    private OnLineStatistics[][] valueStats;
    private double priorSum;
    private double[] priors;
    private boolean sparseInput;

    public NaiveBayesUpdateable() {
        this(true);
    }

    public NaiveBayesUpdateable(boolean z) {
        this.priorSum = 0.0d;
        this.sparseInput = true;
        setSparse(z);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v5, types: [double[][], double[][][]] */
    /* JADX WARN: Type inference failed for: r1v9, types: [jsat.math.OnLineStatistics[], jsat.math.OnLineStatistics[][]] */
    protected NaiveBayesUpdateable(NaiveBayesUpdateable naiveBayesUpdateable) {
        this(naiveBayesUpdateable.sparseInput);
        if (naiveBayesUpdateable.apriori != null) {
            this.apriori = new double[naiveBayesUpdateable.apriori.length];
            this.valueStats = new OnLineStatistics[naiveBayesUpdateable.valueStats.length];
            for (int i = 0; i < naiveBayesUpdateable.apriori.length; i++) {
                this.apriori[i] = new double[naiveBayesUpdateable.apriori[i].length];
                for (int i2 = 0; i2 < naiveBayesUpdateable.apriori[i].length; i2++) {
                    this.apriori[i][i2] = Arrays.copyOf(naiveBayesUpdateable.apriori[i][i2], naiveBayesUpdateable.apriori[i][i2].length);
                }
                this.valueStats[i] = new OnLineStatistics[naiveBayesUpdateable.valueStats[i].length];
                for (int i3 = 0; i3 < this.valueStats[i].length; i3++) {
                    this.valueStats[i][i3] = new OnLineStatistics(naiveBayesUpdateable.valueStats[i][i3]);
                }
            }
            this.priorSum = naiveBayesUpdateable.priorSum;
            this.priors = Arrays.copyOf(naiveBayesUpdateable.priors, naiveBayesUpdateable.priors.length);
        }
    }

    @Override // jsat.classifiers.BaseUpdateableClassifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public NaiveBayesUpdateable mo480clone() {
        return new NaiveBayesUpdateable(this);
    }

    @Override // jsat.classifiers.UpdateableClassifier
    public void setUp(CategoricalData[] categoricalDataArr, int i, CategoricalData categoricalData) {
        int numOfCategories = categoricalData.getNumOfCategories();
        this.apriori = new double[numOfCategories][categoricalDataArr.length];
        this.valueStats = new OnLineStatistics[numOfCategories][i];
        this.priors = new double[numOfCategories];
        this.priorSum = numOfCategories;
        Arrays.fill(this.priors, 1.0d);
        for (int i2 = 0; i2 < numOfCategories; i2++) {
            for (int i3 = 0; i3 < categoricalDataArr.length; i3++) {
                this.apriori[i2][i3] = new double[categoricalDataArr[i3].getNumOfCategories()];
                for (int i4 = 0; i4 < this.apriori[i2][i3].length; i4++) {
                    this.apriori[i2][i3][i4] = 1.0d;
                }
            }
            for (int i5 = 0; i5 < i; i5++) {
                this.valueStats[i2][i5] = new OnLineStatistics();
            }
        }
    }

    @Override // jsat.classifiers.UpdateableClassifier
    public void update(DataPoint dataPoint, int i) {
        double weight = dataPoint.getWeight();
        Vec numericalValues = dataPoint.getNumericalValues();
        if (this.sparseInput) {
            Iterator<IndexValue> it = numericalValues.iterator();
            while (it.hasNext()) {
                IndexValue next = it.next();
                this.valueStats[i][next.getIndex()].add(next.getValue(), weight);
            }
        } else {
            for (int i2 = 0; i2 < numericalValues.length(); i2++) {
                this.valueStats[i][i2].add(numericalValues.get(i2), weight);
            }
        }
        int[] categoricalValues = dataPoint.getCategoricalValues();
        for (int i3 = 0; i3 < this.apriori[i].length; i3++) {
            double[] dArr = this.apriori[i][i3];
            int i4 = categoricalValues[i3];
            dArr[i4] = dArr[i4] + 1.0d;
        }
        this.priorSum += 1.0d;
        double[] dArr2 = this.priors;
        dArr2[i] = dArr2[i] + 1.0d;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        double d;
        double d2;
        if (this.apriori == null) {
            throw new UntrainedModelException("Model has not been intialized");
        }
        CategoricalResults categoricalResults = new CategoricalResults(this.apriori.length);
        double[] dArr = new double[this.apriori.length];
        double d3 = Double.NEGATIVE_INFINITY;
        Vec numericalValues = dataPoint.getNumericalValues();
        for (int i = 0; i < this.valueStats.length; i++) {
            double d4 = 0.0d;
            if (this.sparseInput) {
                Iterator<IndexValue> it = numericalValues.iterator();
                while (it.hasNext()) {
                    IndexValue next = it.next();
                    int index = next.getIndex();
                    double logPdf = Normal.logPdf(next.getValue(), this.valueStats[i][index].getMean(), this.valueStats[i][index].getStandardDeviation());
                    d4 = Double.isNaN(logPdf) ? d4 + Math.log(1.0E-16d) : Double.isInfinite(logPdf) ? d4 + Math.log(1.0E-16d) : d4 + logPdf;
                }
            } else {
                for (int i2 = 0; i2 < this.valueStats[i].length; i2++) {
                    double logPdf2 = Normal.logPdf(numericalValues.get(i2), this.valueStats[i][i2].getMean(), this.valueStats[i][i2].getStandardDeviation());
                    if (Double.isInfinite(logPdf2)) {
                        d = d4;
                        d2 = Math.log(1.0E-16d);
                    } else {
                        d = d4;
                        d2 = logPdf2;
                    }
                    d4 = d + d2;
                }
            }
            for (int i3 = 0; i3 < this.apriori[i].length; i3++) {
                double d5 = 0.0d;
                for (int i4 = 0; i4 < this.apriori[i][i3].length; i4++) {
                    d5 += this.apriori[i][i3][i4];
                }
                d4 += Math.log(this.apriori[i][i3][dataPoint.getCategoricalValue(i3)] / d5);
            }
            double log = d4 + Math.log(this.priors[i] / this.priorSum);
            dArr[i] = log;
            d3 = Math.max(d3, log);
        }
        double logSumExp = MathTricks.logSumExp(dArr, d3);
        for (int i5 = 0; i5 < categoricalResults.size(); i5++) {
            categoricalResults.setProb(i5, Math.exp(dArr[i5] - logSumExp));
        }
        categoricalResults.normalize();
        return categoricalResults;
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return true;
    }

    public boolean isSparseInput() {
        return this.sparseInput;
    }

    public void setSparse(boolean z) {
        this.sparseInput = z;
    }
}
