package jsat.datatransform;

import java.util.ArrayList;
import java.util.Arrays;
import jsat.DataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.DataPoint;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.Matrix;
import jsat.linear.Vec;
import org.apache.log4j.Priority;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/datatransform/PCA.class */
public class PCA implements DataTransform {
    private static final long serialVersionUID = 8736609877239941617L;
    private Matrix P;
    private int maxPCs;
    private double threshold;

    public PCA() {
        this(50);
    }

    public PCA(DataSet dataSet) {
        this(dataSet, Priority.OFF_INT);
    }

    public PCA(DataSet dataSet, int i) {
        this(dataSet, i, 1.0E-4d);
    }

    public PCA(int i) {
        this(i, 1.0E-4d);
    }

    public PCA(int i, double d) {
        setMaxPCs(i);
        setThreshold(d);
    }

    public PCA(DataSet dataSet, int i, double d) {
        this(i, d);
        fit(dataSet);
    }

    @Override // jsat.datatransform.DataTransform
    public void fit(DataSet dataSet) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        Matrix dataMatrix = dataSet.getDataMatrix();
        int min = Math.min(this.maxPCs, Math.min(dataSet.getSampleSize(), dataSet.getNumNumericalVars()));
        Vec column = getColumn(dataMatrix);
        double dot = column.dot(column);
        DenseVector denseVector = new DenseVector(dataMatrix.cols());
        for (int i = 1; i <= min; i++) {
            for (int i2 = 0; i2 < 100; i2++) {
                denseVector.zeroOut();
                dataMatrix.transposeMultiply(1.0d, column, denseVector);
                denseVector.mutableDivide(dot);
                denseVector.mutableMultiply(Math.pow(denseVector.dot(denseVector), -0.5d));
                column = dataMatrix.multiply(denseVector);
                column.mutableDivide(denseVector.dot(denseVector));
                double dot2 = column.dot(column);
                if ((i2 > 0 && Math.abs(dot2 - dot) <= this.threshold * dot2) || i2 == 99) {
                    arrayList.add(new DenseVector(column));
                    arrayList2.add(new DenseVector(denseVector));
                    break;
                }
                dot = dot2;
            }
            Matrix.OuterProductUpdate(dataMatrix, column, denseVector, -1.0d);
        }
        this.P = new DenseMatrix(arrayList2.size(), ((Vec) arrayList2.get(0)).length());
        for (int i3 = 0; i3 < arrayList2.size(); i3++) {
            Vec vec = (Vec) arrayList2.get(i3);
            for (int i4 = 0; i4 < vec.length(); i4++) {
                this.P.set(i3, i4, vec.get(i4));
            }
        }
    }

    private PCA(PCA pca) {
        if (pca.P != null) {
            this.P = pca.P.mo640clone();
        }
        this.maxPCs = pca.maxPCs;
        this.threshold = pca.threshold;
    }

    public void setMaxPCs(int i) {
        if (i <= 0) {
            throw new IllegalArgumentException("number of principal components must be a positive number, not " + i);
        }
        this.maxPCs = i;
    }

    public int getMaxPCs() {
        return this.maxPCs;
    }

    public void setThreshold(double d) {
        if (d <= 0.0d || Double.isInfinite(d) || Double.isNaN(d)) {
            throw new IllegalArgumentException("threshold must be in the range (0, Inf), not " + d);
        }
        this.threshold = d;
    }

    public double getThreshold() {
        return this.threshold;
    }

    private static Vec getColumn(Matrix matrix) {
        for (int i = 0; i < matrix.cols(); i++) {
            Vec column = matrix.getColumn(i);
            if (column.dot(column) > 0.0d) {
                return column;
            }
        }
        throw new ArithmeticException("Matrix is essentially zero");
    }

    @Override // jsat.datatransform.DataTransform
    public DataPoint transform(DataPoint dataPoint) {
        return new DataPoint(this.P.multiply(dataPoint.getNumericalValues()), Arrays.copyOf(dataPoint.getCategoricalValues(), dataPoint.numCategoricalValues()), CategoricalData.copyOf(dataPoint.getCategoricalData()), dataPoint.getWeight());
    }

    @Override // jsat.datatransform.DataTransform
    public DataTransform clone() {
        return new PCA(this);
    }
}
