package jsat.classifiers.bayesian;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.distributions.multivariate.MultivariateDistribution;
import jsat.exceptions.FailedToFitException;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/bayesian/BestClassDistribution.class */
public class BestClassDistribution implements Classifier, Parameterized {
    private static final long serialVersionUID = -1746145372146154228L;
    private MultivariateDistribution baseDist;
    private List<MultivariateDistribution> dists;
    private double[] priors;
    private boolean usePriors;
    public static final boolean USE_PRIORS = true;

    public BestClassDistribution(MultivariateDistribution multivariateDistribution) {
        this(multivariateDistribution, true);
    }

    public BestClassDistribution(MultivariateDistribution multivariateDistribution, boolean z) {
        this.baseDist = multivariateDistribution;
        this.usePriors = z;
    }

    public BestClassDistribution(BestClassDistribution bestClassDistribution) {
        if (bestClassDistribution.priors != null) {
            this.priors = Arrays.copyOf(bestClassDistribution.priors, bestClassDistribution.priors.length);
        }
        this.baseDist = bestClassDistribution.baseDist.mo630clone();
        if (bestClassDistribution.dists != null) {
            this.dists = new ArrayList(bestClassDistribution.dists.size());
            Iterator<MultivariateDistribution> it = bestClassDistribution.dists.iterator();
            while (it.hasNext()) {
                MultivariateDistribution next = it.next();
                this.dists.add(next == null ? null : next.mo630clone());
            }
        }
    }

    public void setUsePriors(boolean z) {
        this.usePriors = z;
    }

    public boolean isUsePriors() {
        return this.usePriors;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        CategoricalResults categoricalResults = new CategoricalResults(this.dists.size());
        for (int i = 0; i < this.dists.size(); i++) {
            if (this.dists.get(i) != null) {
                double d = 0.0d;
                try {
                    d = this.dists.get(i).pdf(dataPoint.getNumericalValues());
                } catch (ArithmeticException e) {
                }
                if (this.usePriors) {
                    d *= this.priors[i];
                }
                categoricalResults.setProb(i, d);
            }
        }
        categoricalResults.normalize();
        return categoricalResults;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        try {
            this.dists = new ArrayList();
            this.priors = classificationDataSet.getPriors();
            ArrayList arrayList = new ArrayList();
            final MultivariateDistribution multivariateDistribution = this.baseDist;
            for (int i = 0; i < classificationDataSet.getPredicting().getNumOfCategories(); i++) {
                final List<DataPoint> samples = classificationDataSet.getSamples(i);
                arrayList.add(executorService.submit(new Callable<MultivariateDistribution>() { // from class: jsat.classifiers.bayesian.BestClassDistribution.1
                    /* JADX WARN: Can't rename method to resolve collision */
                    @Override // java.util.concurrent.Callable
                    public MultivariateDistribution call() throws Exception {
                        if (samples.isEmpty()) {
                            return null;
                        }
                        MultivariateDistribution mo630clone = multivariateDistribution.mo630clone();
                        mo630clone.setUsingDataList(samples);
                        return mo630clone;
                    }
                }));
            }
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                this.dists.add(((Future) it.next()).get());
            }
        } catch (Exception e) {
            Logger.getLogger(MultivariateNormals.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
            throw new FailedToFitException(e);
        }
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        this.priors = classificationDataSet.getPriors();
        this.dists = new ArrayList(classificationDataSet.getClassSize());
        for (int i = 0; i < classificationDataSet.getClassSize(); i++) {
            MultivariateDistribution mo630clone = this.baseDist.mo630clone();
            List<DataPoint> samples = classificationDataSet.getSamples(i);
            if (samples.isEmpty()) {
                this.dists.add(null);
            } else {
                mo630clone.setUsingDataList(samples);
                this.dists.add(mo630clone);
            }
        }
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    @Override // 
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public Classifier mo493clone() {
        return new BestClassDistribution(this);
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }
}
