package jsat.clustering.kmeans;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ExecutorService;
import jsat.DataSet;
import jsat.SimpleDataSet;
import jsat.classifiers.DataPoint;
import jsat.clustering.SeedSelectionMethods;
import jsat.linear.MatrixStatistics;
import jsat.linear.Vec;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/clustering/kmeans/XMeans.class */
public class XMeans extends KMeans {
    private static final long serialVersionUID = -2577160317892141870L;
    private boolean stopAfterFail;
    private boolean iterativeRefine;
    private int minClusterSize;
    private KMeans kmeans;

    public XMeans() {
        this(new HamerlyKMeans());
    }

    public XMeans(KMeans kMeans) {
        super(kMeans.dm, kMeans.seedSelection, kMeans.rand);
        this.stopAfterFail = false;
        this.iterativeRefine = true;
        this.minClusterSize = 25;
        this.kmeans = kMeans;
        this.kmeans.saveCentroidDistance = true;
        this.kmeans.setStoreMeans(true);
    }

    public XMeans(XMeans xMeans) {
        super(xMeans);
        this.stopAfterFail = false;
        this.iterativeRefine = true;
        this.minClusterSize = 25;
        this.kmeans = xMeans.kmeans.mo589clone();
        this.stopAfterFail = xMeans.stopAfterFail;
        this.iterativeRefine = xMeans.iterativeRefine;
        this.minClusterSize = xMeans.minClusterSize;
    }

    public void setStopAfterFail(boolean z) {
        this.stopAfterFail = z;
    }

    public boolean isStopAfterFail() {
        return this.stopAfterFail;
    }

    public void setMinClusterSize(int i) {
        if (i < 2) {
            throw new IllegalArgumentException("min cluster size that could be split is 2, not " + i);
        }
        this.minClusterSize = i;
    }

    public int getMinClusterSize() {
        return this.minClusterSize;
    }

    public void setIterativeRefine(boolean z) {
        this.iterativeRefine = z;
    }

    public boolean getIterativeRefine() {
        return this.iterativeRefine;
    }

    @Override // jsat.clustering.kmeans.KMeans, jsat.clustering.Clusterer
    public int[] cluster(DataSet dataSet, int[] iArr) {
        return cluster(dataSet, 2, Math.max(dataSet.getSampleSize() / 20, 10), iArr);
    }

    @Override // jsat.clustering.kmeans.KMeans, jsat.clustering.Clusterer
    public int[] cluster(DataSet dataSet, ExecutorService executorService, int[] iArr) {
        return cluster(dataSet, 2, Math.max(dataSet.getSampleSize() / 20, 10), executorService, iArr);
    }

    private static int freeParameters(int i, int i2) {
        return (i - 1) + (i2 * i) + 1;
    }

    @Override // jsat.clustering.kmeans.KMeans, jsat.clustering.KClusterer
    public int[] cluster(DataSet dataSet, int i, int i2, ExecutorService executorService, int[] iArr) {
        int size;
        int sampleSize = dataSet.getSampleSize();
        int numNumericalVars = dataSet.getNumNumericalVars();
        if (iArr == null || iArr.length < dataSet.getSampleSize()) {
            iArr = new int[sampleSize];
        }
        List<Vec> dataVectors = dataSet.getDataVectors();
        List<Double> accelerationCache = this.dm.getAccelerationCache(dataVectors, executorService);
        double[] dArr = new double[i2];
        int[] iArr2 = new int[i2];
        if (i >= 2) {
            this.means = new ArrayList();
            this.kmeans.cluster(dataSet, accelerationCache, i, this.means, iArr, true, executorService, true, null);
            for (int i3 = 0; i3 < dataVectors.size(); i3++) {
                int i4 = iArr[i3];
                dArr[i4] = dArr[i4] + Math.pow(this.kmeans.nearestCentroidDist[i3], 2.0d);
                int i5 = iArr[i3];
                iArr2[i5] = iArr2[i5] + 1;
            }
        } else {
            if (iArr == null || iArr.length < sampleSize) {
                iArr = new int[sampleSize];
            } else {
                Arrays.fill(iArr, 0);
            }
            this.means = new ArrayList(Arrays.asList(MatrixStatistics.meanVector(dataSet)));
            iArr2[0] = sampleSize;
            List<Double> queryInfo = this.dm.getQueryInfo(this.means.get(0));
            for (int i6 = 0; i6 < dataVectors.size(); i6++) {
                dArr[0] = dArr[0] + Math.pow(this.dm.dist(i6, this.means.get(0), queryInfo, dataVectors, accelerationCache), 2.0d);
            }
        }
        int[] iArr3 = new int[iArr.length];
        int[] iArr4 = new int[iArr.length];
        ArrayList arrayList = new ArrayList(Collections.nCopies(this.means.size(), false));
        do {
            size = this.means.size();
            for (int i7 = 0; i7 < size; i7++) {
                if (!((Boolean) arrayList.get(i7)).booleanValue()) {
                    List<DataPoint> datapointsFromCluster = getDatapointsFromCluster(i7, iArr, dataSet, iArr3);
                    int size2 = datapointsFromCluster.size();
                    if (datapointsFromCluster.size() >= this.minClusterSize && this.means.size() != i2) {
                        iArr4 = this.kmeans.cluster(new SimpleDataSet(datapointsFromCluster), 2, executorService, iArr4);
                        ArrayList arrayList2 = new ArrayList(2);
                        this.kmeans.cluster(new SimpleDataSet(datapointsFromCluster), null, 2, arrayList2, iArr4, true, executorService, true, null);
                        double[] dArr2 = this.kmeans.nearestCentroidDist;
                        Vec vec = (Vec) arrayList2.get(0);
                        Vec vec2 = (Vec) arrayList2.get(1);
                        double d = 0.0d;
                        int i8 = 0;
                        for (int i9 = 0; i9 < datapointsFromCluster.size(); i9++) {
                            d += Math.pow(dArr2[i9], 2.0d);
                            if (iArr4[i9] == 0) {
                                i8++;
                            }
                        }
                        double d2 = d / (numNumericalVars * (size2 - 2));
                        int i10 = size2 - i8;
                        if ((((((-size2) * numNumericalVars) / 2.0d) * Math.log((6.283185307179586d * dArr[i7]) / (numNumericalVars * (size2 - 1)))) - ((numNumericalVars / 2.0d) * (size2 - 1))) - ((freeParameters(1, numNumericalVars) / 2.0d) * Math.log(size2)) <= (((((i8 * Math.log(i8)) + (i10 * Math.log(i10))) - (size2 * Math.log(size2))) - (((size2 * numNumericalVars) / 2.0d) * Math.log(6.283185307179586d * d2))) - ((numNumericalVars / 2.0d) * (size2 - 2))) - ((freeParameters(2, numNumericalVars) / 2.0d) * Math.log(size2))) {
                            for (int i11 = 0; i11 < datapointsFromCluster.size(); i11++) {
                                if (iArr4[i11] == 1) {
                                    iArr[iArr3[i11]] = this.means.size();
                                }
                            }
                            this.means.set(i7, vec.mo524clone());
                            this.means.add(vec2.mo524clone());
                            arrayList.add(false);
                        } else if (this.stopAfterFail) {
                            arrayList.set(i7, true);
                        }
                    }
                }
            }
            if (this.iterativeRefine && this.means.size() > 1) {
                this.kmeans.cluster(dataSet, accelerationCache, this.means.size(), this.means, iArr, true, executorService, true, null);
                Arrays.fill(dArr, 0.0d);
                Arrays.fill(iArr2, 0);
                for (int i12 = 0; i12 < dataVectors.size(); i12++) {
                    int i13 = iArr[i12];
                    dArr[i13] = dArr[i13] + Math.pow(this.kmeans.nearestCentroidDist[i12], 2.0d);
                    int i14 = iArr[i12];
                    iArr2[i14] = iArr2[i14] + 1;
                }
            }
        } while (size < this.means.size());
        if (!this.iterativeRefine) {
            this.kmeans.cluster(dataSet, accelerationCache, this.means.size(), this.means, iArr, false, executorService, false, null);
        }
        return iArr;
    }

    @Override // jsat.clustering.kmeans.KMeans, jsat.clustering.KClusterer
    public int[] cluster(DataSet dataSet, int i, int i2, int[] iArr) {
        return cluster(dataSet, i, i2, null, iArr);
    }

    @Override // jsat.clustering.kmeans.KMeans
    public int getIterationLimit() {
        return this.kmeans.getIterationLimit();
    }

    @Override // jsat.clustering.kmeans.KMeans
    public void setIterationLimit(int i) {
        this.kmeans.setIterationLimit(i);
    }

    @Override // jsat.clustering.kmeans.KMeans
    public void setSeedSelection(SeedSelectionMethods.SeedSelection seedSelection) {
        if (this.kmeans != null) {
            this.kmeans.setSeedSelection(seedSelection);
        }
    }

    @Override // jsat.clustering.kmeans.KMeans
    public SeedSelectionMethods.SeedSelection getSeedSelection() {
        return this.kmeans.getSeedSelection();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // jsat.clustering.kmeans.KMeans
    public double cluster(DataSet dataSet, List<Double> list, int i, List<Vec> list2, int[] iArr, boolean z, ExecutorService executorService, boolean z2, Vec vec) {
        return this.kmeans.cluster(dataSet, list, i, list2, iArr, z, executorService, z2, null);
    }

    @Override // jsat.clustering.kmeans.KMeans, jsat.clustering.KClustererBase, jsat.clustering.ClustererBase
    /* renamed from: clone */
    public XMeans mo589clone() {
        return new XMeans(this);
    }
}
