package jsat.classifiers;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.FakeExecutor;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/OneVSAll.class */
public class OneVSAll implements Classifier, Parameterized {
    private static final long serialVersionUID = -326668337438092217L;
    private Classifier[] oneVsAlls;

    @Parameter.ParameterHolder
    private Classifier baseClassifier;
    private CategoricalData predicting;
    private boolean concurrentTraining;
    private boolean useScoreIfAvailable;

    public OneVSAll(Classifier classifier) {
        this(classifier, true);
    }

    public OneVSAll(Classifier classifier, boolean z) {
        this.useScoreIfAvailable = true;
        this.baseClassifier = classifier;
        this.concurrentTraining = z;
    }

    public void setConcurrentTraining(boolean z) {
        this.concurrentTraining = z;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        CategoricalResults categoricalResults = new CategoricalResults(this.predicting.getNumOfCategories());
        if (this.useScoreIfAvailable && (this.oneVsAlls[0] instanceof BinaryScoreClassifier)) {
            int i = 0;
            double d = Double.NEGATIVE_INFINITY;
            for (int i2 = 0; i2 < this.predicting.getNumOfCategories(); i2++) {
                double d2 = -((BinaryScoreClassifier) this.oneVsAlls[i2]).getScore(dataPoint);
                if (d2 > d) {
                    i = i2;
                    d = d2;
                }
            }
            categoricalResults.setProb(i, 1.0d);
        } else {
            for (int i3 = 0; i3 < this.predicting.getNumOfCategories(); i3++) {
                double prob = this.oneVsAlls[i3].classify(dataPoint).getProb(0);
                if (prob > 0.0d) {
                    categoricalResults.setProb(i3, prob);
                }
            }
            categoricalResults.normalize();
        }
        return categoricalResults;
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        this.oneVsAlls = new Classifier[classificationDataSet.getClassSize()];
        this.predicting = classificationDataSet.getPredicting();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.oneVsAlls.length; i++) {
            List<DataPoint> samples = classificationDataSet.getSamples(i);
            ArrayList arrayList2 = new ArrayList(samples.size());
            arrayList2.addAll(samples);
            arrayList.add(arrayList2);
        }
        int numNumericalVars = classificationDataSet.getNumNumericalVars();
        CategoricalData[] categories = classificationDataSet.getCategories();
        final CountDownLatch countDownLatch = new CountDownLatch(this.oneVsAlls.length);
        for (int i2 = 0; i2 < this.oneVsAlls.length; i2++) {
            final ClassificationDataSet classificationDataSet2 = new ClassificationDataSet(numNumericalVars, categories, new CategoricalData(2));
            for (DataPoint dataPoint : (List) arrayList.get(i2)) {
                classificationDataSet2.addDataPoint(dataPoint.getNumericalValues(), dataPoint.getCategoricalValues(), 0);
            }
            for (int i3 = 0; i3 < arrayList.size(); i3++) {
                if (i3 != i2) {
                    for (DataPoint dataPoint2 : (List) arrayList.get(i3)) {
                        classificationDataSet2.addDataPoint(dataPoint2.getNumericalValues(), dataPoint2.getCategoricalValues(), 1);
                    }
                }
            }
            if (this.concurrentTraining) {
                final Classifier m490clone = this.baseClassifier.m490clone();
                final int i4 = i2;
                executorService.submit(new Runnable() { // from class: jsat.classifiers.OneVSAll.1
                    @Override // java.lang.Runnable
                    public void run() {
                        m490clone.trainC(classificationDataSet2);
                        OneVSAll.this.oneVsAlls[i4] = m490clone;
                        countDownLatch.countDown();
                    }
                });
            } else {
                this.oneVsAlls[i2] = this.baseClassifier.m490clone();
                if (executorService == null || (executorService instanceof FakeExecutor)) {
                    this.oneVsAlls[i2].trainC(classificationDataSet2);
                } else {
                    this.oneVsAlls[i2].trainC(classificationDataSet2, executorService);
                }
            }
        }
        if (this.concurrentTraining) {
            try {
                countDownLatch.await();
            } catch (InterruptedException e) {
                Logger.getLogger(OneVSAll.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
            }
        }
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        trainC(classificationDataSet, new FakeExecutor());
    }

    @Override // jsat.classifiers.Classifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public OneVSAll m490clone() {
        OneVSAll oneVSAll = new OneVSAll(this.baseClassifier.m490clone(), this.concurrentTraining);
        if (this.predicting != null) {
            oneVSAll.predicting = this.predicting.m480clone();
        }
        if (this.oneVsAlls != null) {
            oneVSAll.oneVsAlls = new Classifier[this.oneVsAlls.length];
            for (int i = 0; i < this.oneVsAlls.length; i++) {
                if (this.oneVsAlls[i] != null) {
                    oneVSAll.oneVsAlls[i] = this.oneVsAlls[i].m490clone();
                }
            }
        }
        return oneVSAll;
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return this.baseClassifier.supportsWeightedData();
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }
}
