package com.jujutsu.tsne;

import com.jujutsu.tsne.TSne;
import com.jujutsu.tsne.barneshut.TSneConfiguration;
import com.jujutsu.utils.EjmlOps;
import com.jujutsu.utils.MatrixOps;
import java.io.BufferedInputStream;
import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import org.apache.http.HttpStatus;
import org.cytoscape.work.TaskMonitor;
import org.ejml.data.DMatrixD1;
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.CommonOps_DDRM;
import org.jdesktop.swingx.JXLabel;

/* loaded from: input_file:com/jujutsu/tsne/FastTSne.class */
public class FastTSne implements TSne {
    MatrixOps mo = new MatrixOps();
    protected volatile boolean abort = false;
    TaskMonitor monitor;
    TSneConfiguration config;

    public static double[][] readBinaryDoubleMatrix(int i, int i2, String str) throws FileNotFoundException, IOException {
        double[][] dArr = new double[i][i2];
        DataInputStream dataInputStream = new DataInputStream(new BufferedInputStream(new FileInputStream(new File(str).getAbsolutePath())));
        for (double[] dArr2 : dArr) {
            try {
                for (int i3 = 0; i3 < dArr[0].length; i3++) {
                    dArr2[i3] = dataInputStream.readDouble();
                }
            } catch (Throwable th) {
                try {
                    dataInputStream.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
                throw th;
            }
        }
        dataInputStream.close();
        return dArr;
    }

    @Override // com.jujutsu.tsne.TSne
    public double[][] tsne(TSneConfiguration tSneConfiguration, TaskMonitor taskMonitor) {
        double[][] xin = tSneConfiguration.getXin();
        int outputDims = tSneConfiguration.getOutputDims();
        int initialDims = tSneConfiguration.getInitialDims();
        double perplexity = tSneConfiguration.getPerplexity();
        int maxIter = tSneConfiguration.getMaxIter();
        boolean usePca = tSneConfiguration.usePca();
        this.monitor = taskMonitor;
        this.config = tSneConfiguration;
        taskMonitor.showMessage(TaskMonitor.Level.INFO, "Running " + getClass().getSimpleName() + ".");
        System.currentTimeMillis();
        long currentTimeMillis = System.currentTimeMillis();
        long currentTimeMillis2 = System.currentTimeMillis();
        if (usePca && xin[0].length > initialDims && initialDims > 0) {
            taskMonitor.showMessage(TaskMonitor.Level.INFO, "Using PCA to reduce dimensions");
            xin = new PrincipalComponentAnalysis().pca(xin, initialDims);
            taskMonitor.showMessage(TaskMonitor.Level.INFO, "X:Shape after PCA is = " + xin.length + " x " + xin[0].length);
        }
        int length = xin.length;
        DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(MatrixOps.rnorm(length, outputDims));
        DMatrixRMaj dMatrixRMaj2 = new DMatrixRMaj(dMatrixRMaj.numRows, dMatrixRMaj.numRows);
        DMatrixRMaj dMatrixRMaj3 = new DMatrixRMaj(MatrixOps.fillMatrix(length, outputDims, JXLabel.NORMAL));
        DMatrixRMaj dMatrixRMaj4 = new DMatrixRMaj(MatrixOps.fillMatrix(length, outputDims, JXLabel.NORMAL));
        DMatrixRMaj dMatrixRMaj5 = new DMatrixRMaj(MatrixOps.fillMatrix(length, outputDims, 1.0d));
        DMatrixRMaj dMatrixRMaj6 = new DMatrixRMaj(length, outputDims);
        DMatrixRMaj dMatrixRMaj7 = new DMatrixRMaj(length, outputDims);
        if (tSneConfiguration.cancelled()) {
            return null;
        }
        DMatrixRMaj dMatrixRMaj8 = new DMatrixRMaj(x2p(xin, 1.0E-5d, perplexity).P);
        if (tSneConfiguration.cancelled()) {
            return null;
        }
        DMatrixRMaj dMatrixRMaj9 = new DMatrixRMaj(dMatrixRMaj8.numRows, dMatrixRMaj8.numCols);
        DMatrixRMaj dMatrixRMaj10 = new DMatrixRMaj(dMatrixRMaj8);
        DMatrixRMaj dMatrixRMaj11 = new DMatrixRMaj(dMatrixRMaj8.numRows, dMatrixRMaj8.numCols);
        DMatrixRMaj dMatrixRMaj12 = new DMatrixRMaj(MatrixOps.fillMatrix(dMatrixRMaj10.numRows, dMatrixRMaj10.numCols, JXLabel.NORMAL));
        CommonOps_DDRM.transpose(dMatrixRMaj8, dMatrixRMaj9);
        CommonOps_DDRM.addEquals(dMatrixRMaj8, dMatrixRMaj9);
        CommonOps_DDRM.divide(dMatrixRMaj8, CommonOps_DDRM.elementSum(dMatrixRMaj8));
        EjmlOps.replaceNaN(dMatrixRMaj8, Double.MIN_VALUE);
        CommonOps_DDRM.scale(4.0d, dMatrixRMaj8);
        EjmlOps.maximize(dMatrixRMaj8, 1.0E-12d);
        if (tSneConfiguration.cancelled()) {
            return null;
        }
        taskMonitor.showMessage(TaskMonitor.Level.INFO, "Y:Shape is = " + dMatrixRMaj.getNumRows() + " x " + dMatrixRMaj.getNumCols());
        DMatrixRMaj dMatrixRMaj13 = new DMatrixRMaj(dMatrixRMaj.numRows, dMatrixRMaj.numCols);
        DMatrixRMaj dMatrixRMaj14 = new DMatrixRMaj(1, dMatrixRMaj.numRows);
        DMatrixRMaj dMatrixRMaj15 = new DMatrixRMaj(dMatrixRMaj.numRows, dMatrixRMaj.numRows);
        DMatrixRMaj dMatrixRMaj16 = new DMatrixRMaj(dMatrixRMaj8.numRows, dMatrixRMaj8.numCols);
        System.out.println("Created sum_Y: (" + dMatrixRMaj14.numRows + "," + dMatrixRMaj14.numCols + ")");
        int i = 0;
        while (i < maxIter && !this.abort) {
            if (tSneConfiguration.cancelled()) {
                return null;
            }
            taskMonitor.setProgress(i / maxIter);
            CommonOps_DDRM.elementPower(dMatrixRMaj, 2.0d, dMatrixRMaj13);
            CommonOps_DDRM.sumRows(dMatrixRMaj13, dMatrixRMaj14);
            CommonOps_DDRM.transpose(dMatrixRMaj14);
            CommonOps_DDRM.multAddTransB(-2.0d, dMatrixRMaj, dMatrixRMaj, dMatrixRMaj2);
            EjmlOps.addRowVector(dMatrixRMaj2, dMatrixRMaj14);
            CommonOps_DDRM.transpose(dMatrixRMaj2);
            EjmlOps.addRowVector(dMatrixRMaj2, dMatrixRMaj14);
            CommonOps_DDRM.add(dMatrixRMaj2, 1.0d);
            CommonOps_DDRM.divide(1.0d, dMatrixRMaj2);
            dMatrixRMaj15.set((DMatrixD1) dMatrixRMaj2);
            EjmlOps.assignAtIndex(dMatrixRMaj15, MatrixOps.range(length), MatrixOps.range(length), JXLabel.NORMAL);
            CommonOps_DDRM.divide(dMatrixRMaj15, CommonOps_DDRM.elementSum(dMatrixRMaj15), dMatrixRMaj16);
            EjmlOps.maximize(dMatrixRMaj16, 1.0E-12d);
            CommonOps_DDRM.subtract(dMatrixRMaj8, dMatrixRMaj16, dMatrixRMaj10);
            CommonOps_DDRM.elementMult(dMatrixRMaj10, dMatrixRMaj15);
            DMatrixRMaj sumRows = CommonOps_DDRM.sumRows(dMatrixRMaj10, null);
            double[] dArr = new double[sumRows.numRows];
            for (int i2 = 0; i2 < dArr.length; i2++) {
                dArr[i2] = sumRows.get(i2, 0);
            }
            EjmlOps.setDiag(dMatrixRMaj12, dArr);
            CommonOps_DDRM.subtract(dMatrixRMaj12, dMatrixRMaj10, dMatrixRMaj10);
            CommonOps_DDRM.mult(dMatrixRMaj10, dMatrixRMaj, dMatrixRMaj3);
            CommonOps_DDRM.scale(4.0d, dMatrixRMaj3);
            double d = i < 20 ? 0.5d : 0.8d;
            boolean[][] equal = MatrixOps.equal(EjmlOps.biggerThan(dMatrixRMaj3, JXLabel.NORMAL), EjmlOps.biggerThan(dMatrixRMaj4, JXLabel.NORMAL));
            EjmlOps.setData(dMatrixRMaj6, MatrixOps.abs(MatrixOps.negate(equal)));
            EjmlOps.setData(dMatrixRMaj7, MatrixOps.abs(equal));
            DMatrixRMaj dMatrixRMaj17 = new DMatrixRMaj(dMatrixRMaj5);
            DMatrixRMaj dMatrixRMaj18 = new DMatrixRMaj(dMatrixRMaj5);
            CommonOps_DDRM.add(dMatrixRMaj17, 0.2d);
            CommonOps_DDRM.scale(0.8d, dMatrixRMaj18);
            CommonOps_DDRM.elementMult(dMatrixRMaj17, dMatrixRMaj6);
            CommonOps_DDRM.elementMult(dMatrixRMaj18, dMatrixRMaj7);
            CommonOps_DDRM.add(dMatrixRMaj17, dMatrixRMaj18, dMatrixRMaj5);
            EjmlOps.assignAllLessThan(dMatrixRMaj5, 0.01d, 0.01d);
            CommonOps_DDRM.scale(d, dMatrixRMaj4);
            DMatrixRMaj dMatrixRMaj19 = new DMatrixRMaj(dMatrixRMaj5.numRows, dMatrixRMaj3.numCols);
            CommonOps_DDRM.elementMult(dMatrixRMaj5, dMatrixRMaj3, dMatrixRMaj19);
            CommonOps_DDRM.scale(HttpStatus.SC_INTERNAL_SERVER_ERROR, dMatrixRMaj19);
            CommonOps_DDRM.subtractEquals(dMatrixRMaj4, dMatrixRMaj19);
            CommonOps_DDRM.addEquals(dMatrixRMaj, dMatrixRMaj4);
            CommonOps_DDRM.subtractEquals(dMatrixRMaj, EjmlOps.tile(EjmlOps.colMean(dMatrixRMaj, 0), length, 1));
            if (i % 100 == 0) {
                DMatrixRMaj dMatrixRMaj20 = new DMatrixRMaj(dMatrixRMaj8);
                CommonOps_DDRM.elementDiv(dMatrixRMaj20, dMatrixRMaj16);
                CommonOps_DDRM.elementLog(dMatrixRMaj20, dMatrixRMaj11);
                EjmlOps.replaceNaN(dMatrixRMaj11, Double.MIN_VALUE);
                CommonOps_DDRM.elementMult(dMatrixRMaj11, dMatrixRMaj8);
                EjmlOps.replaceNaN(dMatrixRMaj11, Double.MIN_VALUE);
                double elementSum = CommonOps_DDRM.elementSum(dMatrixRMaj11);
                taskMonitor.showMessage(TaskMonitor.Level.INFO, String.format("Iteration %d: error is %f (100 iterations in %4.2f seconds)\n", Integer.valueOf(i), Double.valueOf(elementSum), Double.valueOf((System.currentTimeMillis() - currentTimeMillis) / 1000.0d)));
                if (elementSum < JXLabel.NORMAL) {
                    taskMonitor.showMessage(TaskMonitor.Level.WARN, "Warning: Error is negative, this is usually a very bad sign!");
                }
                currentTimeMillis = System.currentTimeMillis();
            }
            if (i == 100) {
                CommonOps_DDRM.divide(dMatrixRMaj8, 4.0d);
            }
            i++;
        }
        taskMonitor.showMessage(TaskMonitor.Level.INFO, String.format("Completed in %4.2f seconds)", Double.valueOf((System.currentTimeMillis() - currentTimeMillis2) / 1000.0d)));
        return EjmlOps.extractDoubleArray(dMatrixRMaj);
    }

    public TSne.R Hbeta(double[][] dArr, double d) {
        DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(dArr);
        CommonOps_DDRM.scale(-d, dMatrixRMaj);
        CommonOps_DDRM.elementExp(dMatrixRMaj, dMatrixRMaj);
        double elementSum = CommonOps_DDRM.elementSum(dMatrixRMaj);
        DMatrixRMaj dMatrixRMaj2 = new DMatrixRMaj(dArr);
        CommonOps_DDRM.elementMult(dMatrixRMaj2, dMatrixRMaj);
        double log = Math.log(elementSum) + ((d * CommonOps_DDRM.elementSum(dMatrixRMaj2)) / elementSum);
        CommonOps_DDRM.scale(1.0d / elementSum, dMatrixRMaj);
        TSne.R r = new TSne.R();
        r.H = log;
        r.P = EjmlOps.extractDoubleArray(dMatrixRMaj);
        return r;
    }

    public TSne.R x2p(double[][] dArr, double d, double d2) {
        int length = dArr.length;
        double[][] sum = MatrixOps.sum(MatrixOps.square(dArr), 1);
        double[][] addRowVector = MatrixOps.addRowVector(MatrixOps.addColumnVector(this.mo.transpose(MatrixOps.scalarMult(MatrixOps.times(dArr, this.mo.transpose(dArr)), -2.0d)), sum), this.mo.transpose(sum));
        double[][] fillMatrix = MatrixOps.fillMatrix(length, length, JXLabel.NORMAL);
        double[] dArr2 = MatrixOps.fillMatrix(length, length, 1.0d)[0];
        double log = Math.log(d2);
        for (int i = 0; i < length && !this.config.cancelled(); i++) {
            if (i % HttpStatus.SC_INTERNAL_SERVER_ERROR == 0) {
                this.monitor.showMessage(TaskMonitor.Level.INFO, "Computing P-values for point " + i + " of " + length + "...");
            }
            double d3 = Double.NEGATIVE_INFINITY;
            double d4 = Double.POSITIVE_INFINITY;
            double[][] valuesFromRow = MatrixOps.getValuesFromRow(addRowVector, i, MatrixOps.concatenate(MatrixOps.range(0, i), MatrixOps.range(i + 1, length)));
            TSne.R Hbeta = Hbeta(valuesFromRow, dArr2[i]);
            double d5 = Hbeta.H;
            double[][] dArr3 = Hbeta.P;
            double d6 = d5 - log;
            int i2 = 0;
            while (true) {
                int i3 = i2;
                if (Math.abs(d6) > d && i3 < 50) {
                    if (d6 > JXLabel.NORMAL) {
                        d3 = dArr2[i];
                        if (Double.isInfinite(d4)) {
                            dArr2[i] = dArr2[i] * 2.0d;
                        } else {
                            dArr2[i] = (dArr2[i] + d4) / 2.0d;
                        }
                    } else {
                        d4 = dArr2[i];
                        if (Double.isInfinite(d3)) {
                            dArr2[i] = dArr2[i] / 2.0d;
                        } else {
                            dArr2[i] = (dArr2[i] + d3) / 2.0d;
                        }
                    }
                    TSne.R Hbeta2 = Hbeta(valuesFromRow, dArr2[i]);
                    double d7 = Hbeta2.H;
                    dArr3 = Hbeta2.P;
                    d6 = d7 - log;
                    i2 = i3 + 1;
                }
            }
            MatrixOps.assignValuesToRow(fillMatrix, i, MatrixOps.concatenate(MatrixOps.range(0, i), MatrixOps.range(i + 1, length)), dArr3[0]);
        }
        TSne.R r = new TSne.R();
        r.P = fillMatrix;
        r.beta = dArr2;
        this.monitor.showMessage(TaskMonitor.Level.INFO, "Mean value of sigma: " + MatrixOps.mean(MatrixOps.sqrt(MatrixOps.scalarInverse(dArr2))));
        return r;
    }

    @Override // com.jujutsu.tsne.TSne
    public void abort() {
        this.abort = true;
    }
}
