package jsat.clustering.kmeans;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.clustering.SeedSelectionMethods;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.utils.DoubleList;
import jsat.utils.FakeExecutor;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.AtomicDoubleArray;
import jsat.utils.random.XORWOW;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/clustering/kmeans/HamerlyKMeans.class */
public class HamerlyKMeans extends KMeans {
    private static final long serialVersionUID = -4960453870335145091L;

    public HamerlyKMeans(DistanceMetric distanceMetric, SeedSelectionMethods.SeedSelection seedSelection, Random random) {
        super(distanceMetric, seedSelection, random);
    }

    public HamerlyKMeans(DistanceMetric distanceMetric, SeedSelectionMethods.SeedSelection seedSelection) {
        this(distanceMetric, seedSelection, new XORWOW());
    }

    public HamerlyKMeans() {
        this(new EuclideanDistance(), SeedSelectionMethods.SeedSelection.KPP);
    }

    public HamerlyKMeans(HamerlyKMeans hamerlyKMeans) {
        super(hamerlyKMeans);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // jsat.clustering.kmeans.KMeans
    public double cluster(final DataSet dataSet, List<Double> list, final int i, final List<Vec> list2, final int[] iArr, boolean z, ExecutorService executorService, boolean z2, Vec vec) {
        final int sampleSize = dataSet.getSampleSize();
        final int numNumericalVars = dataSet.getNumNumericalVars();
        TrainableDistanceMetric.trainIfNeeded(this.dm, dataSet, executorService);
        Vec dataWeights = vec == null ? dataSet.getDataWeights() : vec;
        final List<Vec> dataVectors = dataSet.getDataVectors();
        List<Double> accelerationCache = list == null ? (executorService == null || (executorService instanceof FakeExecutor)) ? this.dm.getAccelerationCache(dataVectors) : this.dm.getAccelerationCache(dataVectors, executorService) : list;
        final ArrayList arrayList = new ArrayList(i);
        if (list2.size() != i) {
            list2.clear();
            list2.addAll(SeedSelectionMethods.selectIntialPoints(dataSet, i, this.dm, accelerationCache, this.rand, this.seedSelection, executorService));
        }
        for (int i2 = 0; i2 < list2.size(); i2++) {
            if (list2.get(i2).isSparse()) {
                list2.set(i2, new DenseVector(list2.get(i2)));
            }
        }
        final Vec[] vecArr = new Vec[i];
        Vec[] vecArr2 = new Vec[i];
        final AtomicDoubleArray atomicDoubleArray = new AtomicDoubleArray(i);
        double[] dArr = new double[i];
        final double[] dArr2 = new double[i];
        final double[] dArr3 = new double[sampleSize];
        final double[] dArr4 = new double[sampleSize];
        final ThreadLocal<Vec[]> threadLocal = new ThreadLocal<Vec[]>() { // from class: jsat.clustering.kmeans.HamerlyKMeans.1
            /* JADX INFO: Access modifiers changed from: protected */
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.lang.ThreadLocal
            public Vec[] initialValue() {
                Vec[] vecArr3 = new Vec[list2.size()];
                for (int i3 = 0; i3 < i; i3++) {
                    vecArr3[i3] = new DenseVector(numNumericalVars);
                }
                return vecArr3;
            }
        };
        Initialize(dataSet, atomicDoubleArray, list2, vecArr2, vecArr, dArr3, dArr4, iArr, executorService, threadLocal, dataVectors, accelerationCache, arrayList, dataWeights);
        for (int i3 = 0; i3 < list2.size(); i3++) {
            if (list2.get(i3).isSparse()) {
                list2.set(i3, new DenseVector(list2.get(i3)));
            }
        }
        final AtomicInteger atomicInteger = new AtomicInteger(sampleSize);
        while (atomicInteger.get() > 0) {
            moveCenters(list2, vecArr2, vecArr, atomicDoubleArray, dArr, arrayList);
            UpdateBounds(dArr, iArr, dArr3, dArr4);
            atomicInteger.set(0);
            updateS(dArr2, list2, executorService, arrayList);
            if (executorService == null) {
                int i4 = 0;
                for (int i5 = 0; i5 < sampleSize; i5++) {
                    i4 += mainLoopWork(dataSet, i5, dArr2, iArr, dArr3, dArr4, atomicDoubleArray, vecArr, dataVectors, accelerationCache, list2, arrayList, dataWeights);
                }
                atomicInteger.set(i4);
            } else {
                final CountDownLatch countDownLatch = new CountDownLatch(SystemInfo.LogicalCores);
                for (int i6 = 0; i6 < SystemInfo.LogicalCores; i6++) {
                    final int i7 = i6;
                    final List<Double> list3 = accelerationCache;
                    final Vec vec2 = dataWeights;
                    executorService.submit(new Runnable() { // from class: jsat.clustering.kmeans.HamerlyKMeans.2
                        @Override // java.lang.Runnable
                        public void run() {
                            Vec[] vecArr3 = (Vec[]) threadLocal.get();
                            int i8 = 0;
                            int i9 = i7;
                            while (true) {
                                int i10 = i9;
                                if (i10 >= sampleSize) {
                                    break;
                                }
                                i8 += HamerlyKMeans.this.mainLoopWork(dataSet, i10, dArr2, iArr, dArr3, dArr4, atomicDoubleArray, vecArr3, dataVectors, list3, list2, arrayList, vec2);
                                i9 = i10 + SystemInfo.LogicalCores;
                            }
                            if (i8 > 0) {
                                atomicInteger.getAndAdd(i8);
                                for (int i11 = 0; i11 < vecArr.length; i11++) {
                                    synchronized (vecArr[i11]) {
                                        vecArr[i11].mutableAdd(vecArr3[i11]);
                                    }
                                    vecArr3[i11].zeroOut();
                                }
                            }
                            countDownLatch.countDown();
                        }
                    });
                }
                try {
                    countDownLatch.await();
                } catch (InterruptedException e) {
                    Logger.getLogger(HamerlyKMeans.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
                }
            }
        }
        if (!z2) {
            return 0.0d;
        }
        double d = 0.0d;
        if (this.saveCentroidDistance) {
            this.nearestCentroidDist = new double[sampleSize];
        } else {
            this.nearestCentroidDist = null;
        }
        if (z) {
            for (int i8 = 0; i8 < sampleSize; i8++) {
                double dist = this.dm.dist(i8, list2.get(iArr[i8]), arrayList.get(iArr[i8]), dataVectors, accelerationCache);
                d += Math.pow(dist, 2.0d);
                if (this.saveCentroidDistance) {
                    this.nearestCentroidDist[i8] = dist;
                }
            }
        } else {
            for (int i9 = 0; i9 < sampleSize; i9++) {
                d += Math.pow(dArr3[i9], 2.0d);
                if (this.saveCentroidDistance) {
                    this.nearestCentroidDist[i9] = dArr3[i9];
                }
            }
        }
        return d;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public int mainLoopWork(DataSet dataSet, int i, double[] dArr, int[] iArr, double[] dArr2, double[] dArr3, AtomicDoubleArray atomicDoubleArray, Vec[] vecArr, List<Vec> list, List<Double> list2, List<Vec> list3, List<List<Double>> list4, Vec vec) {
        int PointAllCtrs;
        int i2 = iArr[i];
        double max = Math.max(dArr[i2] / 2.0d, dArr3[i]);
        if (dArr2[i] <= max) {
            return 0;
        }
        Vec vec2 = list.get(i);
        dArr2[i] = this.dm.dist(i, list3.get(i2), list4.get(i2), list, list2);
        if (dArr2[i] <= max || i2 == (PointAllCtrs = PointAllCtrs(vec2, i, list3, iArr, dArr2, dArr3, list, list2, list4))) {
            return 0;
        }
        double d = vec.get(i);
        atomicDoubleArray.addAndGet(i2, -d);
        atomicDoubleArray.addAndGet(PointAllCtrs, d);
        vecArr[i2].mutableSubtract(d, vec2);
        vecArr[PointAllCtrs].mutableAdd(d, vec2);
        return 1;
    }

    private void updateS(final double[] dArr, final List<Vec> list, ExecutorService executorService, List<List<Double>> list2) {
        final CountDownLatch countDownLatch = new CountDownLatch(list.size());
        Arrays.fill(dArr, Double.MAX_VALUE);
        final DoubleList doubleList = list2.get(0).isEmpty() ? null : new DoubleList(list2.size());
        if (doubleList != null) {
            Iterator<List<Double>> it = list2.iterator();
            while (it.hasNext()) {
                doubleList.addAll(it.next());
            }
        }
        for (int i = 0; i < list.size(); i++) {
            if (executorService == null) {
                double d = Double.POSITIVE_INFINITY;
                int i2 = Integer.MAX_VALUE;
                for (int i3 = i + 1; i3 < list.size(); i3++) {
                    double dist = this.dm.dist(i, i3, list, doubleList);
                    if (dist < d) {
                        d = dist;
                        i2 = i3;
                    }
                }
                dArr[i] = Math.min(d, dArr[i]);
                if (i2 < dArr.length) {
                    dArr[i2] = Math.min(dArr[i2], dArr[i]);
                }
            } else {
                final int i4 = i;
                executorService.submit(new Runnable() { // from class: jsat.clustering.kmeans.HamerlyKMeans.3
                    @Override // java.lang.Runnable
                    public void run() {
                        double d2 = Double.POSITIVE_INFINITY;
                        int i5 = Integer.MAX_VALUE;
                        for (int i6 = i4 + 1; i6 < list.size(); i6++) {
                            double dist2 = HamerlyKMeans.this.dm.dist(i4, i6, list, doubleList);
                            if (dist2 < d2) {
                                d2 = dist2;
                                i5 = i6;
                            }
                        }
                        synchronized (dArr) {
                            double[] dArr2 = dArr;
                            int i7 = i4;
                            double min = Math.min(d2, dArr[i4]);
                            dArr2[i7] = min;
                            if (i5 < dArr.length) {
                                dArr[i5] = Math.min(min, dArr[i5]);
                            }
                        }
                        countDownLatch.countDown();
                    }
                });
            }
        }
        if (executorService != null) {
            try {
                countDownLatch.await();
            } catch (InterruptedException e) {
                Logger.getLogger(HamerlyKMeans.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
            }
        }
    }

    private void Initialize(DataSet dataSet, final AtomicDoubleArray atomicDoubleArray, final List<Vec> list, Vec[] vecArr, final Vec[] vecArr2, final double[] dArr, final double[] dArr2, final int[] iArr, ExecutorService executorService, final ThreadLocal<Vec[]> threadLocal, final List<Vec> list2, final List<Double> list3, final List<List<Double>> list4, final Vec vec) {
        for (int i = 0; i < list.size(); i++) {
            vecArr2[i] = new DenseVector(list.get(0).length());
            vecArr[i] = vecArr2[i].mo525clone();
            if (this.dm.supportsAcceleration()) {
                list4.add(this.dm.getQueryInfo(list.get(i)));
            } else {
                list4.add(Collections.EMPTY_LIST);
            }
        }
        if (executorService == null) {
            for (int i2 = 0; i2 < dArr.length; i2++) {
                Vec vec2 = list2.get(i2);
                int PointAllCtrs = PointAllCtrs(vec2, i2, list, iArr, dArr, dArr2, list2, list3, list4);
                double d = vec.get(i2);
                atomicDoubleArray.addAndGet(PointAllCtrs, d);
                vecArr2[PointAllCtrs].mutableAdd(d, vec2);
            }
            return;
        }
        final CountDownLatch countDownLatch = new CountDownLatch(SystemInfo.LogicalCores);
        for (int i3 = 0; i3 < SystemInfo.LogicalCores; i3++) {
            final int i4 = i3;
            executorService.submit(new Runnable() { // from class: jsat.clustering.kmeans.HamerlyKMeans.4
                @Override // java.lang.Runnable
                public void run() {
                    Vec[] vecArr3 = (Vec[]) threadLocal.get();
                    int i5 = i4;
                    while (true) {
                        int i6 = i5;
                        if (i6 >= dArr.length) {
                            break;
                        }
                        Vec vec3 = (Vec) list2.get(i6);
                        int PointAllCtrs2 = HamerlyKMeans.this.PointAllCtrs(vec3, i6, list, iArr, dArr, dArr2, list2, list3, list4);
                        double d2 = vec.get(i6);
                        atomicDoubleArray.addAndGet(PointAllCtrs2, d2);
                        vecArr3[PointAllCtrs2].mutableAdd(d2, vec3);
                        i5 = i6 + SystemInfo.LogicalCores;
                    }
                    for (int i7 = 0; i7 < vecArr2.length; i7++) {
                        synchronized (vecArr2[i7]) {
                            vecArr2[i7].mutableAdd(vecArr3[i7]);
                        }
                        vecArr3[i7].zeroOut();
                    }
                    countDownLatch.countDown();
                }
            });
        }
        try {
            countDownLatch.await();
        } catch (InterruptedException e) {
            Logger.getLogger(HamerlyKMeans.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public int PointAllCtrs(Vec vec, int i, List<Vec> list, int[] iArr, double[] dArr, double[] dArr2, List<Vec> list2, List<Double> list3, List<List<Double>> list4) {
        double d = Double.POSITIVE_INFINITY;
        double d2 = Double.MAX_VALUE;
        int i2 = -1;
        for (int i3 = 0; i3 < list.size(); i3++) {
            double dist = this.dm.dist(i, list.get(i3), list4.get(i3), list2, list3);
            if (dist < d) {
                if (dist < d2) {
                    d = d2;
                    d2 = dist;
                    i2 = i3;
                } else {
                    d = dist;
                }
            }
        }
        iArr[i] = i2;
        dArr[i] = d2;
        dArr2[i] = d;
        return i2;
    }

    private void moveCenters(List<Vec> list, Vec[] vecArr, Vec[] vecArr2, AtomicDoubleArray atomicDoubleArray, double[] dArr, List<List<Double>> list2) {
        for (int i = 0; i < list.size(); i++) {
            double d = atomicDoubleArray.get(i);
            if (d > 0.0d) {
                vecArr2[i].copyTo(vecArr[i]);
                vecArr[i].mutableDivide(d);
            } else {
                vecArr2[i].zeroOut();
                vecArr[i].zeroOut();
            }
            dArr[i] = this.dm.dist(list.get(i), vecArr[i]);
            vecArr[i].copyTo(list.get(i));
            if (this.dm.supportsAcceleration()) {
                list2.set(i, this.dm.getQueryInfo(list.get(i)));
            }
        }
    }

    private void UpdateBounds(double[] dArr, int[] iArr, double[] dArr2, double[] dArr3) {
        double d = Double.NEGATIVE_INFINITY;
        int i = -1;
        double d2 = -1.7976931348623157E308d;
        int i2 = -1;
        for (int i3 = 0; i3 < dArr.length; i3++) {
            double d3 = dArr[i3];
            if (d3 > d) {
                if (d3 > d2) {
                    d = d2;
                    i = i2;
                    d2 = d3;
                    i2 = i3;
                } else {
                    d = d3;
                    i = i3;
                }
            }
        }
        int i4 = i2;
        int i5 = i;
        for (int i6 = 0; i6 < dArr2.length; i6++) {
            int i7 = iArr[i6];
            int i8 = i6;
            dArr2[i8] = dArr2[i8] + dArr[i7];
            if (i4 == i7) {
                int i9 = i6;
                dArr3[i9] = dArr3[i9] - dArr[i5];
            } else {
                int i10 = i6;
                dArr3[i10] = dArr3[i10] - dArr[i4];
            }
        }
    }

    @Override // jsat.clustering.kmeans.KMeans, jsat.clustering.KClustererBase, jsat.clustering.ClustererBase
    /* renamed from: clone */
    public HamerlyKMeans mo590clone() {
        return new HamerlyKMeans(this);
    }
}
