package jsat.regression;

import java.util.concurrent.ExecutorService;
import jsat.SingleWeightVectorModel;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.Matrix;
import jsat.linear.QRDecomposition;
import jsat.linear.Vec;
import jsat.utils.FakeExecutor;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/regression/MultipleLinearRegression.class */
public class MultipleLinearRegression implements Regressor, SingleWeightVectorModel {
    private static final long serialVersionUID = 7694194181910565061L;
    private Vec B;
    private double a;
    private boolean useWeights;

    public MultipleLinearRegression() {
        this(true);
    }

    public MultipleLinearRegression(boolean z) {
        this.useWeights = false;
        this.useWeights = z;
    }

    @Override // jsat.regression.Regressor
    public double regress(DataPoint dataPoint) {
        return this.B.dot(dataPoint.getNumericalValues()) + this.a;
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet, ExecutorService executorService) {
        if (regressionDataSet.getNumCategoricalVars() > 0) {
            throw new RuntimeException("Multiple Linear Regression only works with numerical values");
        }
        regressionDataSet.getSampleSize();
        DenseMatrix denseMatrix = new DenseMatrix(regressionDataSet.getSampleSize(), regressionDataSet.getNumNumericalVars() + 1);
        DenseVector denseVector = new DenseVector(regressionDataSet.getSampleSize());
        for (int i = 0; i < regressionDataSet.getSampleSize(); i++) {
            DataPointPair<Double> dataPointPair = regressionDataSet.getDataPointPair(i);
            denseVector.set(i, dataPointPair.getPair().doubleValue());
            denseMatrix.set(i, 0, 1.0d);
            Vec vector = dataPointPair.getVector();
            for (int i2 = 0; i2 < vector.length(); i2++) {
                denseMatrix.set(i, i2 + 1, vector.get(i2));
            }
        }
        if (this.useWeights) {
            DenseVector denseVector2 = new DenseVector(regressionDataSet.getSampleSize());
            for (int i3 = 0; i3 < regressionDataSet.getSampleSize(); i3++) {
                denseVector2.set(i3, Math.sqrt(regressionDataSet.getDataPoint(i3).getWeight()));
            }
            Matrix.diagMult(denseVector2, denseMatrix);
            denseVector.mutablePairwiseMultiply(denseVector2);
        }
        Matrix[] qr = denseMatrix.qr(executorService);
        Vec solve = new QRDecomposition(qr[0], qr[1]).solve(denseVector);
        this.a = solve.get(0);
        this.B = new DenseVector(regressionDataSet.getNumNumericalVars());
        for (int i4 = 1; i4 < solve.length(); i4++) {
            this.B.set(i4 - 1, solve.get(i4));
        }
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet) {
        train(regressionDataSet, new FakeExecutor());
    }

    @Override // jsat.regression.Regressor
    public boolean supportsWeightedData() {
        return this.useWeights;
    }

    @Override // jsat.SingleWeightVectorModel
    public Vec getRawWeight() {
        return this.B;
    }

    @Override // jsat.SingleWeightVectorModel
    public double getBias() {
        return this.a;
    }

    @Override // jsat.SimpleWeightVectorModel
    public Vec getRawWeight(int i) {
        if (i < 1) {
            return getRawWeight();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override // jsat.SimpleWeightVectorModel
    public double getBias(int i) {
        if (i < 1) {
            return getBias();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override // jsat.SimpleWeightVectorModel
    public int numWeightsVecs() {
        return 1;
    }

    @Override // jsat.regression.Regressor
    public MultipleLinearRegression clone() {
        MultipleLinearRegression multipleLinearRegression = new MultipleLinearRegression();
        if (this.B != null) {
            multipleLinearRegression.B = this.B.mo525clone();
        }
        multipleLinearRegression.a = this.a;
        return multipleLinearRegression;
    }
}
