package edu.ucsf.rbvi.scNetViz.internal.algorithms.tSNE;

import edu.ucsf.rbvi.scNetViz.internal.algorithms.tSNE.ParallelVpTree;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.Future;
import java.util.concurrent.RecursiveAction;

/* loaded from: input_file:edu/ucsf/rbvi/scNetViz/internal/algorithms/tSNE/ParallelBHTsne.class */
public class ParallelBHTsne extends BHTSne {
    private ForkJoinPool gradientPool;
    private ExecutorService gradientCalculationPool;

    /* loaded from: input_file:edu/ucsf/rbvi/scNetViz/internal/algorithms/tSNE/ParallelBHTsne$ParallelGradientCalculator.class */
    class ParallelGradientCalculator implements Callable<Double> {
        static final long serialVersionUID = 1;
        int row;
        int limit;
        ParallelSPTree tree;
        double[][] neg_f;
        double theta;

        public ParallelGradientCalculator(ParallelSPTree parallelSPTree, double[][] dArr, double d, int i, int i2) {
            this.row = -1;
            this.limit = 100;
            this.limit = i2;
            this.row = i;
            this.tree = parallelSPTree;
            this.neg_f = dArr;
            this.theta = d;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Double call() {
            return Double.valueOf(this.tree.computeNonEdgeForces(this.row, this.theta, this.neg_f[this.row], Double.valueOf(0.0d)));
        }
    }

    /* loaded from: input_file:edu/ucsf/rbvi/scNetViz/internal/algorithms/tSNE/ParallelBHTsne$RecursiveGradientCalculator.class */
    class RecursiveGradientCalculator extends RecursiveAction {
        static final long serialVersionUID = 1;
        int startRow;
        int endRow;
        int limit;
        SPTree tree;
        double[][] neg_f;
        double theta;
        AtomicDouble sum_Q;

        public RecursiveGradientCalculator(SPTree sPTree, double[][] dArr, double d, AtomicDouble atomicDouble, int i, int i2, int i3) {
            this.startRow = -1;
            this.endRow = -1;
            this.limit = 100;
            this.limit = i3;
            this.startRow = i;
            this.endRow = i2;
            this.tree = sPTree;
            this.neg_f = dArr;
            this.theta = d;
            this.sum_Q = atomicDouble;
        }

        @Override // java.util.concurrent.RecursiveAction
        protected void compute() {
            if (this.endRow - this.startRow <= this.limit) {
                for (int i = this.startRow; i < this.endRow; i++) {
                    this.tree.computeNonEdgeForces(i, this.theta, this.neg_f[i], this.sum_Q);
                }
                return;
            }
            int i2 = this.endRow - this.startRow;
            int i3 = this.startRow;
            int i4 = this.startRow + (i2 / 2);
            invokeAll(new RecursiveGradientCalculator(this.tree, this.neg_f, this.theta, this.sum_Q, i3, i4, this.limit), new RecursiveGradientCalculator(this.tree, this.neg_f, this.theta, this.sum_Q, i4, this.endRow, this.limit));
        }
    }

    /* loaded from: input_file:edu/ucsf/rbvi/scNetViz/internal/algorithms/tSNE/ParallelBHTsne$RecursiveGradientUpdater.class */
    class RecursiveGradientUpdater extends RecursiveAction {
        static final long serialVersionUID = 1;
        int startIdx;
        int endIdx;
        int limit;
        int N;
        int no_dims;
        double[] Y;
        double momentum;
        double eta;
        double[] dY;
        double[] uY;
        double[] gains;

        public RecursiveGradientUpdater(int i, int i2, double[] dArr, double d, double d2, double[] dArr2, double[] dArr3, double[] dArr4, int i3, int i4, int i5) {
            this.startIdx = -1;
            this.endIdx = -1;
            this.limit = 100;
            this.startIdx = i3;
            this.endIdx = i4;
            this.limit = i5;
            this.N = i;
            this.no_dims = i2;
            this.Y = dArr;
            this.momentum = d;
            this.eta = d2;
            this.dY = dArr2;
            this.uY = dArr3;
            this.gains = dArr4;
        }

        @Override // java.util.concurrent.RecursiveAction
        protected void compute() {
            if (this.endIdx - this.startIdx > this.limit) {
                int i = this.endIdx - this.startIdx;
                int i2 = this.startIdx;
                int i3 = this.startIdx + (i / 2);
                invokeAll(new RecursiveGradientUpdater(this.N, this.no_dims, this.Y, this.momentum, this.eta, this.dY, this.uY, this.gains, i2, i3, this.limit), new RecursiveGradientUpdater(this.N, this.no_dims, this.Y, this.momentum, this.eta, this.dY, this.uY, this.gains, i3, this.endIdx, this.limit));
                return;
            }
            for (int i4 = this.startIdx; i4 < this.endIdx; i4++) {
                this.gains[i4] = BHTSne.sign_tsne(this.dY[i4]) != BHTSne.sign_tsne(this.uY[i4]) ? this.gains[i4] + 0.2d : this.gains[i4] * 0.8d;
                if (this.gains[i4] < 0.01d) {
                    this.gains[i4] = 0.01d;
                }
                this.Y[i4] = this.Y[i4] + this.uY[i4];
                this.uY[i4] = (this.momentum * this.uY[i4]) - ((this.eta * this.gains[i4]) * this.dY[i4]);
            }
        }
    }

    @Override // edu.ucsf.rbvi.scNetViz.internal.algorithms.tSNE.BHTSne
    double[][] run(TSneConfiguration tSneConfiguration) {
        this.gradientPool = new ForkJoinPool(Runtime.getRuntime().availableProcessors());
        this.gradientCalculationPool = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
        double[][] run = super.run(tSneConfiguration);
        this.gradientPool.shutdown();
        this.gradientCalculationPool.shutdown();
        return run;
    }

    @Override // edu.ucsf.rbvi.scNetViz.internal.algorithms.tSNE.BHTSne
    void updateGradient(int i, int i2, double[] dArr, double d, double d2, double[] dArr2, double[] dArr3, double[] dArr4) {
        int availableProcessors = i / (Runtime.getRuntime().availableProcessors() * 10);
        if (availableProcessors == 0) {
            availableProcessors = 1;
        }
        this.gradientPool.invoke(new RecursiveGradientUpdater(i, i2, dArr, d, d2, dArr2, dArr3, dArr4, 0, i * i2, availableProcessors));
    }

    @Override // edu.ucsf.rbvi.scNetViz.internal.algorithms.tSNE.BHTSne
    void computeGradient(double[] dArr, int[] iArr, int[] iArr2, double[] dArr2, double[] dArr3, int i, int i2, double[] dArr4, double d) {
        ParallelSPTree parallelSPTree = new ParallelSPTree(i2, dArr3, i);
        double[] dArr5 = new double[i * i2];
        double[][] dArr6 = new double[i][i2];
        parallelSPTree.computeEdgeForces(iArr, iArr2, dArr2, i, dArr5);
        double d2 = 0.0d;
        ArrayList arrayList = new ArrayList();
        for (int i3 = 0; i3 < i; i3++) {
            arrayList.add(new ParallelGradientCalculator(parallelSPTree, dArr6, d, i3, 20));
        }
        try {
            Iterator it = this.gradientCalculationPool.invokeAll(arrayList).iterator();
            while (it.hasNext()) {
                d2 += ((Double) ((Future) it.next()).get()).doubleValue();
            }
        } catch (InterruptedException e) {
            e.printStackTrace();
            System.exit(-1);
        } catch (ExecutionException e2) {
            e2.printStackTrace();
            System.exit(-1);
        }
        for (int i4 = 0; i4 < i; i4++) {
            for (int i5 = 0; i5 < i2; i5++) {
                dArr4[(i4 * i2) + i5] = dArr5[(i4 * i2) + i5] - (dArr6[i4][i5] / d2);
            }
        }
    }

    @Override // edu.ucsf.rbvi.scNetViz.internal.algorithms.tSNE.BHTSne
    void computeGaussianPerplexity(double[] dArr, int i, int i2, int[] iArr, int[] iArr2, double[] dArr2, double d, int i3) {
        if (d > i3) {
            System.out.println("Perplexity should be lower than K!");
        }
        double[] dArr3 = new double[i - 1];
        iArr[0] = 0;
        for (int i4 = 0; i4 < i; i4++) {
            iArr[i4 + 1] = iArr[i4] + i3;
        }
        ParallelVpTree parallelVpTree = new ParallelVpTree(this.gradientPool, this.distance);
        DataPoint[] dataPointArr = new DataPoint[i];
        for (int i5 = 0; i5 < i; i5++) {
            dataPointArr[i5] = new DataPoint(i2, i5, MatrixOps.extractRowFromFlatMatrix(dArr, i5, i2));
        }
        parallelVpTree.create(dataPointArr);
        Iterator it = parallelVpTree.searchMultiple(parallelVpTree, dataPointArr, i3 + 1).iterator();
        while (it.hasNext()) {
            List<Double> list = null;
            List<DataPoint> list2 = null;
            int i6 = -1;
            try {
                ParallelVpTree.ParallelTreeNode.TreeSearchResult treeSearchResult = (ParallelVpTree.ParallelTreeNode.TreeSearchResult) ((Future) it.next()).get();
                list = treeSearchResult.getDistances();
                list2 = treeSearchResult.getIndices();
                i6 = treeSearchResult.getIndex();
            } catch (InterruptedException | ExecutionException e) {
                e.printStackTrace();
            }
            boolean z = false;
            double d2 = 1.0d;
            double d3 = -1.7976931348623157E308d;
            double d4 = Double.MAX_VALUE;
            double d5 = 0.0d;
            for (int i7 = 0; !z && i7 < 200; i7++) {
                d5 = Double.MIN_VALUE;
                double d6 = 0.0d;
                for (int i8 = 0; i8 < i3; i8++) {
                    dArr3[i8] = Math.exp((-d2) * list.get(i8 + 1).doubleValue());
                    d5 += dArr3[i8];
                    d6 += d2 * list.get(i8 + 1).doubleValue() * dArr3[i8];
                }
                double log = ((d6 / d5) + Math.log(d5)) - Math.log(d);
                if (log < 1.0E-5d && (-log) < 1.0E-5d) {
                    z = true;
                } else if (log > 0.0d) {
                    d3 = d2;
                    d2 = (d4 == Double.MAX_VALUE || d4 == -1.7976931348623157E308d) ? d2 * 2.0d : (d2 + d4) / 2.0d;
                } else {
                    d4 = d2;
                    d2 = (d3 == -1.7976931348623157E308d || d3 == Double.MAX_VALUE) ? d2 / 2.0d : (d2 + d3) / 2.0d;
                }
            }
            for (int i9 = 0; i9 < i3; i9++) {
                int i10 = i9;
                dArr3[i10] = dArr3[i10] / d5;
                iArr2[iArr[i6] + i9] = list2.get(i9 + 1).index();
                dArr2[iArr[i6] + i9] = dArr3[i9];
            }
        }
    }
}
