package jsat.classifiers.calibration;

import java.util.Collections;
import java.util.List;
import java.util.concurrent.ExecutorService;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPointPair;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.FakeExecutor;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/calibration/BinaryCalibration.class */
public abstract class BinaryCalibration implements Classifier, Parameterized {
    private static final long serialVersionUID = 2356311701854978890L;

    @Parameter.ParameterHolder
    protected BinaryScoreClassifier base;
    protected int folds = 3;
    protected double holdOut = 0.3d;
    protected CalibrationMode mode;

    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/calibration/BinaryCalibration$CalibrationMode.class */
    public enum CalibrationMode {
        NAIVE,
        CV,
        HOLD_OUT
    }

    public BinaryCalibration(BinaryScoreClassifier binaryScoreClassifier, CalibrationMode calibrationMode) {
        this.base = binaryScoreClassifier;
        setCalibrationMode(calibrationMode);
    }

    private void train(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        if (executorService == null || (executorService instanceof FakeExecutor)) {
            this.base.trainC(classificationDataSet);
        } else {
            this.base.trainC(classificationDataSet, executorService);
        }
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        double[] dArr = new double[classificationDataSet.getSampleSize()];
        boolean[] zArr = new boolean[dArr.length];
        int length = zArr.length;
        if (this.mode == CalibrationMode.CV) {
            List<ClassificationDataSet> cvSet = classificationDataSet.cvSet(this.folds);
            int i = 0;
            for (int i2 = 0; i2 < cvSet.size(); i2++) {
                ClassificationDataSet classificationDataSet2 = cvSet.get(i2);
                train(ClassificationDataSet.comineAllBut(cvSet, i2), executorService);
                for (int i3 = 0; i3 < classificationDataSet2.getSampleSize(); i3++) {
                    dArr[i] = this.base.getScore(classificationDataSet2.getDataPoint(i3));
                    zArr[i] = classificationDataSet2.getDataPointCategory(i3) == 1;
                    i++;
                }
            }
            train(classificationDataSet, executorService);
        } else if (this.mode == CalibrationMode.HOLD_OUT) {
            List<DataPointPair<Integer>> asDPPList = classificationDataSet.getAsDPPList();
            Collections.shuffle(asDPPList);
            int size = (int) (asDPPList.size() * (1.0d - this.holdOut));
            ClassificationDataSet classificationDataSet3 = new ClassificationDataSet(asDPPList.subList(0, size), classificationDataSet.getPredicting());
            ClassificationDataSet classificationDataSet4 = new ClassificationDataSet(asDPPList.subList(size, asDPPList.size()), classificationDataSet.getPredicting());
            train(classificationDataSet3, executorService);
            for (int i4 = 0; i4 < classificationDataSet4.getSampleSize(); i4++) {
                dArr[i4] = this.base.getScore(classificationDataSet4.getDataPoint(i4));
                zArr[i4] = classificationDataSet4.getDataPointCategory(i4) == 1;
            }
            length = classificationDataSet4.getSampleSize();
            train(classificationDataSet, executorService);
        } else {
            train(classificationDataSet, executorService);
            for (int i5 = 0; i5 < length; i5++) {
                dArr[i5] = this.base.getScore(classificationDataSet.getDataPoint(i5));
                zArr[i5] = classificationDataSet.getDataPointCategory(i5) == 1;
            }
        }
        calibrate(zArr, dArr, length);
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        trainC(classificationDataSet, null);
    }

    protected abstract void calibrate(boolean[] zArr, double[] dArr, int i);

    public void setCalibrationFolds(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("Folds must be a positive value, not " + i);
        }
        this.folds = i;
    }

    public int getCalibrationFolds() {
        return this.folds;
    }

    public void setCalibrationHoldOut(double d) {
        if (Double.isNaN(d) || d <= 0.0d || d >= 1.0d) {
            throw new IllegalArgumentException("HoldOut must be in (0, 1), not " + d);
        }
        this.holdOut = d;
    }

    public double getCalibrationHoldOut() {
        return this.holdOut;
    }

    public void setCalibrationMode(CalibrationMode calibrationMode) {
        this.mode = calibrationMode;
    }

    public CalibrationMode getCalibrationMode() {
        return this.mode;
    }

    @Override // jsat.classifiers.Classifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public abstract BinaryCalibration mo511clone();

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }
}
