package jsat.regression;

import java.util.ArrayList;
import java.util.concurrent.ExecutorService;
import jsat.SingleWeightVectorModel;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.exceptions.FailedToFitException;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.DenseVector;
import jsat.linear.SubVector;
import jsat.linear.Vec;
import jsat.math.Function;
import jsat.math.optimization.IterativelyReweightedLeastSquares;
import jsat.utils.FakeExecutor;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/regression/LogisticRegression.class */
public class LogisticRegression implements Classifier, Regressor, SingleWeightVectorModel {
    private static final long serialVersionUID = -5115807516729861730L;
    private Vec coefficents;
    private double shift;
    private double scale;
    private final Function logitFun = new Function() { // from class: jsat.regression.LogisticRegression.1
        private static final long serialVersionUID = -653111120605227341L;

        @Override // jsat.math.Function
        public double f(double... dArr) {
            return LogisticRegression.this.logitReg(DenseVector.toDenseVec(dArr));
        }

        @Override // jsat.math.Function
        public double f(Vec vec) {
            return LogisticRegression.this.logitReg(vec);
        }
    };
    private final Function logitFunD = new Function() { // from class: jsat.regression.LogisticRegression.2
        private static final long serialVersionUID = 4844651397674391691L;

        @Override // jsat.math.Function
        public double f(double... dArr) {
            return LogisticRegression.this.logitReg(DenseVector.toDenseVec(dArr));
        }

        @Override // jsat.math.Function
        public double f(Vec vec) {
            double logitReg = LogisticRegression.this.logitReg(vec);
            return logitReg * (1.0d - logitReg);
        }
    };

    private static double logit(double d) {
        return 1.0d / (1.0d + Math.exp(-d));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public double logitReg(Vec vec) {
        double d = this.coefficents.get(0);
        for (int i = 1; i < this.coefficents.length(); i++) {
            d += vec.get(i - 1) * this.coefficents.get(i);
        }
        return logit(d);
    }

    public Vec getCoefficents() {
        return this.coefficents;
    }

    @Override // jsat.regression.Regressor
    public double regress(DataPoint dataPoint) {
        if (this.coefficents == null) {
            throw new UntrainedModelException("Model has not been trained");
        }
        return (logitReg(dataPoint.getNumericalValues()) * this.scale) + this.shift;
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet, ExecutorService executorService) {
        ArrayList arrayList = new ArrayList(regressionDataSet.getSampleSize());
        for (int i = 0; i < regressionDataSet.getSampleSize(); i++) {
            arrayList.add(regressionDataSet.getDataPoint(i).getNumericalValues());
        }
        this.coefficents = new DenseVector(regressionDataSet.getNumNumericalVars() + 1);
        Vec targetValues = regressionDataSet.getTargetValues();
        double min = targetValues.min();
        double max = targetValues.max();
        this.shift = min;
        this.scale = max - min;
        targetValues.subtract(this.shift);
        targetValues.mutableDivide(this.scale);
        this.coefficents = new IterativelyReweightedLeastSquares().optimize(1.0E-5d, 100, this.logitFun, this.logitFunD, this.coefficents, arrayList, targetValues, executorService);
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet) {
        train(regressionDataSet, new FakeExecutor());
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    @Override // jsat.SingleWeightVectorModel
    public Vec getRawWeight() {
        return new SubVector(1, this.coefficents.length() - 1, this.coefficents);
    }

    @Override // jsat.SingleWeightVectorModel
    public double getBias() {
        return this.coefficents.get(0);
    }

    @Override // jsat.SimpleWeightVectorModel
    public Vec getRawWeight(int i) {
        if (i < 1) {
            return getRawWeight();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override // jsat.SimpleWeightVectorModel
    public double getBias(int i) {
        if (i < 1) {
            return getBias();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override // jsat.SimpleWeightVectorModel
    public int numWeightsVecs() {
        return 1;
    }

    @Override // jsat.regression.Regressor
    /* renamed from: clone, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public LogisticRegression m574clone() {
        LogisticRegression logisticRegression = new LogisticRegression();
        if (this.coefficents != null) {
            logisticRegression.coefficents = this.coefficents.mo525clone();
        }
        logisticRegression.scale = this.scale;
        logisticRegression.shift = this.shift;
        return logisticRegression;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        if (this.coefficents == null) {
            throw new UntrainedModelException("Model has not yet been trained");
        }
        if (this.shift != 0.0d || this.scale != 1.0d) {
            throw new UntrainedModelException("Model was trained for regression, not classifiaction");
        }
        CategoricalResults categoricalResults = new CategoricalResults(2);
        categoricalResults.setProb(1, regress(dataPoint));
        categoricalResults.setProb(0, 1.0d - categoricalResults.getProb(1));
        return categoricalResults;
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        if (classificationDataSet.getClassSize() != 2) {
            throw new FailedToFitException("Logistic Regression works only in the case of two classes, and can not handle " + classificationDataSet.getClassSize() + " classes");
        }
        RegressionDataSet regressionDataSet = new RegressionDataSet(classificationDataSet.getNumNumericalVars(), classificationDataSet.getCategories());
        for (int i = 0; i < classificationDataSet.getSampleSize(); i++) {
            regressionDataSet.addDataPoint(classificationDataSet.getDataPoint(i), classificationDataSet.getDataPointCategory(i));
        }
        train(regressionDataSet, executorService);
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        trainC(classificationDataSet, new FakeExecutor());
    }
}
