package jsat.classifiers.bayesian;

import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.distributions.ContinuousDistribution;
import jsat.distributions.DistributionSearch;
import jsat.distributions.Normal;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.Vec;
import jsat.math.MathTricks;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.DoubleList;
import jsat.utils.FakeExecutor;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/bayesian/NaiveBayes.class */
public class NaiveBayes implements Classifier, Parameterized {
    private static final long serialVersionUID = -2437775653277531182L;
    private double[][][] apriori;
    private ContinuousDistribution[][] distributions;
    private NumericalHandeling numericalHandling;
    private double[] priors;
    private boolean sparceInput;
    public static final NumericalHandeling defaultHandling = NumericalHandeling.NORMAL;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/bayesian/NaiveBayes$AprioriCounterRunable.class */
    public class AprioriCounterRunable implements Runnable {
        int i;
        int j;
        List<DataPoint> dataSamples;
        CountDownLatch latch;

        public AprioriCounterRunable(int i, int i2, List<DataPoint> list, CountDownLatch countDownLatch) {
            this.i = i;
            this.j = i2;
            this.dataSamples = list;
            this.latch = countDownLatch;
        }

        @Override // java.lang.Runnable
        public void run() {
            for (DataPoint dataPoint : this.dataSamples) {
                double[] dArr = NaiveBayes.this.apriori[this.i][this.j];
                int categoricalValue = dataPoint.getCategoricalValue(this.j);
                dArr[categoricalValue] = dArr[categoricalValue] + 1.0d;
            }
            double d = 0.0d;
            for (int i = 0; i < NaiveBayes.this.apriori[this.i][this.j].length; i++) {
                d += NaiveBayes.this.apriori[this.i][this.j][i];
            }
            for (int i2 = 0; i2 < NaiveBayes.this.apriori[this.i][this.j].length; i2++) {
                double[] dArr2 = NaiveBayes.this.apriori[this.i][this.j];
                int i3 = i2;
                dArr2[i3] = dArr2[i3] / d;
            }
            this.latch.countDown();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/bayesian/NaiveBayes$DistributionSelectRunable.class */
    public class DistributionSelectRunable implements Runnable {
        int i;
        int j;
        Vec v;
        CountDownLatch countDown;

        public DistributionSelectRunable(int i, int i2, Vec vec, CountDownLatch countDownLatch) {
            this.i = i;
            this.j = i2;
            this.v = vec;
            this.countDown = countDownLatch;
        }

        @Override // java.lang.Runnable
        public void run() {
            try {
                NaiveBayes.this.distributions[this.i][this.j] = NaiveBayes.this.numericalHandling.fit(this.v);
            } catch (ArithmeticException e) {
                NaiveBayes.this.distributions[this.i][this.j] = null;
            }
            this.countDown.countDown();
        }
    }

    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/bayesian/NaiveBayes$NumericalHandeling.class */
    public enum NumericalHandeling {
        NORMAL { // from class: jsat.classifiers.bayesian.NaiveBayes.NumericalHandeling.1
            @Override // jsat.classifiers.bayesian.NaiveBayes.NumericalHandeling
            protected ContinuousDistribution fit(Vec vec) {
                return DistributionSearch.getBestDistribution(vec, new Normal(0.0d, 1.0d));
            }
        },
        BEST_FIT { // from class: jsat.classifiers.bayesian.NaiveBayes.NumericalHandeling.2
            @Override // jsat.classifiers.bayesian.NaiveBayes.NumericalHandeling
            protected ContinuousDistribution fit(Vec vec) {
                return DistributionSearch.getBestDistribution(vec);
            }
        },
        BEST_FIT_KDE { // from class: jsat.classifiers.bayesian.NaiveBayes.NumericalHandeling.3
            private double cutOff = 0.9d;

            @Override // jsat.classifiers.bayesian.NaiveBayes.NumericalHandeling
            protected ContinuousDistribution fit(Vec vec) {
                return DistributionSearch.getBestDistribution(vec, this.cutOff);
            }
        };

        protected abstract ContinuousDistribution fit(Vec vec);
    }

    public NaiveBayes(NumericalHandeling numericalHandeling) {
        this.sparceInput = true;
        this.numericalHandling = numericalHandeling;
    }

    public NaiveBayes() {
        this(defaultHandling);
    }

    public void setNumericalHandling(NumericalHandeling numericalHandeling) {
        this.numericalHandling = numericalHandeling;
    }

    public NumericalHandeling getNumericalHandling() {
        return this.numericalHandling;
    }

    public boolean isSparceInput() {
        return this.sparceInput;
    }

    public void setSparceInput(boolean z) {
        this.sparceInput = z;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        double d;
        double d2;
        CategoricalResults categoricalResults = new CategoricalResults(this.distributions.length);
        double[] dArr = new double[this.distributions.length];
        Vec numericalValues = dataPoint.getNumericalValues();
        double d3 = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < this.distributions.length; i++) {
            double d4 = 0.0d;
            if (this.sparceInput) {
                Iterator<IndexValue> nonZeroIterator = numericalValues.getNonZeroIterator();
                while (nonZeroIterator.hasNext()) {
                    IndexValue next = nonZeroIterator.next();
                    int index = next.getIndex();
                    double logPdf = this.distributions[i][index] == null ? Double.NEGATIVE_INFINITY : this.distributions[i][index].logPdf(next.getValue());
                    d4 = Double.isInfinite(logPdf) ? d4 + Math.log(1.0E-16d) : d4 + logPdf;
                }
            } else {
                for (int i2 = 0; i2 < this.distributions[i].length; i2++) {
                    double logPdf2 = this.distributions[i][i2] == null ? Double.NEGATIVE_INFINITY : this.distributions[i][i2].logPdf(numericalValues.get(i2));
                    if (Double.isInfinite(logPdf2)) {
                        d = d4;
                        d2 = Math.log(1.0E-16d);
                    } else {
                        d = d4;
                        d2 = logPdf2;
                    }
                    d4 = d + d2;
                }
            }
            for (int i3 = 0; i3 < this.apriori[i].length; i3++) {
                d4 += Math.log(this.apriori[i][i3][dataPoint.getCategoricalValue(i3)]);
            }
            double log = d4 + Math.log(this.priors[i]);
            dArr[i] = log;
            d3 = Math.max(d3, log);
        }
        if (d3 == Double.NEGATIVE_INFINITY) {
            for (int i4 = 0; i4 < categoricalResults.size(); i4++) {
                categoricalResults.setProb(i4, 1.0d / categoricalResults.size());
            }
            return categoricalResults;
        }
        double logSumExp = MathTricks.logSumExp(dArr, d3);
        for (int i5 = 0; i5 < categoricalResults.size(); i5++) {
            categoricalResults.setProb(i5, Math.exp(dArr[i5] - logSumExp));
        }
        categoricalResults.normalize();
        return categoricalResults;
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        trainC(classificationDataSet, new FakeExecutor());
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return Collections.unmodifiableList(Parameter.getParamsFromMethods(this));
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v21, types: [double[][], double[][][]] */
    /* JADX WARN: Type inference failed for: r1v7, types: [jsat.distributions.ContinuousDistribution[], jsat.distributions.ContinuousDistribution[][]] */
    @Override // jsat.classifiers.Classifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public Classifier m496clone() {
        NaiveBayes naiveBayes = new NaiveBayes(this.numericalHandling);
        if (this.apriori != null) {
            naiveBayes.apriori = new double[this.apriori.length];
            for (int i = 0; i < this.apriori.length; i++) {
                naiveBayes.apriori[i] = new double[this.apriori[i].length];
                for (int i2 = 0; this.apriori[i].length > 0 && i2 < this.apriori[i][i2].length; i2++) {
                    naiveBayes.apriori[i][i2] = Arrays.copyOf(this.apriori[i][i2], this.apriori[i][i2].length);
                }
            }
        }
        if (this.distributions != null) {
            naiveBayes.distributions = new ContinuousDistribution[this.distributions.length];
            for (int i3 = 0; i3 < this.distributions.length; i3++) {
                naiveBayes.distributions[i3] = new ContinuousDistribution[this.distributions[i3].length];
                for (int i4 = 0; i4 < this.distributions[i3].length; i4++) {
                    naiveBayes.distributions[i3][i4] = this.distributions[i3][i4].mo616clone();
                }
            }
        }
        if (this.priors != null) {
            naiveBayes.priors = Arrays.copyOf(this.priors, this.priors.length);
        }
        return naiveBayes;
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    private Vec getSampleVariableVector(ClassificationDataSet classificationDataSet, int i, int i2) {
        Vec sampleVariableVector = classificationDataSet.getSampleVariableVector(i, i2);
        if (this.sparceInput) {
            DoubleList doubleList = new DoubleList();
            for (int i3 = 0; i3 < sampleVariableVector.length(); i3++) {
                if (sampleVariableVector.get(i3) != 0.0d) {
                    doubleList.add((DoubleList) Double.valueOf(sampleVariableVector.get(i3)));
                }
            }
            sampleVariableVector = new DenseVector(doubleList);
        }
        return sampleVariableVector;
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        int numOfCategories = classificationDataSet.getPredicting().getNumOfCategories();
        this.apriori = new double[numOfCategories][classificationDataSet.getNumCategoricalVars()];
        this.distributions = new ContinuousDistribution[numOfCategories][classificationDataSet.getNumNumericalVars()];
        this.priors = classificationDataSet.getPriors();
        CountDownLatch countDownLatch = new CountDownLatch(numOfCategories * (classificationDataSet.getNumNumericalVars() + classificationDataSet.getNumCategoricalVars()));
        for (int i = 0; i < numOfCategories; i++) {
            for (int i2 = 0; i2 < classificationDataSet.getNumNumericalVars(); i2++) {
                executorService.submit(new DistributionSelectRunable(i, i2, getSampleVariableVector(classificationDataSet, i, i2), countDownLatch));
            }
            List<DataPoint> samples = classificationDataSet.getSamples(i);
            for (int i3 = 0; i3 < classificationDataSet.getNumCategoricalVars(); i3++) {
                this.apriori[i][i3] = new double[classificationDataSet.getCategories()[i3].getNumOfCategories()];
                for (int i4 = 0; i4 < this.apriori[i][i3].length; i4++) {
                    this.apriori[i][i3][i4] = 1.0d;
                }
                executorService.submit(new AprioriCounterRunable(i, i3, samples, countDownLatch));
            }
        }
        try {
            countDownLatch.await();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }
}
