package jsat.clustering.kmeans;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.clustering.KClustererBase;
import jsat.distributions.kernels.KernelTrick;
import jsat.linear.ConstantVector;
import jsat.linear.Vec;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.DoubleList;
import jsat.utils.ListUtils;
import jsat.utils.random.XOR96;
import org.apache.log4j.Priority;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/clustering/kmeans/KernelKMeans.class */
public abstract class KernelKMeans extends KClustererBase implements Parameterized {
    private static final long serialVersionUID = -5294680202634779440L;

    @Parameter.ParameterHolder
    protected KernelTrick kernel;
    protected List<Vec> X;
    protected Vec W;
    protected List<Double> accel;
    protected double[] selfK;
    protected double[] meanSqrdNorms;
    protected double[] normConsts;
    protected double[] ownes;
    protected int[] newDesignations;
    protected int maximumIterations;

    public KernelKMeans(KernelTrick kernelTrick) {
        this.maximumIterations = Priority.OFF_INT;
        this.kernel = kernelTrick;
    }

    public KernelKMeans(KernelKMeans kernelKMeans) {
        this.maximumIterations = Priority.OFF_INT;
        this.kernel = kernelKMeans.kernel.mo622clone();
        this.maximumIterations = kernelKMeans.maximumIterations;
        if (kernelKMeans.X != null) {
            this.X = new ArrayList(kernelKMeans.X.size());
            Iterator<Vec> it = kernelKMeans.X.iterator();
            while (it.hasNext()) {
                this.X.add(it.next().mo524clone());
            }
        }
        if (kernelKMeans.accel != null) {
            this.accel = new DoubleList(kernelKMeans.accel);
        }
        if (kernelKMeans.selfK != null) {
            this.selfK = Arrays.copyOf(kernelKMeans.selfK, kernelKMeans.selfK.length);
        }
        if (kernelKMeans.meanSqrdNorms != null) {
            this.meanSqrdNorms = Arrays.copyOf(kernelKMeans.meanSqrdNorms, kernelKMeans.meanSqrdNorms.length);
        }
        if (kernelKMeans.normConsts != null) {
            this.normConsts = Arrays.copyOf(kernelKMeans.normConsts, kernelKMeans.normConsts.length);
        }
        if (kernelKMeans.ownes != null) {
            this.ownes = Arrays.copyOf(kernelKMeans.ownes, kernelKMeans.ownes.length);
        }
        if (kernelKMeans.newDesignations != null) {
            this.newDesignations = Arrays.copyOf(kernelKMeans.newDesignations, kernelKMeans.newDesignations.length);
        }
        if (kernelKMeans.W != null) {
            this.W = kernelKMeans.W.mo524clone();
        }
    }

    public void setMaximumIterations(int i) {
        if (i <= 0) {
            throw new IllegalArgumentException("iterations must be a positive value, not " + i);
        }
        this.maximumIterations = i;
    }

    public int getMaximumIterations() {
        return this.maximumIterations;
    }

    @Override // jsat.clustering.Clusterer
    public int[] cluster(DataSet dataSet, int[] iArr) {
        throw new UnsupportedOperationException("Not supported.");
    }

    @Override // jsat.clustering.Clusterer
    public int[] cluster(DataSet dataSet, ExecutorService executorService, int[] iArr) {
        throw new UnsupportedOperationException("Not supported.");
    }

    @Override // jsat.clustering.KClusterer
    public int[] cluster(DataSet dataSet, int i, int i2, ExecutorService executorService, int[] iArr) {
        throw new UnsupportedOperationException("Not supported.");
    }

    @Override // jsat.clustering.KClusterer
    public int[] cluster(DataSet dataSet, int i, int i2, int[] iArr) {
        throw new UnsupportedOperationException("Not supported.");
    }

    protected double evalSumK(int i, int i2, int[] iArr) {
        double d = 0.0d;
        for (int i3 = 0; i3 < this.X.size(); i3++) {
            if (iArr[i3] == i2) {
                d += this.W.get(i3) * this.kernel.eval(i, i3, this.X, this.accel);
            }
        }
        return d;
    }

    protected double evalSumK(Vec vec, List<Double> list, int i, int[] iArr) {
        double d = 0.0d;
        for (int i2 = 0; i2 < this.X.size(); i2++) {
            if (iArr[i2] == i) {
                d += this.W.get(i2) * this.kernel.eval(i2, vec, list, this.X, this.accel);
            }
        }
        return d;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void setup(int i, int[] iArr, Vec vec) {
        this.accel = this.kernel.getAccelerationCache(this.X);
        int size = this.X.size();
        this.selfK = new double[size];
        for (int i2 = 0; i2 < this.selfK.length; i2++) {
            this.selfK[i2] = this.kernel.eval(i2, i2, this.X, this.accel);
        }
        this.ownes = new double[i];
        this.meanSqrdNorms = new double[i];
        this.newDesignations = new int[size];
        if (vec == null) {
            this.W = new ConstantVector(1.0d, size);
        } else {
            this.W = vec;
        }
        XOR96 xor96 = new XOR96();
        for (int i3 = 0; i3 < size; i3++) {
            int nextInt = xor96.nextInt(i);
            double[] dArr = this.ownes;
            dArr[nextInt] = dArr[nextInt] + this.W.get(i3);
            iArr[i3] = nextInt;
            this.newDesignations[i3] = nextInt;
        }
        this.normConsts = new double[i];
        updateNormConsts();
        for (int i4 = 0; i4 < size; i4++) {
            int i5 = iArr[i4];
            double d = this.W.get(i4);
            double[] dArr2 = this.meanSqrdNorms;
            dArr2[i5] = dArr2[i5] + (d * this.selfK[i4]);
            for (int i6 = i4 + 1; i6 < size; i6++) {
                if (i5 == iArr[i6]) {
                    double[] dArr3 = this.meanSqrdNorms;
                    dArr3[i5] = dArr3[i5] + (2.0d * d * this.W.get(i6) * this.kernel.eval(i4, i6, this.X, this.accel));
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void updateNormConsts() {
        for (int i = 0; i < this.normConsts.length; i++) {
            this.normConsts[i] = 1.0d / (this.ownes[i] * this.ownes[i]);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double distance(int i, int i2, int[] iArr) {
        return Math.sqrt(Math.max((this.selfK[i] - ((2.0d / this.ownes[i2]) * evalSumK(i, i2, iArr))) + (this.meanSqrdNorms[i2] * this.normConsts[i2]), 0.0d));
    }

    public double distance(Vec vec, int i) {
        return distance(vec, this.kernel.getQueryInfo(vec), i);
    }

    public double distance(Vec vec, List<Double> list, int i) {
        if (i >= this.meanSqrdNorms.length || i < 0) {
            throw new IndexOutOfBoundsException("Only " + this.meanSqrdNorms.length + " clusters. " + i + " is not a valid index");
        }
        return Math.sqrt(Math.max((this.kernel.eval(0, 0, Arrays.asList(vec), list) - ((2.0d / this.ownes[i]) * evalSumK(vec, list, i, this.newDesignations))) + (this.meanSqrdNorms[i] * this.normConsts[i]), 0.0d));
    }

    public int findClosestCluster(Vec vec) {
        return findClosestCluster(vec, this.kernel.getQueryInfo(vec));
    }

    public int findClosestCluster(Vec vec, List<Double> list) {
        double d = Double.MAX_VALUE;
        int i = -1;
        for (int i2 = 0; i2 < this.meanSqrdNorms.length; i2++) {
            double distance = distance(vec, list, i2);
            if (distance < d) {
                d = distance;
                i = i2;
            }
        }
        return i;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public int updateMeansFromChange(int i, int[] iArr) {
        return updateMeansFromChange(i, iArr, this.meanSqrdNorms, this.ownes);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public int updateMeansFromChange(int i, int[] iArr, double[] dArr, double[] dArr2) {
        int i2 = iArr[i];
        int i3 = this.newDesignations[i];
        if (i2 == i3) {
            return 0;
        }
        int size = this.X.size();
        double d = this.W.get(i);
        dArr2[i2] = dArr2[i2] - d;
        dArr2[i3] = dArr2[i3] + d;
        for (int i4 = 0; i4 < size; i4++) {
            double d2 = this.W.get(i4);
            int i5 = iArr[i4];
            int i6 = this.newDesignations[i4];
            if (i == i4) {
                dArr[i2] = dArr[i2] - (d * this.selfK[i]);
                dArr[i3] = dArr[i3] + (d * this.selfK[i]);
            } else {
                if (i2 == i5 && (i <= i4 || i5 == i6)) {
                    dArr[i2] = dArr[i2] - (((2.0d * d) * d2) * this.kernel.eval(i, i4, this.X, this.accel));
                }
                if (i3 == i6 && (i <= i4 || i5 == i6)) {
                    dArr[i3] = dArr[i3] + (2.0d * d * d2 * this.kernel.eval(i, i4, this.X, this.accel));
                }
            }
        }
        return 1;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void applyMeanUpdates(double[] dArr, double[] dArr2) {
        for (int i = 0; i < dArr.length; i++) {
            double[] dArr3 = this.meanSqrdNorms;
            int i2 = i;
            dArr3[i2] = dArr3[i2] + dArr[i];
            double[] dArr4 = this.ownes;
            int i3 = i;
            dArr4[i3] = dArr4[i3] + dArr2[i];
        }
    }

    public double meanToMeanDistance(int i, int i2) {
        if (i >= this.meanSqrdNorms.length || i < 0) {
            throw new IndexOutOfBoundsException("Only " + this.meanSqrdNorms.length + " clusters. " + i + " is not a valid index");
        }
        if (i2 >= this.meanSqrdNorms.length || i2 < 0) {
            throw new IndexOutOfBoundsException("Only " + this.meanSqrdNorms.length + " clusters. " + i2 + " is not a valid index");
        }
        return meanToMeanDistance(i, i2, this.newDesignations);
    }

    protected double meanToMeanDistance(int i, int i2, int[] iArr) {
        return Math.sqrt(Math.max(0.0d, ((this.meanSqrdNorms[i] * this.normConsts[i]) + (this.meanSqrdNorms[i2] * this.normConsts[i2])) - (2.0d * dot(i, i2, iArr))));
    }

    protected double meanToMeanDistance(int i, int i2, int[] iArr, ExecutorService executorService) {
        return Math.sqrt(Math.max(0.0d, ((this.meanSqrdNorms[i] * this.normConsts[i]) + (this.meanSqrdNorms[i2] * this.normConsts[i2])) - (2.0d * dot(i, i2, iArr, executorService))));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double meanToMeanDistance(int i, int i2, int[] iArr, int[] iArr2, double d) {
        return Math.sqrt(Math.max(0.0d, ((this.meanSqrdNorms[i] * this.normConsts[i]) + d) - (2.0d * dot(i, i2, iArr, iArr2))));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double meanToMeanDistance(int i, int i2, int[] iArr, int[] iArr2, double d, ExecutorService executorService) {
        return Math.sqrt(Math.max(0.0d, ((this.meanSqrdNorms[i] * this.normConsts[i]) + d) - (2.0d * dot(i, i2, iArr, iArr2, executorService))));
    }

    private double dot(int i, int i2, int[] iArr) {
        return dot(i, i2, iArr, iArr);
    }

    private double dot(int i, int i2, int[] iArr, ExecutorService executorService) {
        return dot(i, i2, iArr, iArr, executorService);
    }

    private double dot(int i, int i2, int[] iArr, int[] iArr2) {
        double d = 0.0d;
        int size = this.X.size();
        double d2 = 0.0d;
        double d3 = 0.0d;
        for (int i3 = 0; i3 < size; i3++) {
            double d4 = this.W.get(i3);
            if (iArr[i3] == i) {
                d2 += d4;
                for (int i4 = 0; i4 < size; i4++) {
                    if (iArr2[i4] == i2) {
                        d += d4 * this.W.get(i4) * this.kernel.eval(i3, i4, this.X, this.accel);
                    }
                }
            }
        }
        for (int i5 = 0; i5 < size; i5++) {
            if (iArr2[i5] == i2) {
                d3 += this.W.get(i5);
            }
        }
        return d / (d2 * d3);
    }

    private double dot(int i, final int i2, int[] iArr, final int[] iArr2, ExecutorService executorService) {
        double d = 0.0d;
        final int size = this.X.size();
        double d2 = 0.0d;
        double d3 = 0.0d;
        ArrayList arrayList = new ArrayList();
        for (int i3 = 0; i3 < size; i3++) {
            final double d4 = this.W.get(i3);
            if (iArr[i3] == i) {
                d2 += d4;
                final int i4 = i3;
                arrayList.add(executorService.submit(new Callable<Double>() { // from class: jsat.clustering.kmeans.KernelKMeans.1
                    /* JADX WARN: Can't rename method to resolve collision */
                    @Override // java.util.concurrent.Callable
                    public Double call() throws Exception {
                        double d5 = 0.0d;
                        for (int i5 = 0; i5 < size; i5++) {
                            if (iArr2[i5] == i2) {
                                d5 += d4 * KernelKMeans.this.W.get(i5) * KernelKMeans.this.kernel.eval(i4, i5, KernelKMeans.this.X, KernelKMeans.this.accel);
                            }
                        }
                        return Double.valueOf(d5);
                    }
                }));
            }
        }
        for (int i5 = 0; i5 < size; i5++) {
            if (iArr2[i5] == i2) {
                d3 += this.W.get(i5);
            }
        }
        try {
            Iterator it = ListUtils.collectFutures(arrayList).iterator();
            while (it.hasNext()) {
                d += ((Double) it.next()).doubleValue();
            }
        } catch (InterruptedException e) {
            Logger.getLogger(KernelKMeans.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
        } catch (ExecutionException e2) {
            Logger.getLogger(KernelKMeans.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e2);
        }
        return d / (d2 * d3);
    }

    @Override // jsat.clustering.KClustererBase, jsat.clustering.ClustererBase
    /* renamed from: clone */
    public abstract KernelKMeans mo589clone();

    @Override // jsat.clustering.ClustererBase, jsat.clustering.Clusterer
    public boolean supportsWeightedData() {
        return true;
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }
}
