package jsat.datatransform.kernel;

import java.util.Arrays;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import jsat.DataSet;
import jsat.classifiers.DataPoint;
import jsat.datatransform.DataTransformBase;
import jsat.datatransform.kernel.Nystrom;
import jsat.distributions.Distribution;
import jsat.distributions.discrete.UniformDiscrete;
import jsat.distributions.kernels.KernelTrick;
import jsat.distributions.kernels.RBFKernel;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.EigenValueDecomposition;
import jsat.linear.Matrix;
import jsat.linear.RowColumnOps;
import jsat.linear.Vec;
import jsat.parameters.Parameter;
import jsat.utils.random.XOR96;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/datatransform/kernel/KernelPCA.class */
public class KernelPCA extends DataTransformBase {
    private static final long serialVersionUID = 5676602024560381023L;
    private int dimensions;

    @Parameter.ParameterHolder
    private KernelTrick k;
    private int basisSize;
    private Nystrom.SamplingMethod samplingMethod;
    private double[] eigenVals;
    private Matrix eigenVecs;
    private Vec[] vecs;
    private double[] rowAvg;
    private double allAvg;

    public KernelPCA() {
        this(100);
    }

    public KernelPCA(int i) {
        this(new RBFKernel(), i);
    }

    public KernelPCA(KernelTrick kernelTrick, int i) {
        this(kernelTrick, i, 1000, Nystrom.SamplingMethod.UNIFORM);
    }

    public KernelPCA(KernelTrick kernelTrick, int i, int i2, Nystrom.SamplingMethod samplingMethod) {
        setDimensions(i);
        setKernel(kernelTrick);
        setBasisSize(i2);
        setBasisSamplingMethod(samplingMethod);
    }

    public KernelPCA(KernelTrick kernelTrick, DataSet dataSet, int i, int i2, Nystrom.SamplingMethod samplingMethod) {
        this(kernelTrick, i, i2, samplingMethod);
        fit(dataSet);
    }

    @Override // jsat.datatransform.DataTransform
    public void fit(DataSet dataSet) {
        if (dataSet.getSampleSize() <= this.basisSize) {
            this.vecs = new Vec[dataSet.getSampleSize()];
            for (int i = 0; i < this.vecs.length; i++) {
                this.vecs[i] = dataSet.getDataPoint(i).getNumericalValues();
            }
        } else {
            int i2 = 0;
            List<Vec> sampleBasisVectors = Nystrom.sampleBasisVectors(this.k, dataSet, dataSet.getDataVectors(), this.samplingMethod, this.basisSize, false, new XOR96());
            this.vecs = new Vec[sampleBasisVectors.size()];
            Iterator<Vec> it = sampleBasisVectors.iterator();
            while (it.hasNext()) {
                int i3 = i2;
                i2++;
                this.vecs[i3] = it.next();
            }
        }
        DenseMatrix denseMatrix = new DenseMatrix(this.vecs.length, this.vecs.length);
        this.rowAvg = new double[denseMatrix.rows()];
        this.allAvg = 0.0d;
        for (int i4 = 0; i4 < denseMatrix.rows(); i4++) {
            Vec vec = this.vecs[i4];
            for (int i5 = i4; i5 < denseMatrix.cols(); i5++) {
                double eval = this.k.eval(vec, this.vecs[i5]);
                denseMatrix.set(i4, i5, eval);
                denseMatrix.set(i5, i4, eval);
            }
        }
        for (int i6 = 0; i6 < denseMatrix.rows(); i6++) {
            for (int i7 = 0; i7 < denseMatrix.cols(); i7++) {
                double[] dArr = this.rowAvg;
                int i8 = i6;
                dArr[i8] = dArr[i8] + denseMatrix.get(i6, i7);
            }
        }
        for (int i9 = 0; i9 < denseMatrix.rows(); i9++) {
            this.allAvg += this.rowAvg[i9];
            double[] dArr2 = this.rowAvg;
            int i10 = i9;
            dArr2[i10] = dArr2[i10] / denseMatrix.rows();
        }
        this.allAvg /= denseMatrix.rows() * denseMatrix.cols();
        for (int i11 = 0; i11 < denseMatrix.rows(); i11++) {
            for (int i12 = 0; i12 < denseMatrix.cols(); i12++) {
                denseMatrix.set(i11, i12, ((denseMatrix.get(i11, i12) - this.rowAvg[i11]) - this.rowAvg[i12]) + this.allAvg);
            }
        }
        EigenValueDecomposition eigenValueDecomposition = new EigenValueDecomposition(denseMatrix);
        eigenValueDecomposition.sortByEigenValue(new Comparator<Double>() { // from class: jsat.datatransform.kernel.KernelPCA.1
            @Override // java.util.Comparator
            public int compare(Double d, Double d2) {
                return -Double.compare(d.doubleValue(), d2.doubleValue());
            }
        });
        this.eigenVals = eigenValueDecomposition.getRealEigenvalues();
        this.eigenVecs = eigenValueDecomposition.getV();
        for (int i13 = 0; i13 < this.eigenVals.length; i13++) {
            RowColumnOps.divCol(this.eigenVecs, i13, Math.sqrt(this.eigenVals[i13]));
        }
    }

    protected KernelPCA(KernelPCA kernelPCA) {
        this.dimensions = kernelPCA.dimensions;
        this.k = kernelPCA.k.m629clone();
        this.basisSize = kernelPCA.basisSize;
        this.samplingMethod = kernelPCA.samplingMethod;
        if (kernelPCA.eigenVals != null) {
            this.eigenVals = Arrays.copyOf(kernelPCA.eigenVals, kernelPCA.eigenVals.length);
        }
        if (kernelPCA.eigenVecs != null) {
            this.eigenVecs = kernelPCA.eigenVecs.mo641clone();
        }
        if (kernelPCA.vecs != null) {
            this.vecs = new Vec[kernelPCA.vecs.length];
            for (int i = 0; i < this.vecs.length; i++) {
                this.vecs[i] = kernelPCA.vecs[i].mo525clone();
            }
            this.rowAvg = Arrays.copyOf(kernelPCA.rowAvg, kernelPCA.rowAvg.length);
        }
        this.allAvg = kernelPCA.allAvg;
    }

    @Override // jsat.datatransform.DataTransform
    public DataPoint transform(DataPoint dataPoint) {
        Vec numericalValues = dataPoint.getNumericalValues();
        DenseVector denseVector = new DenseVector(this.dimensions);
        double[] dArr = new double[this.vecs.length];
        double d = 0.0d;
        for (int i = 0; i < this.vecs.length; i++) {
            double eval = this.k.eval(this.vecs[i], numericalValues);
            dArr[i] = eval;
            d += eval;
        }
        double length = d / this.vecs.length;
        for (int i2 = 0; i2 < this.dimensions; i2++) {
            double d2 = 0.0d;
            for (int i3 = 0; i3 < this.vecs.length; i3++) {
                d2 += this.eigenVecs.get(i3, i2) * (((dArr[i3] - length) - this.rowAvg[i2]) + this.allAvg);
            }
            denseVector.set(i2, d2);
        }
        return new DataPoint(denseVector, dataPoint.getCategoricalValues(), dataPoint.getCategoricalData(), dataPoint.getWeight());
    }

    @Override // jsat.datatransform.DataTransformBase
    public KernelPCA clone() {
        return new KernelPCA(this);
    }

    public void setKernel(KernelTrick kernelTrick) {
        this.k = kernelTrick;
    }

    public KernelTrick getKernel() {
        return this.k;
    }

    public void setBasisSize(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("The basis size must be positive, not " + i);
        }
        this.basisSize = i;
    }

    public int getBasisSize() {
        return this.basisSize;
    }

    public void setDimensions(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("The number of dimensions must be positive, not " + i);
        }
        this.dimensions = i;
    }

    public int getDimensions() {
        return this.dimensions;
    }

    public void setBasisSamplingMethod(Nystrom.SamplingMethod samplingMethod) {
        this.samplingMethod = samplingMethod;
    }

    public Nystrom.SamplingMethod getBasisSamplingMethod() {
        return this.samplingMethod;
    }

    public static Distribution guessDimensions(DataSet dataSet) {
        return new UniformDiscrete(20, 200);
    }
}
