package smile.projection;

import java.io.Serializable;
import java.util.Arrays;
import smile.math.MathEx;
import smile.math.kernel.MercerKernel;
import smile.math.matrix.DenseMatrix;
import smile.math.matrix.EVD;
import smile.math.matrix.Matrix;

/* loaded from: input_file:smile-core-2.4.0.jar:smile/projection/KPCA.class */
public class KPCA<T> implements Projection<T>, Serializable {
    private static final long serialVersionUID = 2;
    private int p;
    private T[] data;
    private MercerKernel<T> kernel;
    private double[] mean;
    private double mu;
    private double[] latent;
    private DenseMatrix projection;
    private double[][] coordinates;

    public KPCA(T[] tArr, MercerKernel<T> mercerKernel, double[] dArr, double d, double[][] dArr2, double[] dArr3, DenseMatrix denseMatrix) {
        this.data = tArr;
        this.kernel = mercerKernel;
        this.mean = dArr;
        this.mu = d;
        this.coordinates = dArr2;
        this.latent = dArr3;
        this.projection = denseMatrix;
        this.p = denseMatrix.nrows();
    }

    public static <T> KPCA<T> fit(T[] tArr, MercerKernel<T> mercerKernel, int i) {
        return fit(tArr, mercerKernel, i, 1.0E-4d);
    }

    public static <T> KPCA<T> fit(T[] tArr, MercerKernel<T> mercerKernel, int i, double d) {
        if (d < 0.0d) {
            throw new IllegalArgumentException("Invalid threshold = " + d);
        }
        if (i < 1 || i > tArr.length) {
            throw new IllegalArgumentException("Invalid dimension of feature space: " + i);
        }
        int length = tArr.length;
        DenseMatrix zeros = Matrix.zeros(length, length);
        for (int i2 = 0; i2 < length; i2++) {
            for (int i3 = 0; i3 <= i2; i3++) {
                double k = mercerKernel.k(tArr[i2], tArr[i3]);
                zeros.set(i2, i3, k);
                zeros.set(i3, i2, k);
            }
        }
        double[] rowMeans = zeros.rowMeans();
        double mean = MathEx.mean(rowMeans);
        for (int i4 = 0; i4 < length; i4++) {
            for (int i5 = 0; i5 <= i4; i5++) {
                double d2 = ((zeros.get(i4, i5) - rowMeans[i4]) - rowMeans[i5]) + mean;
                zeros.set(i4, i5, d2);
                zeros.set(i5, i4, d2);
            }
        }
        zeros.setSymmetric(true);
        EVD eigen = zeros.eigen(i);
        double[] eigenValues = eigen.getEigenValues();
        DenseMatrix eigenVectors = eigen.getEigenVectors();
        int count = (int) Arrays.stream(eigenValues).limit(i).filter(d3 -> {
            return d3 / ((double) length) > d;
        }).count();
        double[] dArr = new double[count];
        DenseMatrix zeros2 = Matrix.zeros(count, length);
        for (int i6 = 0; i6 < count; i6++) {
            dArr[i6] = eigenValues[i6];
            double sqrt = Math.sqrt(dArr[i6]);
            for (int i7 = 0; i7 < length; i7++) {
                zeros2.set(i6, i7, eigenVectors.get(i7, i6) / sqrt);
            }
        }
        DenseMatrix abmm = zeros2.abmm(zeros);
        double[][] dArr2 = new double[length][count];
        for (int i8 = 0; i8 < length; i8++) {
            for (int i9 = 0; i9 < count; i9++) {
                dArr2[i8][i9] = abmm.get(i9, i8);
            }
        }
        return new KPCA<>(tArr, mercerKernel, rowMeans, mean, dArr2, dArr, zeros2);
    }

    public double[] getVariances() {
        return this.latent;
    }

    public DenseMatrix getProjection() {
        return this.projection;
    }

    public double[][] getCoordinates() {
        return this.coordinates;
    }

    @Override // smile.projection.Projection
    public double[] project(T t) {
        int length = this.data.length;
        double[] dArr = new double[length];
        for (int i = 0; i < length; i++) {
            dArr[i] = this.kernel.k(t, this.data[i]);
        }
        double mean = MathEx.mean(dArr);
        for (int i2 = 0; i2 < length; i2++) {
            dArr[i2] = ((dArr[i2] - mean) - this.mean[i2]) + this.mu;
        }
        double[] dArr2 = new double[this.p];
        this.projection.ax(dArr, dArr2);
        return dArr2;
    }

    @Override // smile.projection.Projection
    public double[][] project(T[] tArr) {
        int length = tArr.length;
        int length2 = this.data.length;
        double[][] dArr = new double[length][length2];
        for (int i = 0; i < length; i++) {
            for (int i2 = 0; i2 < length2; i2++) {
                dArr[i][i2] = this.kernel.k(tArr[i], this.data[i2]);
            }
            double mean = MathEx.mean(dArr[i]);
            for (int i3 = 0; i3 < length2; i3++) {
                dArr[i][i3] = ((dArr[i][i3] - mean) - this.mean[i3]) + this.mu;
            }
        }
        double[][] dArr2 = new double[tArr.length][this.p];
        for (int i4 = 0; i4 < dArr.length; i4++) {
            this.projection.ax(dArr[i4], dArr2[i4]);
        }
        return dArr2;
    }
}
