package jsat.clustering.kmeans;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.clustering.KClustererBase;
import jsat.clustering.SeedSelectionMethods;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.utils.FakeExecutor;
import jsat.utils.IntList;
import jsat.utils.ListUtils;
import jsat.utils.SystemInfo;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/clustering/kmeans/MiniBatchKMeans.class */
public class MiniBatchKMeans extends KClustererBase {
    private static final long serialVersionUID = 412553399508594014L;
    private int batchSize;
    private int iterations;
    private DistanceMetric dm;
    private SeedSelectionMethods.SeedSelection seedSelection;
    private boolean storeMeans;
    private List<Vec> means;

    public MiniBatchKMeans(int i, int i2) {
        this(new EuclideanDistance(), i, i2);
    }

    public MiniBatchKMeans(DistanceMetric distanceMetric, int i, int i2) {
        this(distanceMetric, i, i2, SeedSelectionMethods.SeedSelection.KPP);
    }

    public MiniBatchKMeans(DistanceMetric distanceMetric, int i, int i2, SeedSelectionMethods.SeedSelection seedSelection) {
        this.storeMeans = true;
        setBatchSize(i);
        setIterations(i2);
        setDistanceMetric(distanceMetric);
        setSeedSelection(seedSelection);
    }

    public MiniBatchKMeans(MiniBatchKMeans miniBatchKMeans) {
        this.storeMeans = true;
        this.batchSize = miniBatchKMeans.batchSize;
        this.iterations = miniBatchKMeans.iterations;
        this.dm = miniBatchKMeans.dm.mo651clone();
        this.seedSelection = miniBatchKMeans.seedSelection;
        this.storeMeans = miniBatchKMeans.storeMeans;
        if (miniBatchKMeans.means != null) {
            this.means = new ArrayList();
            Iterator<Vec> it = miniBatchKMeans.means.iterator();
            while (it.hasNext()) {
                this.means.add(it.next().mo524clone());
            }
        }
    }

    public void setStoreMeans(boolean z) {
        this.storeMeans = z;
    }

    public List<Vec> getMeans() {
        return this.means;
    }

    public void setDistanceMetric(DistanceMetric distanceMetric) {
        this.dm = distanceMetric;
    }

    public DistanceMetric getDistanceMetric() {
        return this.dm;
    }

    public void setBatchSize(int i) {
        if (i < 1) {
            throw new ArithmeticException("Batch size must be a positive value, not " + i);
        }
        this.batchSize = i;
    }

    public int getBatchSize() {
        return this.batchSize;
    }

    public void setIterations(int i) {
        if (i < 1) {
            throw new ArithmeticException("Iterations must be a positive value, not " + i);
        }
        this.iterations = i;
    }

    public int getIterations() {
        return this.iterations;
    }

    public void setSeedSelection(SeedSelectionMethods.SeedSelection seedSelection) {
        this.seedSelection = seedSelection;
    }

    public SeedSelectionMethods.SeedSelection getSeedSelection() {
        return this.seedSelection;
    }

    @Override // jsat.clustering.Clusterer
    public int[] cluster(DataSet dataSet, int[] iArr) {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    @Override // jsat.clustering.Clusterer
    public int[] cluster(DataSet dataSet, ExecutorService executorService, int[] iArr) {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    @Override // jsat.clustering.KClusterer
    public int[] cluster(DataSet dataSet, int i, ExecutorService executorService, int[] iArr) {
        if (iArr == null) {
            iArr = new int[dataSet.getSampleSize()];
        }
        TrainableDistanceMetric.trainIfNeeded(this.dm, dataSet, executorService);
        final List<Vec> dataVectors = dataSet.getDataVectors();
        List<Double> accelerationCache = (executorService == null || (executorService instanceof FakeExecutor)) ? this.dm.getAccelerationCache(dataVectors) : this.dm.getAccelerationCache(dataVectors, executorService);
        this.means = SeedSelectionMethods.selectIntialPoints(dataSet, i, this.dm, accelerationCache, new Random(), this.seedSelection, executorService);
        final ArrayList arrayList = new ArrayList(this.means.size());
        for (int i2 = 0; i2 < this.means.size(); i2++) {
            if (this.dm.supportsAcceleration()) {
                arrayList.add(this.dm.getQueryInfo(this.means.get(i2)));
            } else {
                arrayList.add(Collections.EMPTY_LIST);
            }
        }
        int[] iArr2 = new int[this.means.size()];
        int min = Math.min(this.batchSize, dataSet.getSampleSize());
        final IntList intList = new IntList(min);
        IntList intList2 = new IntList(dataVectors.size());
        ListUtils.addRange(intList2, 0, dataVectors.size(), 1);
        final int[] iArr3 = new int[min];
        for (int i3 = 0; i3 < this.iterations; i3++) {
            intList.clear();
            ListUtils.randomSample(intList2, intList, min);
            final CountDownLatch countDownLatch = new CountDownLatch(SystemInfo.LogicalCores);
            int i4 = min / SystemInfo.LogicalCores;
            int i5 = min % SystemInfo.LogicalCores;
            int i6 = 0;
            while (i6 < min) {
                final int i7 = i6;
                int i8 = i5;
                i5--;
                final int i9 = i6 + i4 + (i8 > 0 ? 1 : 0);
                i6 = i9;
                final List<Double> list = accelerationCache;
                executorService.submit(new Runnable() { // from class: jsat.clustering.kmeans.MiniBatchKMeans.1
                    @Override // java.lang.Runnable
                    public void run() {
                        for (int i10 = i7; i10 < i9; i10++) {
                            double d = Double.POSITIVE_INFINITY;
                            int i11 = -1;
                            for (int i12 = 0; i12 < MiniBatchKMeans.this.means.size(); i12++) {
                                double dist = MiniBatchKMeans.this.dm.dist(((Integer) intList.get(i10)).intValue(), (Vec) MiniBatchKMeans.this.means.get(i12), (List) arrayList.get(i12), dataVectors, list);
                                if (dist < d) {
                                    d = dist;
                                    i11 = i12;
                                }
                            }
                            iArr3[i10] = i11;
                        }
                        countDownLatch.countDown();
                    }
                });
            }
            try {
                countDownLatch.await();
            } catch (InterruptedException e) {
                Logger.getLogger(MiniBatchKMeans.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
            }
            for (int i10 = 0; i10 < intList.size(); i10++) {
                int i11 = iArr3[i10];
                int i12 = iArr2[i11] + 1;
                iArr2[i11] = i12;
                double d = 1.0d / i12;
                Vec vec = this.means.get(i11);
                vec.mutableMultiply(1.0d - d);
                vec.mutableAdd(d, dataVectors.get(intList.get(i10).intValue()));
            }
            if (this.dm.supportsAcceleration()) {
                for (int i13 = 0; i13 < this.means.size(); i13++) {
                    arrayList.set(i13, this.dm.getQueryInfo(this.means.get(i13)));
                }
            }
        }
        ArrayList arrayList2 = new ArrayList(SystemInfo.LogicalCores);
        int sampleSize = dataSet.getSampleSize() / SystemInfo.LogicalCores;
        int sampleSize2 = dataSet.getSampleSize() % SystemInfo.LogicalCores;
        int i14 = 0;
        final int[] iArr4 = iArr;
        while (i14 < dataSet.getSampleSize()) {
            final int i15 = i14;
            int i16 = sampleSize2;
            sampleSize2--;
            final int i17 = i14 + sampleSize + (i16 > 0 ? 1 : 0);
            i14 = i17;
            final List<Double> list2 = accelerationCache;
            arrayList2.add(executorService.submit(new Callable<Double>() { // from class: jsat.clustering.kmeans.MiniBatchKMeans.2
                /* JADX WARN: Can't rename method to resolve collision */
                @Override // java.util.concurrent.Callable
                public Double call() throws Exception {
                    double d2 = 0.0d;
                    for (int i18 = i15; i18 < i17; i18++) {
                        double d3 = Double.POSITIVE_INFINITY;
                        int i19 = -1;
                        for (int i20 = 0; i20 < MiniBatchKMeans.this.means.size(); i20++) {
                            double dist = MiniBatchKMeans.this.dm.dist(i18, (Vec) MiniBatchKMeans.this.means.get(i20), (List) arrayList.get(i20), dataVectors, list2);
                            if (dist < d3) {
                                d3 = dist;
                                i19 = i20;
                            }
                        }
                        iArr4[i18] = i19;
                        d2 += d3 * d3;
                    }
                    return Double.valueOf(d2);
                }
            }));
        }
        double d2 = 0.0d;
        try {
            Iterator it = arrayList2.iterator();
            while (it.hasNext()) {
                d2 += ((Double) ((Future) it.next()).get()).doubleValue();
            }
        } catch (InterruptedException e2) {
            Logger.getLogger(MiniBatchKMeans.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e2);
        } catch (ExecutionException e3) {
            Logger.getLogger(MiniBatchKMeans.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e3);
        }
        if (!this.storeMeans) {
            this.means = null;
        }
        return iArr4;
    }

    @Override // jsat.clustering.KClusterer
    public int[] cluster(DataSet dataSet, int i, int[] iArr) {
        return cluster(dataSet, i, new FakeExecutor(), iArr);
    }

    @Override // jsat.clustering.KClusterer
    public int[] cluster(DataSet dataSet, int i, int i2, ExecutorService executorService, int[] iArr) {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    @Override // jsat.clustering.KClusterer
    public int[] cluster(DataSet dataSet, int i, int i2, int[] iArr) {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    @Override // jsat.clustering.KClustererBase, jsat.clustering.ClustererBase
    /* renamed from: clone */
    public MiniBatchKMeans mo589clone() {
        return new MiniBatchKMeans(this);
    }
}
