package jsat.classifiers.boosting;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ExecutorService;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.exceptions.FailedToFitException;
import jsat.exceptions.UntrainedModelException;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.MultipleLinearRegression;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/boosting/LogitBoost.class */
public class LogitBoost implements Classifier, Parameterized {
    private static final long serialVersionUID = 1621062168467402062L;
    protected double fScaleConstant;
    protected List<Regressor> baseLearners;
    protected Regressor baseLearner;
    private int maxIterations;
    private double zMax;

    public LogitBoost(int i) {
        this(new MultipleLinearRegression(true), i);
    }

    public LogitBoost(Regressor regressor, int i) {
        this.fScaleConstant = 0.5d;
        this.zMax = 3.0d;
        if (!regressor.supportsWeightedData()) {
            throw new RuntimeException("Base Learner must support weighted data points to be boosted");
        }
        this.baseLearner = regressor;
        this.maxIterations = i;
    }

    public List<Regressor> getModels() {
        return Collections.unmodifiableList(this.baseLearners);
    }

    public void setMaxIterations(int i) {
        this.maxIterations = i;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public void setzMax(double d) {
        if (Double.isInfinite(d) || Double.isNaN(d) || d <= 0.0d) {
            throw new ArithmeticException("Invalid penalty given: " + d);
        }
        this.zMax = d;
    }

    public double getzMax() {
        return this.zMax;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        if (this.baseLearner == null) {
            throw new UntrainedModelException("Model has not yet been trained");
        }
        double P = P(dataPoint);
        CategoricalResults categoricalResults = new CategoricalResults(2);
        categoricalResults.setProb(1, P);
        categoricalResults.setProb(0, 1.0d - P);
        return categoricalResults;
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        trainC(classificationDataSet);
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        if (classificationDataSet.getClassSize() != 2) {
            throw new FailedToFitException("LogitBoost only supports binary decision tasks, not " + classificationDataSet.getClassSize() + " class problems");
        }
        List<DataPointPair<Double>> asFloatDPPList = classificationDataSet.getAsFloatDPPList();
        this.baseLearners = new ArrayList(this.maxIterations);
        int sampleSize = classificationDataSet.getSampleSize();
        for (int i = 0; i < this.maxIterations; i++) {
            for (int i2 = 0; i2 < sampleSize; i2++) {
                DataPoint dataPoint = asFloatDPPList.get(i2).getDataPoint();
                double P = P(dataPoint);
                double min = classificationDataSet.getDataPointCategory(i2) == 1 ? Math.min(this.zMax, 1.0d / P) : Math.max(-this.zMax, (-1.0d) / (1.0d - P));
                dataPoint.setWeight(Math.max(P * (1.0d - P), 2.0E-15d));
                asFloatDPPList.get(i2).setPair(Double.valueOf(min));
            }
            Regressor clone = this.baseLearner.clone();
            clone.train(new RegressionDataSet(asFloatDPPList));
            this.baseLearners.add(clone);
        }
    }

    private double F(DataPoint dataPoint) {
        double d = 0.0d;
        Iterator<Regressor> it = this.baseLearners.iterator();
        while (it.hasNext()) {
            d += it.next().regress(dataPoint);
        }
        return d * this.fScaleConstant;
    }

    protected double P(DataPoint dataPoint) {
        double F = F(dataPoint);
        double exp = Math.exp(F);
        double exp2 = Math.exp(-F);
        if (!Double.isInfinite(exp) || exp <= 0.0d || exp2 >= 1.0E-15d) {
            return exp / (exp + exp2);
        }
        return 1.0d;
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    @Override // jsat.classifiers.Classifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public LogitBoost mo505clone() {
        LogitBoost logitBoost = new LogitBoost(this.maxIterations);
        logitBoost.zMax = this.zMax;
        if (this.baseLearner != null) {
            logitBoost.baseLearner = this.baseLearner.clone();
        }
        if (this.baseLearners != null) {
            logitBoost.baseLearners = new ArrayList(this.baseLearners.size());
            Iterator<Regressor> it = this.baseLearners.iterator();
            while (it.hasNext()) {
                logitBoost.baseLearners.add(it.next().clone());
            }
        }
        return logitBoost;
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }
}
