package jsat.clustering.kmeans;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.clustering.ClusterFailureException;
import jsat.clustering.SeedSelectionMethods;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DenseSparseMetric;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.utils.FakeExecutor;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.AtomicDoubleArray;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/clustering/kmeans/ElkanKMeans.class */
public class ElkanKMeans extends KMeans {
    private static final long serialVersionUID = -1629432283103273051L;
    private DenseSparseMetric dmds;
    private boolean useDenseSparse;

    public ElkanKMeans(DistanceMetric distanceMetric, Random random, SeedSelectionMethods.SeedSelection seedSelection) {
        super(distanceMetric, seedSelection, random);
        this.useDenseSparse = false;
        if (!distanceMetric.isSubadditive()) {
            throw new ClusterFailureException("KMeans implementation requires the triangle inequality");
        }
    }

    public ElkanKMeans(DistanceMetric distanceMetric, Random random) {
        this(distanceMetric, random, DEFAULT_SEED_SELECTION);
    }

    public ElkanKMeans(DistanceMetric distanceMetric) {
        this(distanceMetric, new Random());
    }

    public ElkanKMeans() {
        this(new EuclideanDistance());
    }

    public ElkanKMeans(ElkanKMeans elkanKMeans) {
        super(elkanKMeans);
        this.useDenseSparse = false;
        if (elkanKMeans.dmds != null) {
            this.dmds = (DenseSparseMetric) elkanKMeans.dmds.mo651clone();
        }
        this.useDenseSparse = elkanKMeans.useDenseSparse;
    }

    public void setUseDenseSparse(boolean z) {
        this.useDenseSparse = z;
    }

    public boolean isUseDenseSparse() {
        return this.useDenseSparse;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // jsat.clustering.kmeans.KMeans
    public double cluster(final DataSet dataSet, List<Double> list, final int i, final List<Vec> list2, final int[] iArr, boolean z, ExecutorService executorService, boolean z2, Vec vec) {
        try {
            final int sampleSize = dataSet.getSampleSize();
            final int numNumericalVars = dataSet.getNumNumericalVars();
            if (sampleSize < i) {
                throw new ClusterFailureException("Fewer data points then desired clusters, decrease cluster size");
            }
            Vec dataWeights = vec == null ? dataSet.getDataWeights() : vec;
            TrainableDistanceMetric.trainIfNeeded(this.dm, dataSet);
            final List<Vec> dataVectors = dataSet.getDataVectors();
            final ArrayList arrayList = new ArrayList(i);
            List<Double> accelerationCache = list == null ? (executorService == null || (executorService instanceof FakeExecutor)) ? this.dm.getAccelerationCache(dataVectors) : this.dm.getAccelerationCache(dataVectors, executorService) : list;
            if (list2.size() != i) {
                list2.clear();
                if (executorService == null || (executorService instanceof FakeExecutor)) {
                    list2.addAll(SeedSelectionMethods.selectIntialPoints(dataSet, i, this.dm, accelerationCache, this.rand, this.seedSelection));
                } else {
                    list2.addAll(SeedSelectionMethods.selectIntialPoints(dataSet, i, this.dm, accelerationCache, this.rand, this.seedSelection, executorService));
                }
            }
            for (int i2 = 0; i2 < list2.size(); i2++) {
                if (list2.get(i2).isSparse()) {
                    list2.set(i2, new DenseVector(list2.get(i2)));
                }
            }
            final double[][] dArr = new double[sampleSize][i];
            final double[] dArr2 = new double[sampleSize];
            final double[][] dArr3 = new double[i][i];
            final double[] dArr4 = new double[i];
            calculateCentroidDistances(i, dArr3, list2, dArr4, null, executorService);
            final AtomicDoubleArray atomicDoubleArray = new AtomicDoubleArray(i);
            Vec[] vecArr = new Vec[i];
            final Vec[] vecArr2 = new Vec[i];
            for (int i3 = 0; i3 < i; i3++) {
                vecArr[i3] = list2.get(i3).mo524clone();
                if (this.dm.supportsAcceleration()) {
                    arrayList.add(this.dm.getQueryInfo(list2.get(i3)));
                } else {
                    arrayList.add(Collections.EMPTY_LIST);
                }
                vecArr2[i3] = new DenseVector(numNumericalVars);
            }
            if ((this.dm instanceof DenseSparseMetric) && this.useDenseSparse) {
                this.dmds = (DenseSparseMetric) this.dm;
            }
            final double[] dArr5 = this.dmds != null ? new double[list2.size()] : null;
            int i4 = 2;
            final AtomicBoolean atomicBoolean = new AtomicBoolean(true);
            final boolean[] zArr = new boolean[sampleSize];
            final ThreadLocal<Vec[]> threadLocal = new ThreadLocal<Vec[]>() { // from class: jsat.clustering.kmeans.ElkanKMeans.1
                /* JADX INFO: Access modifiers changed from: protected */
                /* JADX WARN: Can't rename method to resolve collision */
                @Override // java.lang.ThreadLocal
                public Vec[] initialValue() {
                    Vec[] vecArr3 = new Vec[i];
                    for (int i5 = 0; i5 < vecArr3.length; i5++) {
                        vecArr3[i5] = new DenseVector(numNumericalVars);
                    }
                    return vecArr3;
                }
            };
            if (executorService == null) {
                initialClusterSetUp(i, sampleSize, dataVectors, list2, dArr, dArr2, dArr3, iArr, atomicDoubleArray, vecArr2, accelerationCache, arrayList, dataWeights);
            } else {
                initialClusterSetUp(i, sampleSize, dataVectors, list2, dArr, dArr2, dArr3, iArr, atomicDoubleArray, vecArr2, accelerationCache, arrayList, threadLocal, executorService, dataWeights);
            }
            int i5 = this.MaxIterLimit;
            while (true) {
                if (!atomicBoolean.get() && i4 <= 0) {
                    break;
                }
                int i6 = i5;
                i5--;
                if (i6 < 0) {
                    break;
                }
                i4--;
                atomicBoolean.set(false);
                if (i5 < this.MaxIterLimit - 1) {
                    calculateCentroidDistances(i, dArr3, list2, dArr4, dArr5, executorService);
                }
                final CountDownLatch countDownLatch = new CountDownLatch(SystemInfo.LogicalCores);
                if (executorService == null) {
                    for (int i7 = 0; i7 < sampleSize; i7++) {
                        if (dArr2[i7] > dArr4[iArr[i7]]) {
                            Vec vec2 = dataVectors.get(i7);
                            for (int i8 = 0; i8 < i; i8++) {
                                if (i8 != iArr[i7] && dArr2[i7] > dArr[i7][i8] && dArr2[i7] > dArr3[iArr[i7]][i8] * 0.5d) {
                                    step3aBoundsUpdate(dataVectors, zArr, i7, vec2, list2, iArr, dArr2, dArr, dArr5, accelerationCache, arrayList);
                                    step3bUpdate(dataVectors, dArr2, i7, dArr, i8, dArr3, iArr, vec2, list2, threadLocal, atomicDoubleArray, atomicBoolean, dArr5, accelerationCache, arrayList, dataWeights);
                                }
                            }
                        }
                    }
                } else {
                    for (int i9 = 0; i9 < SystemInfo.LogicalCores; i9++) {
                        final int i10 = i9;
                        final List<Double> list3 = accelerationCache;
                        final Vec vec3 = dataWeights;
                        executorService.submit(new Runnable() { // from class: jsat.clustering.kmeans.ElkanKMeans.2
                            @Override // java.lang.Runnable
                            public void run() {
                                int i11 = i10;
                                while (true) {
                                    int i12 = i11;
                                    if (i12 >= sampleSize) {
                                        ElkanKMeans.this.step4UpdateCentroids(vecArr2, threadLocal);
                                        countDownLatch.countDown();
                                        return;
                                    }
                                    if (dArr2[i12] > dArr4[iArr[i12]]) {
                                        Vec numericalValues = dataSet.getDataPoint(i12).getNumericalValues();
                                        for (int i13 = 0; i13 < i; i13++) {
                                            if (i13 != iArr[i12] && dArr2[i12] > dArr[i12][i13] && dArr2[i12] > dArr3[iArr[i12]][i13] * 0.5d) {
                                                ElkanKMeans.this.step3aBoundsUpdate(dataVectors, zArr, i12, numericalValues, list2, iArr, dArr2, dArr, dArr5, list3, arrayList);
                                                ElkanKMeans.this.step3bUpdate(dataVectors, dArr2, i12, dArr, i13, dArr3, iArr, numericalValues, list2, threadLocal, atomicDoubleArray, atomicBoolean, dArr5, list3, arrayList, vec3);
                                            }
                                        }
                                    }
                                    i11 = i12 + SystemInfo.LogicalCores;
                                }
                            }
                        });
                    }
                }
                if (executorService != null) {
                    try {
                        countDownLatch.await();
                    } catch (InterruptedException e) {
                        throw new ClusterFailureException("Clustering failed");
                    }
                } else {
                    step4UpdateCentroids(vecArr2, threadLocal);
                }
                step5_6_distanceMovedBoundsUpdate(i, vecArr, list2, vecArr2, atomicDoubleArray, sampleSize, dArr, dArr2, iArr, zArr, arrayList, executorService);
            }
            double d = 0.0d;
            if (z2) {
                if (this.saveCentroidDistance) {
                    this.nearestCentroidDist = new double[sampleSize];
                } else {
                    this.nearestCentroidDist = null;
                }
                if (z) {
                    for (int i11 = 0; i11 < sampleSize; i11++) {
                        double dist = this.dm.dist(i11, list2.get(iArr[i11]), arrayList.get(iArr[i11]), dataVectors, accelerationCache);
                        d += Math.pow(dist, 2.0d);
                        if (this.saveCentroidDistance) {
                            this.nearestCentroidDist[i11] = dist;
                        }
                    }
                } else {
                    for (int i12 = 0; i12 < sampleSize; i12++) {
                        d += Math.pow(dArr2[i12], 2.0d);
                        if (this.saveCentroidDistance) {
                            this.nearestCentroidDist[i12] = dArr2[i12];
                        }
                    }
                }
            }
            return d;
        } catch (Exception e2) {
            Logger.getLogger(ElkanKMeans.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e2);
            return Double.MAX_VALUE;
        }
    }

    private void initialClusterSetUp(int i, int i2, List<Vec> list, List<Vec> list2, double[][] dArr, double[] dArr2, double[][] dArr3, int[] iArr, AtomicDoubleArray atomicDoubleArray, Vec[] vecArr, List<Double> list3, List<List<Double>> list4, Vec vec) {
        boolean[] zArr = new boolean[i];
        for (int i3 = 0; i3 < i2; i3++) {
            Vec vec2 = list.get(i3);
            double d = Double.MAX_VALUE;
            int i4 = -1;
            Arrays.fill(zArr, false);
            for (int i5 = 0; i5 < i; i5++) {
                if (!zArr[i5]) {
                    double dist = this.dm.dist(i3, list2.get(i5), list4.get(i5), list, list3);
                    dArr[i3][i5] = dist;
                    if (dist < d) {
                        dArr2[i3] = dist;
                        d = dist;
                        i4 = i5;
                        for (int i6 = i5 + 1; i6 < i; i6++) {
                            if (dArr3[i5][i6] >= 2.0d * dist) {
                                zArr[i6] = true;
                            }
                        }
                    }
                }
            }
            iArr[i3] = i4;
            double d2 = vec.get(i3);
            atomicDoubleArray.addAndGet(i4, d2);
            vecArr[i4].mutableAdd(d2, vec2);
        }
    }

    private void initialClusterSetUp(final int i, int i2, final List<Vec> list, final List<Vec> list2, final double[][] dArr, final double[] dArr2, final double[][] dArr3, final int[] iArr, final AtomicDoubleArray atomicDoubleArray, final Vec[] vecArr, final List<Double> list3, final List<List<Double>> list4, final ThreadLocal<Vec[]> threadLocal, ExecutorService executorService, final Vec vec) {
        int i3 = i2 / SystemInfo.LogicalCores;
        int i4 = i2 % SystemInfo.LogicalCores;
        int i5 = 0;
        final CountDownLatch countDownLatch = new CountDownLatch(SystemInfo.LogicalCores);
        while (i5 < i2) {
            final int i6 = i5;
            int i7 = i4;
            i4--;
            final int i8 = i5 + i3 + (i7 > 0 ? 1 : 0);
            i5 = i8;
            executorService.submit(new Runnable() { // from class: jsat.clustering.kmeans.ElkanKMeans.3
                @Override // java.lang.Runnable
                public void run() {
                    Vec[] vecArr2 = (Vec[]) threadLocal.get();
                    boolean[] zArr = new boolean[i];
                    for (int i9 = i6; i9 < i8; i9++) {
                        Vec vec2 = (Vec) list.get(i9);
                        double d = Double.MAX_VALUE;
                        int i10 = -1;
                        Arrays.fill(zArr, false);
                        for (int i11 = 0; i11 < i; i11++) {
                            if (!zArr[i11]) {
                                double dist = ElkanKMeans.this.dm.dist(i9, (Vec) list2.get(i11), (List) list4.get(i11), list, list3);
                                dArr[i9][i11] = dist;
                                if (dist < d) {
                                    dArr2[i9] = dist;
                                    d = dist;
                                    i10 = i11;
                                    for (int i12 = i11 + 1; i12 < i; i12++) {
                                        if (dArr3[i11][i12] >= 2.0d * dist) {
                                            zArr[i12] = true;
                                        }
                                    }
                                }
                            }
                        }
                        iArr[i9] = i10;
                        double d2 = vec.get(i9);
                        atomicDoubleArray.addAndGet(i10, d2);
                        vecArr2[i10].mutableAdd(d2, vec2);
                    }
                    for (int i13 = 0; i13 < vecArr2.length; i13++) {
                        synchronized (vecArr[i13]) {
                            vecArr[i13].mutableAdd(vecArr2[i13]);
                        }
                        vecArr2[i13].zeroOut();
                    }
                    countDownLatch.countDown();
                }
            });
        }
        while (true) {
            int i9 = i5;
            i5++;
            if (i9 >= SystemInfo.LogicalCores) {
                try {
                    countDownLatch.await();
                    return;
                } catch (InterruptedException e) {
                    Logger.getLogger(ElkanKMeans.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
                    return;
                }
            }
            countDownLatch.countDown();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void step4UpdateCentroids(Vec[] vecArr, ThreadLocal<Vec[]> threadLocal) {
        Vec[] vecArr2 = threadLocal.get();
        for (int i = 0; i < vecArr2.length; i++) {
            if (vecArr2[i].nnz() != 0) {
                synchronized (vecArr[i]) {
                    vecArr[i].mutableAdd(vecArr2[i]);
                }
                vecArr2[i].zeroOut();
            }
        }
    }

    private void step5_6_distanceMovedBoundsUpdate(int i, final Vec[] vecArr, final List<Vec> list, final Vec[] vecArr2, final AtomicDoubleArray atomicDoubleArray, final int i2, final double[][] dArr, final double[] dArr2, final int[] iArr, final boolean[] zArr, final List<List<Double>> list2, ExecutorService executorService) {
        final double[] dArr3 = new double[i];
        if (executorService != null) {
            try {
                final CountDownLatch countDownLatch = new CountDownLatch(i);
                for (int i3 = 0; i3 < i; i3++) {
                    final int i4 = i3;
                    executorService.submit(new Runnable() { // from class: jsat.clustering.kmeans.ElkanKMeans.4
                        @Override // java.lang.Runnable
                        public void run() {
                            ((Vec) list.get(i4)).copyTo(vecArr[i4]);
                            vecArr2[i4].copyTo((Vec) list.get(i4));
                            if (atomicDoubleArray.get(i4) <= 1.0E-14d) {
                                ((Vec) list.get(i4)).zeroOut();
                            } else {
                                ((Vec) list.get(i4)).mutableDivide(atomicDoubleArray.get(i4));
                            }
                            dArr3[i4] = ElkanKMeans.this.dm.dist(vecArr[i4], (Vec) list.get(i4));
                            if (ElkanKMeans.this.dm.supportsAcceleration()) {
                                list2.set(i4, ElkanKMeans.this.dm.getQueryInfo((Vec) list.get(i4)));
                            }
                            for (int i5 = 0; i5 < i2; i5++) {
                                dArr[i5][i4] = Math.max(dArr[i5][i4] - dArr3[i4], 0.0d);
                            }
                            countDownLatch.countDown();
                        }
                    });
                }
                countDownLatch.await();
                final CountDownLatch countDownLatch2 = new CountDownLatch(SystemInfo.LogicalCores);
                int i5 = i2 / SystemInfo.LogicalCores;
                int i6 = 0;
                while (i6 < SystemInfo.LogicalCores) {
                    final int i7 = i6 * i5;
                    final int i8 = i6 == SystemInfo.LogicalCores - 1 ? i2 : i7 + i5;
                    executorService.submit(new Runnable() { // from class: jsat.clustering.kmeans.ElkanKMeans.5
                        @Override // java.lang.Runnable
                        public void run() {
                            for (int i9 = i7; i9 < i8; i9++) {
                                double[] dArr4 = dArr2;
                                int i10 = i9;
                                dArr4[i10] = dArr4[i10] + dArr3[iArr[i9]];
                                zArr[i9] = true;
                            }
                            countDownLatch2.countDown();
                        }
                    });
                    i6++;
                }
                countDownLatch2.await();
                return;
            } catch (InterruptedException e) {
                Logger.getLogger(ElkanKMeans.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
            }
        }
        for (int i9 = 0; i9 < i; i9++) {
            list.get(i9).copyTo(vecArr[i9]);
        }
        for (int i10 = 0; i10 < i; i10++) {
            vecArr2[i10].copyTo(list.get(i10));
            if (atomicDoubleArray.get(i10) <= 1.0E-14d) {
                list.get(i10).zeroOut();
            } else {
                list.get(i10).mutableDivide(atomicDoubleArray.get(i10));
            }
        }
        for (int i11 = 0; i11 < i; i11++) {
            dArr3[i11] = this.dm.dist(vecArr[i11], list.get(i11));
            if (this.dm.supportsAcceleration()) {
                list2.set(i11, this.dm.getQueryInfo(list.get(i11)));
            }
        }
        for (int i12 = 0; i12 < i; i12++) {
            for (int i13 = 0; i13 < i2; i13++) {
                dArr[i13][i12] = Math.max(dArr[i13][i12] - dArr3[i12], 0.0d);
            }
        }
        for (int i14 = 0; i14 < i2; i14++) {
            int i15 = i14;
            dArr2[i15] = dArr2[i15] + dArr3[iArr[i14]];
            zArr[i14] = true;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void step3aBoundsUpdate(List<Vec> list, boolean[] zArr, int i, Vec vec, List<Vec> list2, int[] iArr, double[] dArr, double[][] dArr2, double[] dArr3, List<Double> list3, List<List<Double>> list4) {
        if (zArr[i]) {
            zArr[i] = false;
            int i2 = iArr[i];
            double dist = this.dmds == null ? this.dm.dist(i, list2.get(i2), list4.get(i2), list, list3) : this.dmds.dist(dArr3[i2], list2.get(i2), vec);
            dArr2[i][i2] = dist;
            dArr[i] = dist;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void step3bUpdate(List<Vec> list, double[] dArr, int i, double[][] dArr2, int i2, double[][] dArr3, int[] iArr, Vec vec, List<Vec> list2, ThreadLocal<Vec[]> threadLocal, AtomicDoubleArray atomicDoubleArray, AtomicBoolean atomicBoolean, double[] dArr4, List<Double> list3, List<List<Double>> list4, Vec vec2) {
        if (dArr[i] > dArr2[i][i2] || dArr[i] > dArr3[iArr[i]][i2] / 2.0d) {
            double dist = this.dmds == null ? this.dm.dist(i, list2.get(i2), list4.get(i2), list, list3) : this.dmds.dist(dArr4[i2], list2.get(i2), vec);
            dArr2[i][i2] = dist;
            if (dist < dArr[i]) {
                Vec[] vecArr = threadLocal.get();
                double d = vec2.get(i);
                vecArr[iArr[i]].mutableSubtract(d, vec);
                atomicDoubleArray.addAndGet(iArr[i], -d);
                vecArr[i2].mutableAdd(d, vec);
                atomicDoubleArray.addAndGet(i2, d);
                iArr[i] = i2;
                dArr[i] = dist;
                atomicBoolean.set(true);
            }
        }
    }

    private void calculateCentroidDistances(int i, final double[][] dArr, final List<Vec> list, double[] dArr2, final double[] dArr3, ExecutorService executorService) {
        final List<Double> accelerationCache = this.dm.supportsAcceleration() ? this.dm.getAccelerationCache(list) : null;
        if (executorService != null) {
            final CountDownLatch countDownLatch = new CountDownLatch((((1 + i) * i) / 2) - i);
            for (int i2 = 0; i2 < i; i2++) {
                final int i3 = i2;
                for (int i4 = i2 + 1; i4 < i; i4++) {
                    final int i5 = i4;
                    executorService.submit(new Runnable() { // from class: jsat.clustering.kmeans.ElkanKMeans.6
                        @Override // java.lang.Runnable
                        public void run() {
                            dArr[i3][i5] = ElkanKMeans.this.dm.dist(i3, i5, list, accelerationCache);
                            if (dArr3 != null) {
                                dArr3[i3] = ElkanKMeans.this.dmds.getVectorConstant((Vec) list.get(i3));
                            }
                            countDownLatch.countDown();
                        }
                    });
                }
            }
            try {
                countDownLatch.await();
            } catch (InterruptedException e) {
                Logger.getLogger(ElkanKMeans.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
            }
        } else {
            for (int i6 = 0; i6 < i; i6++) {
                for (int i7 = i6 + 1; i7 < i; i7++) {
                    double dist = this.dm.dist(i6, i7, list, accelerationCache);
                    dArr[i6][i7] = dist;
                    dArr[i7][i6] = dist;
                }
                if (dArr3 != null) {
                    dArr3[i6] = this.dmds.getVectorConstant(list.get(i6));
                }
            }
        }
        for (int i8 = 0; i8 < i; i8++) {
            double d = Double.MAX_VALUE;
            for (int i9 = 0; i9 < i; i9++) {
                if (i9 != i8) {
                    d = Math.min(d, dArr[i8][i9]);
                }
            }
            dArr2[i8] = d / 2.0d;
        }
    }

    @Override // jsat.clustering.kmeans.KMeans, jsat.clustering.KClustererBase, jsat.clustering.ClustererBase
    /* renamed from: clone */
    public ElkanKMeans mo589clone() {
        return new ElkanKMeans(this);
    }
}
