package smile.regression;

import java.util.Arrays;
import java.util.Properties;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.data.DataFrame;
import smile.data.formula.Formula;
import smile.math.MathEx;
import smile.math.matrix.DenseMatrix;
import smile.math.matrix.QR;
import smile.math.matrix.SVD;
import smile.math.special.Beta;

/* loaded from: input_file:smile-core-2.4.0.jar:smile/regression/OLS.class */
public class OLS {
    private static final Logger logger = LoggerFactory.getLogger(OLS.class);

    public static LinearModel fit(Formula formula, DataFrame dataFrame) {
        return fit(formula, dataFrame, new Properties());
    }

    public static LinearModel fit(Formula formula, DataFrame dataFrame, Properties properties) {
        return fit(formula, dataFrame, properties.getProperty("smile.ols.method", "qr"), Boolean.valueOf(properties.getProperty("smile.ols.standard.error", "true")).booleanValue(), Boolean.valueOf(properties.getProperty("smile.ols.recursive", "true")).booleanValue());
    }

    public static LinearModel fit(Formula formula, DataFrame dataFrame, String str, boolean z, boolean z2) {
        DenseMatrix matrix = formula.matrix(dataFrame, true);
        double[] doubleArray = formula.y(dataFrame).toDoubleArray();
        int nrows = matrix.nrows();
        int ncols = matrix.ncols() - 1;
        if (nrows <= ncols) {
            throw new IllegalArgumentException(String.format("The input matrix is not over determined: %d rows, %d columns", Integer.valueOf(nrows), Integer.valueOf(ncols)));
        }
        double[] dArr = new double[ncols + 1];
        QR qr = null;
        if (str.equalsIgnoreCase("svd")) {
            matrix.svd(false).solve(doubleArray, dArr);
        } else {
            try {
                qr = matrix.qr(false);
                qr.solve(doubleArray, dArr);
            } catch (RuntimeException e) {
                logger.warn("Matrix is not of full rank, try SVD instead");
                str = "svd";
                SVD svd = matrix.svd(false);
                Arrays.fill(dArr, 0.0d);
                svd.solve(doubleArray, dArr);
            }
        }
        LinearModel linearModel = new LinearModel();
        linearModel.formula = formula;
        linearModel.schema = formula.xschema();
        linearModel.p = ncols;
        linearModel.b = dArr[ncols];
        linearModel.w = new double[ncols];
        System.arraycopy(dArr, 0, linearModel.w, 0, ncols);
        double[] dArr2 = new double[nrows];
        matrix.ax(dArr, dArr2);
        linearModel.fitness(dArr2, doubleArray, MathEx.mean(doubleArray));
        DenseMatrix denseMatrix = null;
        if (z || z2) {
            denseMatrix = (str.equalsIgnoreCase("svd") ? matrix.ata().cholesky() : qr.CholeskyOfAtA()).inverse();
            linearModel.V = denseMatrix;
        }
        if (z) {
            double[][] dArr3 = new double[ncols + 1][4];
            linearModel.ttest = dArr3;
            for (int i = 0; i <= ncols; i++) {
                dArr3[i][0] = dArr[i];
                double sqrt = linearModel.error * Math.sqrt(denseMatrix.get(i, i));
                dArr3[i][1] = sqrt;
                double d = dArr[i] / sqrt;
                dArr3[i][2] = d;
                dArr3[i][3] = Beta.regularizedIncompleteBetaFunction(0.5d * linearModel.df, 0.5d, linearModel.df / (linearModel.df + (d * d)));
            }
        }
        return linearModel;
    }
}
