package jsat.classifiers;

import java.util.Collections;
import java.util.List;
import java.util.concurrent.ExecutorService;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/RegressorToClassifier.class */
public class RegressorToClassifier implements BinaryScoreClassifier, Parameterized {
    private static final long serialVersionUID = -2607433019826385335L;
    private Regressor regressor;

    public RegressorToClassifier(Regressor regressor) {
        this.regressor = regressor;
    }

    @Override // jsat.classifiers.calibration.BinaryScoreClassifier
    public double getScore(DataPoint dataPoint) {
        return this.regressor.regress(dataPoint);
    }

    @Override // jsat.classifiers.Classifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public RegressorToClassifier m491clone() {
        return new RegressorToClassifier(this.regressor.clone());
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        CategoricalResults categoricalResults = new CategoricalResults(2);
        if (getScore(dataPoint) > 0.0d) {
            categoricalResults.setProb(1, 1.0d);
        } else {
            categoricalResults.setProb(0, 1.0d);
        }
        return categoricalResults;
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        this.regressor.train(getRegressionDataSet(classificationDataSet), executorService);
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        this.regressor.train(getRegressionDataSet(classificationDataSet));
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return this.regressor.supportsWeightedData();
    }

    private RegressionDataSet getRegressionDataSet(ClassificationDataSet classificationDataSet) {
        RegressionDataSet regressionDataSet = new RegressionDataSet(classificationDataSet.getNumNumericalVars(), classificationDataSet.getCategories());
        for (int i = 0; i < classificationDataSet.getSampleSize(); i++) {
            regressionDataSet.addDataPoint(classificationDataSet.getDataPoint(i), (classificationDataSet.getDataPointCategory(i) * 2) - 1);
        }
        return regressionDataSet;
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return this.regressor instanceof Parameterized ? ((Parameterized) this.regressor).getParameters() : Collections.EMPTY_LIST;
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        if (this.regressor instanceof Parameterized) {
            return ((Parameterized) this.regressor).getParameter(str);
        }
        return null;
    }
}
