package jsat.clustering;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.SimpleDataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.DataPoint;
import jsat.clustering.SeedSelectionMethods;
import jsat.distributions.multivariate.MultivariateDistribution;
import jsat.distributions.multivariate.NormalM;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.Matrix;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.utils.ListUtils;
import jsat.utils.SystemInfo;
import jsat.utils.random.XORWOW;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/clustering/EMGaussianMixture.class */
public class EMGaussianMixture extends KClustererBase implements MultivariateDistribution {
    private SeedSelectionMethods.SeedSelection seedSelection;
    private static final long serialVersionUID = 2606159815670221662L;
    private List<NormalM> gaussians;
    private double[] a_k;
    private double tolerance;
    protected int MaxIterLimit;

    public EMGaussianMixture(SeedSelectionMethods.SeedSelection seedSelection) {
        this.tolerance = 0.001d;
        this.MaxIterLimit = Integer.MAX_VALUE;
        setSeedSelection(seedSelection);
    }

    public EMGaussianMixture() {
        this(SeedSelectionMethods.SeedSelection.KPP);
    }

    public void setSeedSelection(SeedSelectionMethods.SeedSelection seedSelection) {
        this.seedSelection = seedSelection;
    }

    public SeedSelectionMethods.SeedSelection getSeedSelection() {
        return this.seedSelection;
    }

    public void setIterationLimit(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("Iterations must be a positive value, not " + i);
        }
        this.MaxIterLimit = i;
    }

    public int getIterationLimit() {
        return this.MaxIterLimit;
    }

    public EMGaussianMixture(EMGaussianMixture eMGaussianMixture) {
        this.tolerance = 0.001d;
        this.MaxIterLimit = Integer.MAX_VALUE;
        if (eMGaussianMixture.gaussians != null && !eMGaussianMixture.gaussians.isEmpty()) {
            this.gaussians = new ArrayList(eMGaussianMixture.gaussians.size());
            Iterator<NormalM> it = eMGaussianMixture.gaussians.iterator();
            while (it.hasNext()) {
                this.gaussians.add(it.next().mo630clone());
            }
        }
        if (eMGaussianMixture.a_k != null) {
            this.a_k = Arrays.copyOf(eMGaussianMixture.a_k, eMGaussianMixture.a_k.length);
        }
        this.MaxIterLimit = eMGaussianMixture.MaxIterLimit;
        this.tolerance = eMGaussianMixture.tolerance;
    }

    private EMGaussianMixture(List<NormalM> list, double[] dArr, double d) {
        this.tolerance = 0.001d;
        this.MaxIterLimit = Integer.MAX_VALUE;
        this.gaussians = new ArrayList(dArr.length);
        this.a_k = new double[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            this.gaussians.add(list.get(i).mo630clone());
            this.a_k[i] = dArr[i];
        }
    }

    protected double cluster(DataSet dataSet, List<Double> list, int i, List<Vec> list2, int[] iArr, boolean z, ExecutorService executorService, boolean z2) {
        EuclideanDistance euclideanDistance = new EuclideanDistance();
        ArrayList arrayList = new ArrayList();
        if (list2.size() < i) {
            list2.clear();
            list2.addAll(SeedSelectionMethods.selectIntialPoints(dataSet, i, euclideanDistance, list, new XORWOW(), this.seedSelection, executorService));
            Iterator<Vec> it = list2.iterator();
            while (it.hasNext()) {
                arrayList.add(euclideanDistance.getQueryInfo(it.next()));
            }
        }
        ArrayList arrayList2 = new ArrayList(i);
        int numNumericalVars = dataSet.getNumNumericalVars();
        for (int i2 = 0; i2 < list2.size(); i2++) {
            arrayList2.add(new DenseMatrix(numNumericalVars, numNumericalVars));
        }
        this.a_k = new double[i];
        double sampleSize = dataSet.getSampleSize();
        DenseVector denseVector = new DenseVector(numNumericalVars);
        List<Vec> dataVectors = dataSet.getDataVectors();
        for (int i3 = 0; i3 < dataSet.getSampleSize(); i3++) {
            Vec numericalValues = dataSet.getDataPoint(i3).getNumericalValues();
            double dist = euclideanDistance.dist(i3, list2.get(0), (List) arrayList.get(0), dataVectors, list);
            int i4 = 0;
            for (int i5 = 1; i5 < i; i5++) {
                double dist2 = euclideanDistance.dist(i3, list2.get(i5), (List) arrayList.get(i5), dataVectors, list);
                if (dist2 < dist) {
                    dist = dist2;
                    i4 = i5;
                }
            }
            iArr[i3] = i4;
            double[] dArr = this.a_k;
            int i6 = i4;
            dArr[i6] = dArr[i6] + 1.0d;
            numericalValues.copyTo(denseVector);
            denseVector.mutableSubtract(list2.get(i4));
            Matrix.OuterProductUpdate(arrayList2.get(i4), denseVector, denseVector, 1.0d);
        }
        for (int i7 = 0; i7 < list2.size(); i7++) {
            arrayList2.get(i7).mutableMultiply(1.0d / this.a_k[i7]);
            double[] dArr2 = this.a_k;
            int i8 = i7;
            dArr2[i8] = dArr2[i8] / sampleSize;
        }
        return clusterCompute(i, dataSet, iArr, list2, arrayList2, executorService);
    }

    protected double clusterCompute(int i, DataSet dataSet, int[] iArr, List<Vec> list, List<Matrix> list2, ExecutorService executorService) {
        double eStep;
        List<DataPoint> dataPoints = dataSet.getDataPoints();
        int size = dataPoints.size();
        double d = -1.7976931348623157E308d;
        this.gaussians = new ArrayList(i);
        for (int i2 = 0; i2 < list.size(); i2++) {
            this.gaussians.add(new NormalM(list.get(i2), list2.get(i2)));
        }
        double[][] dArr = new double[dataPoints.size()][i];
        while (true) {
            try {
                eStep = eStep(size, dataPoints, i, dArr, executorService);
            } catch (InterruptedException e) {
                Logger.getLogger(EMGaussianMixture.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
            } catch (ExecutionException e2) {
                Logger.getLogger(EMGaussianMixture.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e2);
            }
            if (Math.abs(d - eStep) < this.tolerance) {
                break;
            }
            d = eStep;
            mStep(list, size, dataPoints, i, dArr, list2, executorService);
        }
        for (int i3 = 0; i3 < dArr.length; i3++) {
            for (int i4 = 0; i4 < i; i4++) {
                if (dArr[i3][i4] > dArr[i3][iArr[i3]]) {
                    iArr[i3] = i4;
                }
            }
        }
        return -d;
    }

    private void mStep(final List<Vec> list, int i, final List<DataPoint> list2, final int i2, final double[][] dArr, final List<Matrix> list3, ExecutorService executorService) throws InterruptedException {
        final int length = list.get(0).length();
        Iterator<Vec> it = list.iterator();
        while (it.hasNext()) {
            it.next().zeroOut();
        }
        Arrays.fill(this.a_k, 0.0d);
        if (executorService == null) {
            for (int i3 = 0; i3 < i; i3++) {
                Vec numericalValues = list2.get(i3).getNumericalValues();
                for (int i4 = 0; i4 < i2; i4++) {
                    double[] dArr2 = this.a_k;
                    int i5 = i4;
                    dArr2[i5] = dArr2[i5] + dArr[i3][i4];
                    list.get(i4).mutableAdd(dArr[i3][i4], numericalValues);
                }
            }
        } else {
            final CountDownLatch countDownLatch = new CountDownLatch(SystemInfo.LogicalCores);
            int i6 = 0;
            int i7 = i / SystemInfo.LogicalCores;
            int i8 = i % SystemInfo.LogicalCores;
            while (i6 < i) {
                int i9 = i8;
                i8--;
                final int min = Math.min((i9 > 0 ? 1 : 0) + i6 + i7, i);
                final int i10 = i6;
                i6 = min;
                executorService.submit(new Runnable() { // from class: jsat.clustering.EMGaussianMixture.1
                    @Override // java.lang.Runnable
                    public void run() {
                        Vec[] vecArr = new Vec[list.size()];
                        for (int i11 = 0; i11 < vecArr.length; i11++) {
                            vecArr[i11] = new DenseVector(((Vec) list.get(i11)).length());
                        }
                        double[] dArr3 = new double[EMGaussianMixture.this.a_k.length];
                        for (int i12 = i10; i12 < min; i12++) {
                            Vec numericalValues2 = ((DataPoint) list2.get(i12)).getNumericalValues();
                            for (int i13 = 0; i13 < i2; i13++) {
                                int i14 = i13;
                                dArr3[i14] = dArr3[i14] + dArr[i12][i13];
                                vecArr[i13].mutableAdd(dArr[i12][i13], numericalValues2);
                            }
                        }
                        synchronized (list) {
                            for (int i15 = 0; i15 < EMGaussianMixture.this.a_k.length; i15++) {
                                double[] dArr4 = EMGaussianMixture.this.a_k;
                                int i16 = i15;
                                dArr4[i16] = dArr4[i16] + dArr3[i15];
                                ((Vec) list.get(i15)).mutableAdd(vecArr[i15]);
                            }
                        }
                        countDownLatch.countDown();
                    }
                });
            }
            countDownLatch.await();
        }
        for (int i11 = 0; i11 < this.a_k.length; i11++) {
            list.get(i11).mutableDivide(this.a_k[i11]);
        }
        Iterator<Matrix> it2 = list3.iterator();
        while (it2.hasNext()) {
            it2.next().zeroOut();
        }
        if (executorService == null) {
            for (int i12 = 0; i12 < i2; i12++) {
                Matrix matrix = list3.get(i12);
                Vec vec = list.get(i12);
                DenseVector denseVector = new DenseVector(vec.length());
                for (int i13 = 0; i13 < list2.size(); i13++) {
                    list2.get(i13).getNumericalValues().copyTo(denseVector);
                    denseVector.mutableSubtract(vec);
                    Matrix.OuterProductUpdate(matrix, denseVector, denseVector, dArr[i13][i12]);
                }
                matrix.mutableMultiply(1.0d / this.a_k[i12]);
            }
        } else {
            final CountDownLatch countDownLatch2 = new CountDownLatch(SystemInfo.LogicalCores);
            int i14 = 0;
            int i15 = i / SystemInfo.LogicalCores;
            int i16 = i % SystemInfo.LogicalCores;
            while (i14 < i) {
                int i17 = i16;
                i16--;
                final int min2 = Math.min((i17 > 0 ? 1 : 0) + i14 + i15, i);
                final int i18 = i14;
                i14 = min2;
                executorService.submit(new Runnable() { // from class: jsat.clustering.EMGaussianMixture.2
                    @Override // java.lang.Runnable
                    public void run() {
                        Matrix[] matrixArr = new Matrix[i2];
                        for (int i19 = 0; i19 < matrixArr.length; i19++) {
                            matrixArr[i19] = new DenseMatrix(length, length);
                        }
                        for (int i20 = i18; i20 < min2; i20++) {
                            Vec numericalValues2 = ((DataPoint) list2.get(i20)).getNumericalValues();
                            DenseVector denseVector2 = new DenseVector(numericalValues2.length());
                            for (int i21 = 0; i21 < i2; i21++) {
                                Matrix matrix2 = matrixArr[i21];
                                Vec vec2 = (Vec) list.get(i21);
                                numericalValues2.copyTo(denseVector2);
                                denseVector2.mutableSubtract(vec2);
                                Matrix.OuterProductUpdate(matrix2, denseVector2, denseVector2, dArr[i20][i21]);
                            }
                        }
                        synchronized (list3) {
                            for (int i22 = 0; i22 < i2; i22++) {
                                ((Matrix) list3.get(i22)).mutableAdd(matrixArr[i22]);
                            }
                        }
                        countDownLatch2.countDown();
                    }
                });
            }
            countDownLatch2.await();
            for (int i19 = 0; i19 < i2; i19++) {
                list3.get(i19).mutableMultiply(1.0d / this.a_k[i19]);
            }
        }
        for (int i20 = 0; i20 < i2; i20++) {
            double[] dArr3 = this.a_k;
            int i21 = i20;
            dArr3[i21] = dArr3[i21] / i;
        }
        for (int i22 = 0; i22 < list.size(); i22++) {
            this.gaussians.get(i22).setMeanCovariance(list.get(i22), list3.get(i22));
        }
    }

    private double eStep(int i, final List<DataPoint> list, final int i2, final double[][] dArr, ExecutorService executorService) throws InterruptedException, ExecutionException {
        double d = 0.0d;
        if (executorService == null) {
            for (int i3 = 0; i3 < i; i3++) {
                Vec numericalValues = list.get(i3).getNumericalValues();
                double d2 = 0.0d;
                for (int i4 = 0; i4 < i2; i4++) {
                    double pdf = this.a_k[i4] * this.gaussians.get(i4).pdf(numericalValues);
                    dArr[i3][i4] = pdf;
                    d2 += pdf;
                }
                for (int i5 = 0; i5 < i2; i5++) {
                    double[] dArr2 = dArr[i3];
                    int i6 = i5;
                    dArr2[i6] = dArr2[i6] / d2;
                }
                d += Math.log(d2);
            }
        } else {
            ArrayList arrayList = new ArrayList(SystemInfo.LogicalCores);
            int i7 = 0;
            int i8 = i / SystemInfo.LogicalCores;
            int i9 = i % SystemInfo.LogicalCores;
            while (i7 < i) {
                int i10 = i9;
                i9--;
                final int min = Math.min((i10 > 0 ? 1 : 0) + i7 + i8, i);
                final int i11 = i7;
                i7 = min;
                arrayList.add(executorService.submit(new Callable<Double>() { // from class: jsat.clustering.EMGaussianMixture.3
                    /* JADX WARN: Can't rename method to resolve collision */
                    @Override // java.util.concurrent.Callable
                    public Double call() throws Exception {
                        double d3 = 0.0d;
                        for (int i12 = i11; i12 < min; i12++) {
                            Vec numericalValues2 = ((DataPoint) list.get(i12)).getNumericalValues();
                            double d4 = 0.0d;
                            for (int i13 = 0; i13 < i2; i13++) {
                                double pdf2 = EMGaussianMixture.this.a_k[i13] * ((NormalM) EMGaussianMixture.this.gaussians.get(i13)).pdf(numericalValues2);
                                dArr[i12][i13] = pdf2;
                                d4 += pdf2;
                            }
                            for (int i14 = 0; i14 < i2; i14++) {
                                double[] dArr3 = dArr[i12];
                                int i15 = i14;
                                dArr3[i15] = dArr3[i15] / d4;
                            }
                            d3 += Math.log(d4);
                        }
                        return Double.valueOf(d3);
                    }
                }));
            }
            Iterator it = ListUtils.collectFutures(arrayList).iterator();
            while (it.hasNext()) {
                d += ((Double) it.next()).doubleValue();
            }
        }
        return d;
    }

    @Override // jsat.distributions.multivariate.MultivariateDistribution
    public double logPdf(double... dArr) {
        return logPdf(DenseVector.toDenseVec(dArr));
    }

    @Override // jsat.distributions.multivariate.MultivariateDistribution
    public double logPdf(Vec vec) {
        double pdf = pdf(vec);
        if (pdf == 0.0d) {
            return -1.7976931348623157E308d;
        }
        return Math.log(pdf);
    }

    @Override // jsat.distributions.multivariate.MultivariateDistribution
    public double pdf(double... dArr) {
        return pdf(DenseVector.toDenseVec(dArr));
    }

    @Override // jsat.distributions.multivariate.MultivariateDistribution
    public double pdf(Vec vec) {
        double d = 0.0d;
        for (int i = 0; i < this.a_k.length; i++) {
            d += this.a_k[i] * this.gaussians.get(i).pdf(vec);
        }
        return d;
    }

    @Override // jsat.distributions.multivariate.MultivariateDistribution
    public <V extends Vec> boolean setUsingData(List<V> list) {
        ArrayList arrayList = new ArrayList(list.size());
        Iterator<V> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(new DataPoint(it.next(), new int[0], new CategoricalData[0]));
        }
        return setUsingDataList(arrayList);
    }

    @Override // jsat.distributions.multivariate.MultivariateDistribution
    public boolean setUsingDataList(List<DataPoint> list) {
        return setUsingData(new SimpleDataSet(list));
    }

    @Override // jsat.distributions.multivariate.MultivariateDistribution
    public boolean setUsingData(DataSet dataSet) {
        try {
            cluster(dataSet);
            return true;
        } catch (ArithmeticException e) {
            return false;
        }
    }

    @Override // jsat.distributions.multivariate.MultivariateDistribution
    public boolean setUsingData(DataSet dataSet, ExecutorService executorService) {
        return setUsingData(dataSet);
    }

    @Override // jsat.distributions.multivariate.MultivariateDistribution
    public <V extends Vec> boolean setUsingData(List<V> list, ExecutorService executorService) {
        return setUsingData(list);
    }

    @Override // jsat.distributions.multivariate.MultivariateDistribution
    public boolean setUsingDataList(List<DataPoint> list, ExecutorService executorService) {
        return setUsingDataList(list);
    }

    @Override // jsat.clustering.KClustererBase, jsat.clustering.ClustererBase
    /* renamed from: clone */
    public EMGaussianMixture mo590clone() {
        return new EMGaussianMixture(this);
    }

    @Override // jsat.distributions.multivariate.MultivariateDistribution
    public List<Vec> sample(int i, Random random) {
        ArrayList arrayList = new ArrayList(i);
        double[] dArr = new double[i];
        for (int i2 = 0; i2 < i; i2++) {
            dArr[i2] = random.nextDouble();
        }
        Arrays.sort(dArr);
        int i3 = 0;
        int i4 = 0;
        int i5 = 0;
        double d = 0.0d;
        while (i4 < this.a_k.length) {
            d += this.a_k[i4];
            while (i5 < i) {
                int i6 = i5;
                i5++;
                if (dArr[i6] < d) {
                    i3++;
                }
            }
            int i7 = i4;
            i4++;
            arrayList.addAll(this.gaussians.get(i7).sample(i3, random));
        }
        return arrayList;
    }

    @Override // jsat.clustering.Clusterer
    public int[] cluster(DataSet dataSet, int[] iArr) {
        return cluster(dataSet, 2, (int) Math.sqrt(dataSet.getSampleSize() / 2), iArr);
    }

    @Override // jsat.clustering.Clusterer
    public int[] cluster(DataSet dataSet, ExecutorService executorService, int[] iArr) {
        return cluster(dataSet, 2, (int) Math.sqrt(dataSet.getSampleSize() / 2), executorService, iArr);
    }

    @Override // jsat.clustering.KClusterer
    public int[] cluster(DataSet dataSet, int i, ExecutorService executorService, int[] iArr) {
        if (iArr == null) {
            iArr = new int[dataSet.getSampleSize()];
        }
        if (dataSet.getSampleSize() < i) {
            throw new ClusterFailureException("Fewer data points then desired clusters, decrease cluster size");
        }
        cluster(dataSet, null, i, new ArrayList(i), iArr, false, executorService, false);
        return iArr;
    }

    @Override // jsat.clustering.KClusterer
    public int[] cluster(DataSet dataSet, int i, int[] iArr) {
        if (iArr == null) {
            iArr = new int[dataSet.getSampleSize()];
        }
        if (dataSet.getSampleSize() < i) {
            throw new ClusterFailureException("Fewer data points then desired clusters, decrease cluster size");
        }
        cluster(dataSet, null, i, new ArrayList(i), iArr, false, null, false);
        return iArr;
    }

    @Override // jsat.clustering.KClusterer
    public int[] cluster(DataSet dataSet, int i, int i2, ExecutorService executorService, int[] iArr) {
        throw new UnsupportedOperationException("EMGaussianMixture does not supported determining the number of clusters");
    }

    @Override // jsat.clustering.KClusterer
    public int[] cluster(DataSet dataSet, int i, int i2, int[] iArr) {
        throw new UnsupportedOperationException("EMGaussianMixture does not supported determining the number of clusters");
    }
}
