package jsat.classifiers;

import java.util.concurrent.ExecutorService;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.Vec;
import jsat.regression.LogisticRegression;
import jsat.regression.RegressionDataSet;
import jsat.utils.FakeExecutor;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/MultinomialLogisticRegression.class */
public class MultinomialLogisticRegression implements Classifier {
    private static final long serialVersionUID = -9168502043850569017L;
    private Vec[] classCoefficents;

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        if (this.classCoefficents == null) {
            throw new UntrainedModelException("Model has not yet been trained");
        }
        CategoricalResults categoricalResults = new CategoricalResults(this.classCoefficents.length + 1);
        double d = 0.0d;
        categoricalResults.setProb(0, 1.0d);
        Vec numericalValues = dataPoint.getNumericalValues();
        for (int i = 0; i < this.classCoefficents.length; i++) {
            Vec vec = this.classCoefficents[i];
            double d2 = vec.get(0);
            for (int i2 = 1; i2 < vec.length(); i2++) {
                d2 += numericalValues.get(i2 - 1) * vec.get(i2);
            }
            double exp = Math.exp(d2);
            d += exp;
            categoricalResults.setProb(i + 1, exp);
        }
        categoricalResults.divideConst(1.0d + d);
        return categoricalResults;
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        LogisticRegression logisticRegression = new LogisticRegression();
        this.classCoefficents = new Vec[classificationDataSet.getClassSize() - 1];
        int i = 1;
        while (i < classificationDataSet.getClassSize()) {
            RegressionDataSet regressionDataSet = new RegressionDataSet(classificationDataSet.getNumNumericalVars(), classificationDataSet.getCategories());
            for (int i2 = 0; i2 < classificationDataSet.getSampleSize(); i2++) {
                regressionDataSet.addDataPoint(classificationDataSet.getDataPoint(i2), classificationDataSet.getDataPointCategory(i2) == i ? 1.0d : 0.0d);
            }
            logisticRegression.train(regressionDataSet, executorService);
            this.classCoefficents[i - 1] = logisticRegression.getCoefficents();
            i++;
        }
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        trainC(classificationDataSet, new FakeExecutor());
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    @Override // jsat.classifiers.Classifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public MultinomialLogisticRegression m487clone() {
        MultinomialLogisticRegression multinomialLogisticRegression = new MultinomialLogisticRegression();
        if (this.classCoefficents != null) {
            multinomialLogisticRegression.classCoefficents = new Vec[this.classCoefficents.length];
            for (int i = 0; i < this.classCoefficents.length; i++) {
                multinomialLogisticRegression.classCoefficents[i] = this.classCoefficents[i].mo524clone();
            }
        }
        return multinomialLogisticRegression;
    }
}
