package jsat.clustering.kmeans;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.clustering.ClusterFailureException;
import jsat.distributions.kernels.KernelTrick;
import jsat.exceptions.FailedToFitException;
import jsat.linear.Vec;
import jsat.utils.FakeExecutor;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.ParallelUtils;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/clustering/kmeans/ElkanKernelKMeans.class */
public class ElkanKernelKMeans extends KernelKMeans {
    private static final long serialVersionUID = 4998832201379993827L;
    private double[][] centroidSelfDistances;
    private double[][] centroidPairDots;

    public ElkanKernelKMeans(KernelTrick kernelTrick) {
        super(kernelTrick);
    }

    public ElkanKernelKMeans(ElkanKernelKMeans elkanKernelKMeans) {
        super(elkanKernelKMeans);
    }

    @Override // jsat.clustering.kmeans.KernelKMeans
    public int findClosestCluster(Vec vec, List<Double> list) {
        double d = Double.MAX_VALUE;
        int i = -1;
        boolean[] zArr = new boolean[this.meanSqrdNorms.length];
        Arrays.fill(zArr, false);
        for (int i2 = 0; i2 < this.meanSqrdNorms.length; i2++) {
            if (this.ownes[i2] > 1.0E-15d && !zArr[i2]) {
                double distance = distance(vec, list, i2);
                if (distance < d) {
                    d = distance;
                    i = i2;
                }
                for (int i3 = i2 + 1; i3 < this.meanSqrdNorms.length; i3++) {
                    if (this.centroidSelfDistances[i2][i3] >= 2.0d * distance) {
                        zArr[i3] = true;
                    }
                }
            }
        }
        return i;
    }

    private void update_centroid_pair_dots(final int[] iArr, final int[] iArr2, ExecutorService executorService) {
        int i;
        final int size = this.X.size();
        if (executorService == null || (executorService instanceof FakeExecutor)) {
            executorService = new FakeExecutor();
            i = 1;
        } else {
            i = SystemInfo.LogicalCores;
        }
        ArrayList arrayList = new ArrayList(i);
        for (int i2 = 0; i2 < i; i2++) {
            final int i3 = i2;
            final int i4 = i;
            arrayList.add(executorService.submit(new Callable<double[][]>() { // from class: jsat.clustering.kmeans.ElkanKernelKMeans.1
                /* JADX WARN: Can't rename method to resolve collision */
                @Override // java.util.concurrent.Callable
                public double[][] call() throws Exception {
                    double[][] dArr = new double[ElkanKernelKMeans.this.centroidPairDots.length][ElkanKernelKMeans.this.centroidPairDots.length];
                    int i5 = i3;
                    while (true) {
                        int i6 = i5;
                        if (i6 >= size) {
                            return dArr;
                        }
                        double d = ElkanKernelKMeans.this.W.get(i6);
                        int i7 = iArr[i6];
                        int i8 = iArr2[i6];
                        for (int i9 = i6; i9 < size; i9++) {
                            int i10 = iArr[i9];
                            int i11 = iArr2[i9];
                            if (i7 != i8 || i10 != i11) {
                                double eval = d * ElkanKernelKMeans.this.W.get(i9) * ElkanKernelKMeans.this.kernel.eval(i6, i9, ElkanKernelKMeans.this.X, ElkanKernelKMeans.this.accel);
                                if (i7 >= 0 && i10 >= 0) {
                                    double[] dArr2 = dArr[i7];
                                    dArr2[i10] = dArr2[i10] - eval;
                                    double[] dArr3 = dArr[i10];
                                    dArr3[i7] = dArr3[i7] - eval;
                                }
                                double[] dArr4 = dArr[i8];
                                dArr4[i11] = dArr4[i11] + eval;
                                double[] dArr5 = dArr[i11];
                                dArr5[i8] = dArr5[i8] + eval;
                            }
                        }
                        i5 = i6 + i4;
                    }
                }
            }));
        }
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            try {
                double[][] dArr = (double[][]) ((Future) it.next()).get();
                for (int i5 = 0; i5 < dArr.length; i5++) {
                    for (int i6 = 0; i6 < dArr[i5].length; i6++) {
                        double[] dArr2 = this.centroidPairDots[i5];
                        int i7 = i6;
                        dArr2[i7] = dArr2[i7] + dArr[i5][i6];
                    }
                }
            } catch (InterruptedException e) {
                throw new FailedToFitException(e);
            } catch (ExecutionException e2) {
                throw new FailedToFitException(e2);
            }
        }
    }

    protected double cluster(DataSet dataSet, final int i, final int[] iArr, boolean z, ExecutorService executorService) {
        try {
            final int sampleSize = dataSet.getSampleSize();
            if (sampleSize < i) {
                throw new ClusterFailureException("Fewer data points then desired clusters, decrease cluster size");
            }
            this.X = dataSet.getDataVectors();
            setup(i, iArr, dataSet.getDataWeights());
            final double[][] dArr = new double[sampleSize][i];
            final double[] dArr2 = new double[sampleSize];
            this.centroidSelfDistances = new double[i][i];
            this.centroidPairDots = new double[i][i];
            final double[] dArr3 = new double[i];
            calculateCentroidDistances(i, this.centroidSelfDistances, dArr3, iArr, null, executorService);
            int[] iArr2 = new int[sampleSize];
            int i2 = 2;
            final AtomicBoolean atomicBoolean = new AtomicBoolean(true);
            final boolean[] zArr = new boolean[sampleSize];
            if (executorService == null) {
                initialClusterSetUp(i, sampleSize, dArr, dArr2, this.centroidSelfDistances, iArr);
            } else {
                initialClusterSetUp(i, sampleSize, dArr, dArr2, this.centroidSelfDistances, iArr, executorService);
            }
            int i3 = this.maximumIterations;
            while (true) {
                if (!atomicBoolean.get() && i2 <= 0) {
                    break;
                }
                int i4 = i3;
                i3--;
                if (i4 < 0) {
                    break;
                }
                i2--;
                atomicBoolean.set(false);
                if (i3 < this.maximumIterations - 1) {
                    calculateCentroidDistances(i, this.centroidSelfDistances, dArr3, iArr, iArr2, executorService);
                }
                System.arraycopy(iArr, 0, iArr2, 0, iArr.length);
                final CountDownLatch countDownLatch = new CountDownLatch(SystemInfo.LogicalCores);
                if (executorService == null) {
                    for (int i5 = 0; i5 < sampleSize; i5++) {
                        if (dArr2[i5] > dArr3[iArr[i5]]) {
                            for (int i6 = 0; i6 < i; i6++) {
                                if (i6 != iArr[i5] && dArr2[i5] > dArr[i5][i6] && dArr2[i5] > this.centroidSelfDistances[iArr[i5]][i6] * 0.5d) {
                                    step3aBoundsUpdate(zArr, i5, iArr, dArr2, dArr);
                                    step3bUpdate(dArr2, i5, dArr, i6, this.centroidSelfDistances, iArr, atomicBoolean);
                                }
                            }
                        }
                    }
                } else {
                    for (int i7 = 0; i7 < SystemInfo.LogicalCores; i7++) {
                        final int i8 = i7;
                        executorService.submit(new Runnable() { // from class: jsat.clustering.kmeans.ElkanKernelKMeans.2
                            @Override // java.lang.Runnable
                            public void run() {
                                int i9 = i8;
                                while (true) {
                                    int i10 = i9;
                                    if (i10 >= sampleSize) {
                                        countDownLatch.countDown();
                                        return;
                                    }
                                    if (dArr2[i10] > dArr3[iArr[i10]]) {
                                        for (int i11 = 0; i11 < i; i11++) {
                                            if (i11 != iArr[i10] && dArr2[i10] > dArr[i10][i11] && dArr2[i10] > ElkanKernelKMeans.this.centroidSelfDistances[iArr[i10]][i11] * 0.5d) {
                                                ElkanKernelKMeans.this.step3aBoundsUpdate(zArr, i10, iArr, dArr2, dArr);
                                                ElkanKernelKMeans.this.step3bUpdate(dArr2, i10, dArr, i11, ElkanKernelKMeans.this.centroidSelfDistances, iArr, atomicBoolean);
                                            }
                                        }
                                    }
                                    i9 = i10 + SystemInfo.LogicalCores;
                                }
                            }
                        });
                    }
                }
                if (executorService != null) {
                    try {
                        countDownLatch.await();
                    } catch (InterruptedException e) {
                        throw new ClusterFailureException("Clustering failed");
                    }
                }
                step4_5_6_distanceMovedBoundsUpdate(i, sampleSize, dArr, dArr2, iArr, zArr, executorService);
            }
            double d = 0.0d;
            if (z) {
                for (int i9 = 0; i9 < sampleSize; i9++) {
                    d += Math.pow(dArr2[i9], 2.0d);
                }
            } else {
                for (int i10 = 0; i10 < sampleSize; i10++) {
                    d += Math.pow(dArr2[i10], 2.0d);
                }
            }
            return d;
        } catch (Exception e2) {
            Logger.getLogger(ElkanKernelKMeans.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e2);
            return Double.MAX_VALUE;
        }
    }

    private void initialClusterSetUp(int i, int i2, double[][] dArr, double[] dArr2, double[][] dArr3, int[] iArr) {
        boolean[] zArr = new boolean[i];
        for (int i3 = 0; i3 < i2; i3++) {
            double d = Double.MAX_VALUE;
            int i4 = -1;
            Arrays.fill(zArr, false);
            for (int i5 = 0; i5 < i; i5++) {
                if (!zArr[i5]) {
                    double distance = distance(i3, i5, iArr);
                    dArr[i3][i5] = distance;
                    if (distance < d) {
                        dArr2[i3] = distance;
                        d = distance;
                        i4 = i5;
                        for (int i6 = i5 + 1; i6 < i; i6++) {
                            if (dArr3[i5][i6] >= 2.0d * distance) {
                                zArr[i6] = true;
                            }
                        }
                    }
                }
            }
            this.newDesignations[i3] = i4;
        }
    }

    private void initialClusterSetUp(final int i, int i2, final double[][] dArr, final double[] dArr2, final double[][] dArr3, final int[] iArr, ExecutorService executorService) {
        int i3 = i2 / SystemInfo.LogicalCores;
        int i4 = i2 % SystemInfo.LogicalCores;
        int i5 = 0;
        final CountDownLatch countDownLatch = new CountDownLatch(SystemInfo.LogicalCores);
        while (i5 < i2) {
            final int i6 = i5;
            int i7 = i4;
            i4--;
            final int i8 = i5 + i3 + (i7 > 0 ? 1 : 0);
            i5 = i8;
            executorService.submit(new Runnable() { // from class: jsat.clustering.kmeans.ElkanKernelKMeans.3
                @Override // java.lang.Runnable
                public void run() {
                    boolean[] zArr = new boolean[i];
                    for (int i9 = i6; i9 < i8; i9++) {
                        double d = Double.MAX_VALUE;
                        int i10 = -1;
                        Arrays.fill(zArr, false);
                        for (int i11 = 0; i11 < i; i11++) {
                            if (!zArr[i11]) {
                                double distance = ElkanKernelKMeans.this.distance(i9, i11, iArr);
                                dArr[i9][i11] = distance;
                                if (distance < d) {
                                    dArr2[i9] = distance;
                                    d = distance;
                                    i10 = i11;
                                    for (int i12 = i11 + 1; i12 < i; i12++) {
                                        if (dArr3[i11][i12] >= 2.0d * distance) {
                                            zArr[i12] = true;
                                        }
                                    }
                                }
                            }
                        }
                        ElkanKernelKMeans.this.newDesignations[i9] = i10;
                    }
                    countDownLatch.countDown();
                }
            });
        }
        while (true) {
            int i9 = i5;
            i5++;
            if (i9 >= SystemInfo.LogicalCores) {
                try {
                    countDownLatch.await();
                    return;
                } catch (InterruptedException e) {
                    Logger.getLogger(ElkanKernelKMeans.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
                    return;
                }
            }
            countDownLatch.countDown();
        }
    }

    private int step4_5_6_distanceMovedBoundsUpdate(final int i, final int i2, final double[][] dArr, final double[] dArr2, final int[] iArr, final boolean[] zArr, ExecutorService executorService) {
        final double[] dArr3 = new double[i];
        double[] dArr4 = new double[this.meanSqrdNorms.length];
        for (int i3 = 0; i3 < this.meanSqrdNorms.length; i3++) {
            dArr4[i3] = this.meanSqrdNorms[i3] * this.normConsts[i3];
        }
        int i4 = 0;
        if (executorService != null) {
            try {
                ArrayList arrayList = new ArrayList(SystemInfo.LogicalCores);
                for (int i5 = 0; i5 < SystemInfo.LogicalCores; i5++) {
                    final int startBlock = ParallelUtils.getStartBlock(i2, i5, SystemInfo.LogicalCores);
                    final int endBlock = ParallelUtils.getEndBlock(i2, i5, SystemInfo.LogicalCores);
                    arrayList.add(executorService.submit(new Callable<Integer>() { // from class: jsat.clustering.kmeans.ElkanKernelKMeans.4
                        /* JADX WARN: Can't rename method to resolve collision */
                        @Override // java.util.concurrent.Callable
                        public Integer call() {
                            double[] dArr5 = new double[i];
                            double[] dArr6 = new double[i];
                            int i6 = 0;
                            for (int i7 = startBlock; i7 < endBlock; i7++) {
                                i6 += ElkanKernelKMeans.this.updateMeansFromChange(i7, iArr, dArr5, dArr6);
                            }
                            synchronized (iArr) {
                                ElkanKernelKMeans.this.applyMeanUpdates(dArr5, dArr6);
                            }
                            return Integer.valueOf(i6);
                        }
                    }));
                }
                try {
                    Iterator it = arrayList.iterator();
                    while (it.hasNext()) {
                        i4 += ((Integer) ((Future) it.next()).get()).intValue();
                    }
                } catch (ExecutionException e) {
                    Logger.getLogger(ElkanKernelKMeans.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
                }
                updateNormConsts();
                final CountDownLatch countDownLatch = new CountDownLatch(i);
                for (int i6 = 0; i6 < i; i6++) {
                    dArr3[i6] = meanToMeanDistance(i6, i6, this.newDesignations, iArr, dArr4[i6], executorService);
                }
                for (int i7 = 0; i7 < i; i7++) {
                    final int i8 = i7;
                    executorService.submit(new Runnable() { // from class: jsat.clustering.kmeans.ElkanKernelKMeans.5
                        @Override // java.lang.Runnable
                        public void run() {
                            for (int i9 = 0; i9 < i2; i9++) {
                                dArr[i9][i8] = Math.max(dArr[i9][i8] - dArr3[i8], 0.0d);
                            }
                            countDownLatch.countDown();
                        }
                    });
                }
                countDownLatch.await();
                System.arraycopy(this.newDesignations, 0, iArr, 0, i2);
                final CountDownLatch countDownLatch2 = new CountDownLatch(SystemInfo.LogicalCores);
                for (int i9 = 0; i9 < SystemInfo.LogicalCores; i9++) {
                    final int startBlock2 = ParallelUtils.getStartBlock(i2, i9, SystemInfo.LogicalCores);
                    final int endBlock2 = ParallelUtils.getEndBlock(i2, i9, SystemInfo.LogicalCores);
                    executorService.submit(new Runnable() { // from class: jsat.clustering.kmeans.ElkanKernelKMeans.6
                        @Override // java.lang.Runnable
                        public void run() {
                            for (int i10 = startBlock2; i10 < endBlock2; i10++) {
                                double[] dArr5 = dArr2;
                                int i11 = i10;
                                dArr5[i11] = dArr5[i11] + dArr3[iArr[i10]];
                                zArr[i10] = true;
                            }
                            countDownLatch2.countDown();
                        }
                    });
                }
                countDownLatch2.await();
                return i4;
            } catch (InterruptedException e2) {
                Logger.getLogger(ElkanKernelKMeans.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e2);
            }
        }
        for (int i10 = 0; i10 < i2; i10++) {
            i4 += updateMeansFromChange(i10, iArr);
        }
        updateNormConsts();
        for (int i11 = 0; i11 < i; i11++) {
            dArr3[i11] = meanToMeanDistance(i11, i11, this.newDesignations, iArr, dArr4[i11]);
        }
        System.arraycopy(this.newDesignations, 0, iArr, 0, i2);
        for (int i12 = 0; i12 < i; i12++) {
            for (int i13 = 0; i13 < i2; i13++) {
                dArr[i13][i12] = Math.max(dArr[i13][i12] - dArr3[i12], 0.0d);
            }
        }
        for (int i14 = 0; i14 < i2; i14++) {
            int i15 = i14;
            dArr2[i15] = dArr2[i15] + dArr3[iArr[i14]];
            zArr[i14] = true;
        }
        return i4;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void step3aBoundsUpdate(boolean[] zArr, int i, int[] iArr, double[] dArr, double[][] dArr2) {
        if (zArr[i]) {
            zArr[i] = false;
            int i2 = iArr[i];
            double distance = distance(i, i2, iArr);
            dArr2[i][i2] = distance;
            dArr[i] = distance;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void step3bUpdate(double[] dArr, int i, double[][] dArr2, int i2, double[][] dArr3, int[] iArr, AtomicBoolean atomicBoolean) {
        if (dArr[i] > dArr2[i][i2] || dArr[i] > dArr3[iArr[i]][i2] / 2.0d) {
            double distance = distance(i, i2, iArr);
            dArr2[i][i2] = distance;
            if (distance < dArr[i]) {
                this.newDesignations[i] = i2;
                dArr[i] = distance;
                atomicBoolean.set(true);
            }
        }
    }

    private void calculateCentroidDistances(int i, double[][] dArr, double[] dArr2, int[] iArr, int[] iArr2, ExecutorService executorService) {
        if (iArr2 == null) {
            iArr2 = new int[iArr.length];
            Arrays.fill(iArr2, -1);
        }
        update_centroid_pair_dots(iArr2, iArr, executorService);
        double[] dArr3 = new double[i];
        for (int i2 = 0; i2 < iArr.length; i2++) {
            int i3 = iArr[i2];
            dArr3[i3] = dArr3[i3] + this.W.get(i2);
        }
        for (int i4 = 0; i4 < i; i4++) {
            for (int i5 = i4 + 1; i5 < i; i5++) {
                double sqrt = Math.sqrt(Math.max(0.0d, ((this.meanSqrdNorms[i4] * this.normConsts[i4]) + (this.meanSqrdNorms[i5] * this.normConsts[i5])) - (2.0d * (this.centroidPairDots[i4][i5] / (dArr3[i4] * dArr3[i5])))));
                dArr[i4][i5] = sqrt;
                dArr[i5][i4] = sqrt;
            }
        }
        for (int i6 = 0; i6 < i; i6++) {
            double d = Double.MAX_VALUE;
            for (int i7 = 0; i7 < i; i7++) {
                if (i6 != i7) {
                    d = Math.min(d, dArr[i6][i7]);
                }
            }
            dArr2[i6] = d / 2.0d;
        }
    }

    @Override // jsat.clustering.KClusterer
    public int[] cluster(DataSet dataSet, int i, ExecutorService executorService, int[] iArr) {
        if (iArr == null) {
            iArr = new int[dataSet.getSampleSize()];
        }
        if (dataSet.getSampleSize() < i) {
            throw new ClusterFailureException("Fewer data points then desired clusters, decrease cluster size");
        }
        cluster(dataSet, i, iArr, false, executorService);
        return iArr;
    }

    @Override // jsat.clustering.KClusterer
    public int[] cluster(DataSet dataSet, int i, int[] iArr) {
        if (iArr == null) {
            iArr = new int[dataSet.getSampleSize()];
        }
        if (dataSet.getSampleSize() < i) {
            throw new ClusterFailureException("Fewer data points then desired clusters, decrease cluster size");
        }
        cluster(dataSet, i, iArr, false, (ExecutorService) null);
        return iArr;
    }

    @Override // jsat.clustering.kmeans.KernelKMeans, jsat.clustering.KClustererBase, jsat.clustering.ClustererBase
    /* renamed from: clone */
    public ElkanKernelKMeans mo590clone() {
        return new ElkanKernelKMeans(this);
    }
}
