package smile.classification;

import java.util.Properties;
import smile.data.DataFrame;
import smile.data.formula.Formula;
import smile.math.MathEx;
import smile.math.matrix.DenseMatrix;
import smile.math.matrix.EVD;
import smile.math.matrix.Matrix;
import smile.math.matrix.SVD;
import smile.projection.Projection;
import smile.util.IntSet;

/* loaded from: input_file:smile-core-2.4.0.jar:smile/classification/FLD.class */
public class FLD implements Classifier<double[]>, Projection<double[]> {
    private static final long serialVersionUID = 2;
    private final int p;
    private final int k;
    private final DenseMatrix scaling;
    private final double[] mean;
    private final double[][] mu;
    private final IntSet labels;

    public FLD(double[] dArr, double[][] dArr2, DenseMatrix denseMatrix) {
        this(dArr, dArr2, denseMatrix, IntSet.of(dArr2.length));
    }

    public FLD(double[] dArr, double[][] dArr2, DenseMatrix denseMatrix, IntSet intSet) {
        this.k = dArr2.length;
        this.p = dArr.length;
        this.scaling = denseMatrix;
        this.labels = intSet;
        int ncols = denseMatrix.ncols();
        this.mean = new double[ncols];
        denseMatrix.atx(dArr, this.mean);
        this.mu = new double[this.k][ncols];
        for (int i = 0; i < this.k; i++) {
            denseMatrix.atx(dArr2[i], this.mu[i]);
        }
    }

    public static FLD fit(Formula formula, DataFrame dataFrame) {
        return fit(formula, dataFrame, new Properties());
    }

    public static FLD fit(Formula formula, DataFrame dataFrame, Properties properties) {
        return fit(formula.x(dataFrame).toArray(), formula.y(dataFrame).toIntArray(), Integer.valueOf(properties.getProperty("smile.fld.dimension", "-1")).intValue(), Double.valueOf(properties.getProperty("smile.fld.tolerance", "1E-4")).doubleValue());
    }

    public static FLD fit(double[][] dArr, int[] iArr) {
        return fit(dArr, iArr, -1, 1.0E-4d);
    }

    public static FLD fit(double[][] dArr, int[] iArr, int i, double d) {
        if (dArr.length != iArr.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", Integer.valueOf(dArr.length), Integer.valueOf(iArr.length)));
        }
        DiscriminantAnalysis fit = DiscriminantAnalysis.fit(dArr, iArr, null, d);
        int length = dArr.length;
        int i2 = fit.k;
        int length2 = fit.mean.length;
        if (i >= i2) {
            throw new IllegalArgumentException(String.format("The dimensionality of mapped space is too high: %d >= %d", Integer.valueOf(i), Integer.valueOf(i2)));
        }
        if (i <= 0) {
            i = i2 - 1;
        }
        double[] dArr2 = fit.mean;
        double[][] dArr3 = fit.mu;
        return new FLD(dArr2, dArr3, length - i2 < length2 ? small(i, dArr, dArr2, dArr3, fit.priori, d) : fld(i, dArr, dArr2, dArr3, d), fit.labels);
    }

    private static DenseMatrix fld(int i, double[][] dArr, double[] dArr2, double[][] dArr3, double d) {
        int length = dArr3.length;
        int length2 = dArr2.length;
        EVD eigen = DiscriminantAnalysis.St(dArr, dArr2, length, d).eigen();
        double d2 = d * d;
        double[] eigenValues = eigen.getEigenValues();
        for (int i2 = 0; i2 < eigenValues.length; i2++) {
            if (eigenValues[i2] < d2) {
                throw new IllegalArgumentException("The covariance matrix is close to singular.");
            }
            eigenValues[i2] = 1.0d / eigenValues[i2];
        }
        for (double[] dArr4 : dArr3) {
            for (int i3 = 0; i3 < length2; i3++) {
                int i4 = i3;
                dArr4[i4] = dArr4[i4] - dArr2[i3];
            }
        }
        DenseMatrix zeros = Matrix.zeros(length2, length2);
        for (double[] dArr5 : dArr3) {
            for (int i5 = 0; i5 < length2; i5++) {
                for (int i6 = 0; i6 <= i5; i6++) {
                    zeros.add(i6, i5, dArr5[i6] * dArr5[i5]);
                }
            }
        }
        for (int i7 = 0; i7 < length2; i7++) {
            for (int i8 = 0; i8 <= i7; i8++) {
                zeros.div(i8, i7, length);
                zeros.set(i7, i8, zeros.get(i8, i7));
            }
        }
        DenseMatrix eigenVectors = eigen.getEigenVectors();
        DenseMatrix atbmm = eigenVectors.atbmm(zeros);
        for (int i9 = 0; i9 < length2; i9++) {
            double d3 = eigenValues[i9];
            for (int i10 = 0; i10 < length; i10++) {
                atbmm.mul(i10, i9, d3);
            }
        }
        DenseMatrix abmm = eigenVectors.abmm(atbmm);
        abmm.setSymmetric(true);
        return abmm.eigen().getEigenVectors().submat(0, 0, length2, i);
    }

    private static DenseMatrix small(int i, double[][] dArr, double[] dArr2, double[][] dArr3, double[] dArr4, double d) {
        int length = dArr3.length;
        int length2 = dArr2.length;
        int length3 = dArr.length;
        double sqrt = Math.sqrt(length3);
        DenseMatrix zeros = Matrix.zeros(length2, length3);
        for (int i2 = 0; i2 < length3; i2++) {
            double[] dArr5 = dArr[i2];
            for (int i3 = 0; i3 < length2; i3++) {
                zeros.set(i3, i2, (dArr5[i3] - dArr2[i3]) / sqrt);
            }
        }
        for (double[] dArr6 : dArr3) {
            for (int i4 = 0; i4 < length2; i4++) {
                int i5 = i4;
                dArr6[i5] = dArr6[i5] - dArr2[i4];
            }
        }
        DenseMatrix zeros2 = Matrix.zeros(length2, length);
        for (int i6 = 0; i6 < length; i6++) {
            double sqrt2 = Math.sqrt(dArr4[i6]);
            double[] dArr7 = dArr3[i6];
            for (int i7 = 0; i7 < length2; i7++) {
                zeros2.set(i7, i6, sqrt2 * dArr7[i7]);
            }
        }
        SVD svd = zeros.svd(true);
        DenseMatrix u = svd.getU();
        double[] singularValues = svd.getSingularValues();
        double d2 = d * d;
        DenseMatrix atbmm = u.atbmm(zeros2);
        for (int i8 = 0; i8 < length3; i8++) {
            double sqrt3 = singularValues[i8] > d2 ? 1.0d / Math.sqrt(singularValues[i8]) : 0.0d;
            for (int i9 = 0; i9 < length; i9++) {
                atbmm.mul(i8, i9, sqrt3);
            }
        }
        DenseMatrix atbmm2 = u.atbmm(u.abmm(atbmm).svd(true).getU().submat(0, 0, length2 + 1, i));
        for (int i10 = 0; i10 < length3; i10++) {
            double sqrt4 = singularValues[i10] > d2 ? 1.0d / Math.sqrt(singularValues[i10]) : 0.0d;
            for (int i11 = 0; i11 < i; i11++) {
                atbmm2.mul(i10, i11, sqrt4);
            }
        }
        return u.abmm(atbmm2);
    }

    @Override // smile.classification.Classifier
    public int predict(double[] dArr) {
        if (dArr.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", Integer.valueOf(dArr.length), Integer.valueOf(this.p)));
        }
        double[] project = project(dArr);
        int i = 0;
        double d = Double.POSITIVE_INFINITY;
        for (int i2 = 0; i2 < this.k; i2++) {
            double distance = MathEx.distance(project, this.mu[i2]);
            if (distance < d) {
                d = distance;
                i = i2;
            }
        }
        return this.labels.valueOf(i);
    }

    @Override // smile.projection.Projection
    public double[] project(double[] dArr) {
        if (dArr.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", Integer.valueOf(dArr.length), Integer.valueOf(this.p)));
        }
        double[] dArr2 = new double[this.scaling.ncols()];
        this.scaling.atx(dArr, dArr2);
        MathEx.sub(dArr2, this.mean);
        return dArr2;
    }

    @Override // smile.projection.Projection
    public double[][] project(double[][] dArr) {
        double[][] dArr2 = new double[dArr.length][this.scaling.ncols()];
        for (int i = 0; i < dArr.length; i++) {
            if (dArr[i].length != this.p) {
                throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", Integer.valueOf(dArr[i].length), Integer.valueOf(this.p)));
            }
            this.scaling.atx(dArr[i], dArr2[i]);
            MathEx.sub(dArr2[i], this.mean);
        }
        return dArr2;
    }

    public DenseMatrix getProjection() {
        return this.scaling;
    }
}
