package jsat.datatransform.kernel;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import jsat.DataSet;
import jsat.classifiers.DataPoint;
import jsat.clustering.SeedSelectionMethods;
import jsat.clustering.kmeans.HamerlyKMeans;
import jsat.datatransform.DataTransformBase;
import jsat.distributions.kernels.KernelTrick;
import jsat.distributions.kernels.RBFKernel;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.EigenValueDecomposition;
import jsat.linear.Matrix;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.parameters.Parameter;
import jsat.utils.DoubleList;
import jsat.utils.IntSet;
import jsat.utils.random.XOR96;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/datatransform/kernel/Nystrom.class */
public class Nystrom extends DataTransformBase {
    private static final long serialVersionUID = -3227844260130709773L;
    private double ridge;

    @Parameter.ParameterHolder
    private KernelTrick k;
    private int dimension;
    private SamplingMethod method;
    int basisSize;
    private boolean sampleWithReplacment;
    private List<Vec> basisVecs;
    private List<Double> accelCache;
    private Matrix transform;

    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/datatransform/kernel/Nystrom$SamplingMethod.class */
    public enum SamplingMethod {
        UNIFORM,
        DIAGONAL,
        NORM,
        KMEANS
    }

    public Nystrom(KernelTrick kernelTrick, DataSet dataSet, int i, SamplingMethod samplingMethod) {
        this(kernelTrick, dataSet, i, samplingMethod, 0.0d, false);
    }

    public Nystrom() {
        this(new RBFKernel(), 500);
    }

    public Nystrom(KernelTrick kernelTrick, int i) {
        this(kernelTrick, i, SamplingMethod.UNIFORM, 1.0E-5d, false);
    }

    public Nystrom(KernelTrick kernelTrick, int i, SamplingMethod samplingMethod, double d, boolean z) {
        setKernel(kernelTrick);
        setBasisSize(i);
        setBasisSamplingMethod(samplingMethod);
        setRidge(d);
        this.sampleWithReplacment = z;
    }

    public Nystrom(KernelTrick kernelTrick, DataSet dataSet, int i, SamplingMethod samplingMethod, double d, boolean z) {
        this(kernelTrick, i, samplingMethod, d, z);
        fit(dataSet);
    }

    @Override // jsat.datatransform.DataTransform
    public void fit(DataSet dataSet) {
        XOR96 xor96 = new XOR96();
        if (this.ridge < 0.0d) {
            throw new IllegalArgumentException("ridge must be positive, not " + this.ridge);
        }
        dataSet.getSampleSize();
        dataSet.getNumNumericalVars();
        this.basisVecs = sampleBasisVectors(this.k, dataSet, dataSet.getDataVectors(), this.method, this.basisSize, this.sampleWithReplacment, xor96);
        this.accelCache = this.k.getAccelerationCache(this.basisVecs);
        DenseMatrix denseMatrix = new DenseMatrix(this.basisSize, this.basisSize);
        for (int i = 0; i < this.basisSize; i++) {
            denseMatrix.set(i, i, this.k.eval(i, i, this.basisVecs, this.accelCache));
            for (int i2 = i + 1; i2 < this.basisSize; i2++) {
                double eval = this.k.eval(i, i2, this.basisVecs, this.accelCache);
                denseMatrix.set(i, i2, eval);
                denseMatrix.set(i2, i, eval);
            }
        }
        EigenValueDecomposition eigenValueDecomposition = new EigenValueDecomposition(denseMatrix);
        double[] realEigenvalues = eigenValueDecomposition.getRealEigenvalues();
        DenseVector denseVector = new DenseVector(realEigenvalues.length);
        for (int i3 = 0; i3 < realEigenvalues.length; i3++) {
            denseVector.set(i3, 1.0d / Math.sqrt(Math.max(1.0E-7d, realEigenvalues[i3] + this.ridge)));
        }
        Matrix v = eigenValueDecomposition.getV();
        Matrix.diagMult(v, denseVector);
        this.transform = v.multiply(eigenValueDecomposition.getVRaw());
        this.transform.mutableTranspose();
    }

    protected Nystrom(Nystrom nystrom) {
        this.k = nystrom.k.m628clone();
        this.method = nystrom.method;
        this.sampleWithReplacment = nystrom.sampleWithReplacment;
        this.dimension = nystrom.dimension;
        this.ridge = nystrom.ridge;
        this.basisSize = nystrom.basisSize;
        if (nystrom.basisVecs != null) {
            this.basisVecs = new ArrayList(nystrom.basisVecs.size());
            Iterator<Vec> it = nystrom.basisVecs.iterator();
            while (it.hasNext()) {
                this.basisVecs.add(it.next().mo524clone());
            }
            if (nystrom.accelCache != null) {
                this.accelCache = new DoubleList(nystrom.accelCache);
            }
        }
        if (nystrom.transform != null) {
            this.transform = nystrom.transform.mo640clone();
        }
    }

    public static List<Vec> sampleBasisVectors(KernelTrick kernelTrick, DataSet dataSet, List<Vec> list, SamplingMethod samplingMethod, int i, boolean z, Random random) {
        ArrayList arrayList = new ArrayList(i);
        int sampleSize = dataSet.getSampleSize();
        switch (samplingMethod) {
            case DIAGONAL:
                double[] dArr = new double[sampleSize];
                dArr[0] = kernelTrick.eval(list.get(0), list.get(0));
                for (int i2 = 1; i2 < sampleSize; i2++) {
                    dArr[i2] = dArr[i2 - 1] + kernelTrick.eval(list.get(i2), list.get(i2));
                }
                sample(i, random, dArr, list, z, arrayList);
                break;
            case NORM:
                double[] dArr2 = new double[sampleSize];
                ArrayList arrayList2 = new ArrayList();
                for (int i3 = 0; i3 < sampleSize; i3++) {
                    arrayList2.add(new DenseVector(sampleSize));
                }
                List<Double> accelerationCache = kernelTrick.getAccelerationCache(list);
                for (int i4 = 0; i4 < sampleSize; i4++) {
                    ((Vec) arrayList2.get(i4)).set(i4, kernelTrick.eval(i4, i4, list, accelerationCache));
                    for (int i5 = i4 + 1; i5 < sampleSize; i5++) {
                        double eval = kernelTrick.eval(i4, i5, list, accelerationCache);
                        ((Vec) arrayList2.get(i4)).set(i5, eval);
                        ((Vec) arrayList2.get(i5)).set(i4, eval);
                    }
                }
                dArr2[0] = ((Vec) arrayList2.get(0)).pNorm(2.0d);
                for (int i6 = 1; i6 < arrayList2.size(); i6++) {
                    dArr2[i6] = dArr2[i6 - 1] + ((Vec) arrayList2.get(i6)).pNorm(2.0d);
                }
                sample(i, random, dArr2, list, z, arrayList);
                break;
            case KMEANS:
                HamerlyKMeans hamerlyKMeans = new HamerlyKMeans(new EuclideanDistance(), SeedSelectionMethods.SeedSelection.KPP);
                hamerlyKMeans.setStoreMeans(true);
                hamerlyKMeans.cluster(dataSet, i);
                arrayList.addAll(hamerlyKMeans.getMeans());
                break;
            case UNIFORM:
            default:
                if (z) {
                    IntSet intSet = new IntSet(i);
                    while (intSet.size() < i) {
                        intSet.add((IntSet) Integer.valueOf(random.nextInt(sampleSize)));
                    }
                    Iterator<Integer> it = intSet.iterator();
                    while (it.hasNext()) {
                        arrayList.add(list.get(it.next().intValue()));
                    }
                    break;
                } else {
                    for (int i7 = 0; i7 < i; i7++) {
                        arrayList.add(list.get(random.nextInt(sampleSize)));
                    }
                    break;
                }
        }
        return arrayList;
    }

    private static void sample(int i, Random random, double[] dArr, List<Vec> list, boolean z, List<Vec> list2) {
        IntSet intSet = new IntSet(i);
        double d = dArr[dArr.length - 1];
        int i2 = 0;
        while (i2 < i) {
            int binarySearch = Arrays.binarySearch(dArr, random.nextDouble() * d);
            if (binarySearch < 0) {
                binarySearch = (-binarySearch) - 1;
            }
            if (z) {
                list2.add(list.get(binarySearch));
            } else {
                int size = intSet.size();
                intSet.add((IntSet) Integer.valueOf(binarySearch));
                if (intSet.size() == size) {
                    i2--;
                } else {
                    list2.add(list.get(binarySearch));
                }
            }
            i2++;
        }
    }

    @Override // jsat.datatransform.DataTransform
    public DataPoint transform(DataPoint dataPoint) {
        Vec numericalValues = dataPoint.getNumericalValues();
        List<Double> queryInfo = this.k.getQueryInfo(numericalValues);
        DenseVector denseVector = new DenseVector(this.basisVecs.size());
        for (int i = 0; i < this.basisVecs.size(); i++) {
            denseVector.set(i, this.k.eval(i, numericalValues, queryInfo, this.basisVecs, this.accelCache));
        }
        return new DataPoint(denseVector.multiply(this.transform), dataPoint.getCategoricalValues(), dataPoint.getCategoricalData(), dataPoint.getWeight());
    }

    @Override // jsat.datatransform.DataTransformBase
    public Nystrom clone() {
        return new Nystrom(this);
    }

    public void setRidge(double d) {
        if (d < 0.0d || Double.isNaN(d) || Double.isInfinite(d)) {
            throw new IllegalArgumentException("Ridge must be non negative, not " + d);
        }
        this.ridge = d;
    }

    public double getRidge() {
        return this.ridge;
    }

    public void setDimension(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("The number of dimensions must be positive, not " + i);
        }
        this.dimension = i;
    }

    public int getDimension() {
        return this.dimension;
    }

    public void setBasisSamplingMethod(SamplingMethod samplingMethod) {
        this.method = samplingMethod;
    }

    public SamplingMethod getBasisSamplingMethod() {
        return this.method;
    }

    public void setBasisSize(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("The basis size must be positive, not " + i);
        }
        this.basisSize = i;
    }

    public int getBasisSize() {
        return this.basisSize;
    }

    public void setKernel(KernelTrick kernelTrick) {
        this.k = kernelTrick;
    }

    public KernelTrick getKernel() {
        return this.k;
    }
}
