package jsat.classifiers.svm;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.svm.SupportVectorLearner;
import jsat.clustering.kmeans.ElkanKernelKMeans;
import jsat.clustering.kmeans.KernelKMeans;
import jsat.distributions.kernels.KernelTrick;
import jsat.distributions.kernels.RBFKernel;
import jsat.exceptions.FailedToFitException;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.Vec;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.DoubleList;
import jsat.utils.FakeExecutor;
import jsat.utils.IntList;
import jsat.utils.ListUtils;
import jsat.utils.SystemInfo;
import org.apache.commons.math3.analysis.interpolation.MicrosphereInterpolator;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/svm/DCSVM.class */
public class DCSVM extends SupportVectorLearner implements Classifier, Parameterized {
    private double C;
    private double tolerance;
    private KernelKMeans clusters;
    private int m;
    private int l_max;
    private int l_early;
    private int k;
    private Map<Integer, SVMnoBias> early_models;
    private long cache_size;

    public DCSVM(KernelTrick kernelTrick) {
        super(kernelTrick, SupportVectorLearner.CacheMode.ROWS);
        this.C = 1.0d;
        this.tolerance = 0.001d;
        this.m = MicrosphereInterpolator.DEFAULT_MICROSPHERE_ELEMENTS;
        this.l_max = 4;
        this.l_early = 3;
        this.k = 4;
        this.cache_size = 0L;
        this.cache_size = Runtime.getRuntime().freeMemory() / 2;
    }

    public DCSVM() {
        this(new RBFKernel());
    }

    public DCSVM(DCSVM dcsvm) {
        super(dcsvm);
        this.C = 1.0d;
        this.tolerance = 0.001d;
        this.m = MicrosphereInterpolator.DEFAULT_MICROSPHERE_ELEMENTS;
        this.l_max = 4;
        this.l_early = 3;
        this.k = 4;
        this.cache_size = 0L;
        this.C = dcsvm.C;
        this.tolerance = dcsvm.tolerance;
        if (dcsvm.clusters != null) {
            this.clusters = dcsvm.clusters.mo590clone();
        }
        this.cache_size = dcsvm.cache_size;
        this.m = dcsvm.m;
        this.l_early = dcsvm.l_early;
        this.l_max = dcsvm.l_max;
        this.k = dcsvm.k;
        if (dcsvm.early_models != null) {
            this.early_models = new ConcurrentHashMap();
            for (Map.Entry<Integer, SVMnoBias> entry : dcsvm.early_models.entrySet()) {
                this.early_models.put(entry.getKey(), entry.getValue().m572clone());
            }
        }
    }

    public void setStartLevel(int i) {
        if (i < 0) {
            throw new IllegalArgumentException("l_max must be a non-negative integer, not " + i);
        }
        this.l_max = i;
    }

    public int getStartLevel() {
        return this.l_max;
    }

    public void setEndLevel(int i) {
        if (i < 0) {
            throw new IllegalArgumentException("l_early must be a non-negative integer, not " + i);
        }
        this.l_early = i;
    }

    public int getEndLevel() {
        return this.l_early;
    }

    public void setClusterSampleSize(int i) {
        if (i <= 0) {
            throw new IllegalArgumentException("Cluster Sample Size must be a positive integer, not " + i);
        }
        this.m = i;
    }

    public int getClusterSampleSize() {
        return this.m;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        CategoricalResults categoricalResults = new CategoricalResults(2);
        if (getScore(dataPoint) > 0.0d) {
            categoricalResults.setProb(1, 1.0d);
        } else {
            categoricalResults.setProb(0, 1.0d);
        }
        return categoricalResults;
    }

    public double getScore(DataPoint dataPoint) {
        if (this.vecs == null) {
            throw new UntrainedModelException("Classifier has yet to be trained");
        }
        Vec numericalValues = dataPoint.getNumericalValues();
        return this.early_models.get(Integer.valueOf(this.early_models.size() > 1 ? this.clusters.findClosestCluster(numericalValues, getKernel().getQueryInfo(numericalValues)) : 0)).getScore(dataPoint);
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        int[] cluster;
        int i = executorService instanceof FakeExecutor ? 1 : SystemInfo.LogicalCores;
        final int sampleSize = classificationDataSet.getSampleSize();
        this.vecs = classificationDataSet.getDataVectors();
        this.early_models = new ConcurrentHashMap();
        setCacheMode(SupportVectorLearner.CacheMode.NONE);
        this.alphas = new double[sampleSize];
        final int[] iArr = new int[sampleSize];
        IntList intList = new IntList();
        for (int i2 = this.l_max; i2 >= this.l_early; i2--) {
            this.early_models.clear();
            ClassificationDataSet classificationDataSet2 = new ClassificationDataSet(classificationDataSet.getNumNumericalVars(), classificationDataSet.getCategories(), classificationDataSet.getPredicting());
            int pow = (int) Math.pow(this.k, i2);
            int i3 = sampleSize / pow < 7 ? pow * 7 : this.m;
            if (i2 == this.l_max) {
                ListUtils.addRange(intList, 0, sampleSize, 1);
                Collections.shuffle(intList);
                for (int i4 = 0; i4 < Math.min(i3, sampleSize); i4++) {
                    classificationDataSet2.addDataPoint(classificationDataSet.getDataPoint(i4), classificationDataSet.getDataPointCategory(i4));
                }
            } else {
                intList.clear();
                for (int i5 = 0; i5 < sampleSize; i5++) {
                    if (this.alphas[i5] != 0.0d) {
                        intList.add(i5);
                    }
                }
                Collections.shuffle(intList);
                for (int i6 = 0; i6 < Math.min(i3, intList.size()); i6++) {
                    classificationDataSet2.addDataPoint(classificationDataSet.getDataPoint(i6), classificationDataSet.getDataPointCategory(i6));
                }
            }
            this.clusters = new ElkanKernelKMeans(getKernel());
            this.clusters.setMaximumIterations(100);
            int min = Math.min(pow, classificationDataSet2.getSampleSize() / 2);
            if (min <= 1) {
                cluster = new int[sampleSize];
                intList.clear();
                ListUtils.addRange(intList, 0, sampleSize, 1);
            } else {
                cluster = this.clusters.cluster(classificationDataSet2, min, executorService, (int[]) null);
            }
            Arrays.fill(iArr, -1);
            HashSet hashSet = new HashSet(min);
            for (int i7 = 0; i7 < cluster.length; i7++) {
                iArr[intList.get(i7).intValue()] = cluster[i7];
                hashSet.add(Integer.valueOf(cluster[i7]));
            }
            final CountDownLatch countDownLatch = new CountDownLatch(i);
            for (int i8 = 0; i8 < i; i8++) {
                final int i9 = i8;
                final int i10 = i;
                executorService.submit(new Runnable() { // from class: jsat.classifiers.svm.DCSVM.1
                    @Override // java.lang.Runnable
                    public void run() {
                        int i11 = i9;
                        while (true) {
                            int i12 = i11;
                            if (i12 >= sampleSize) {
                                countDownLatch.countDown();
                                return;
                            }
                            if (iArr[i12] < 0) {
                                List<Double> list = null;
                                if (DCSVM.this.accelCache != null) {
                                    int size = DCSVM.this.accelCache.size() / sampleSize;
                                    list = DCSVM.this.accelCache.subList(i12 * size, (i12 * size) + size);
                                }
                                iArr[i12] = DCSVM.this.clusters.findClosestCluster(DCSVM.this.vecs.get(i12), list);
                            }
                            i11 = i12 + i10;
                        }
                    }
                });
            }
            try {
                countDownLatch.await();
                Iterator it = hashSet.iterator();
                while (it.hasNext()) {
                    int intValue = ((Integer) it.next()).intValue();
                    ClassificationDataSet classificationDataSet3 = new ClassificationDataSet(classificationDataSet.getNumNumericalVars(), classificationDataSet.getCategories(), classificationDataSet.getPredicting());
                    DoubleList doubleList = new DoubleList();
                    IntList intList2 = new IntList();
                    for (int i11 = 0; i11 < sampleSize; i11++) {
                        if (iArr[i11] == intValue) {
                            classificationDataSet3.addDataPoint(classificationDataSet.getDataPoint(i11), classificationDataSet.getDataPointCategory(i11));
                            doubleList.add(Math.abs(this.alphas[i11]));
                            intList2.add(i11);
                        }
                    }
                    SVMnoBias sVMnoBias = new SVMnoBias(getKernel());
                    if (this.cache_size > 0) {
                        sVMnoBias.setCacheSize(doubleList.size(), this.cache_size);
                    } else {
                        sVMnoBias.setCacheMode(SupportVectorLearner.CacheMode.NONE);
                    }
                    if (i2 == this.l_max) {
                        sVMnoBias.trainC(classificationDataSet3, executorService);
                    } else {
                        sVMnoBias.trainC(classificationDataSet3, doubleList.getBackingArray(), executorService);
                    }
                    this.early_models.put(Integer.valueOf(intValue), sVMnoBias);
                    for (int i12 = 0; i12 < intList2.size(); i12++) {
                        this.alphas[intList2.get(i12).intValue()] = sVMnoBias.alphas[i12];
                    }
                }
            } catch (InterruptedException e) {
                throw new FailedToFitException(e);
            }
        }
        if (this.l_early == 0) {
            SVMnoBias sVMnoBias2 = new SVMnoBias(getKernel());
            if (this.cache_size > 0) {
                sVMnoBias2.setCacheSize(classificationDataSet.getSampleSize(), this.cache_size);
            } else {
                sVMnoBias2.setCacheMode(SupportVectorLearner.CacheMode.NONE);
            }
            sVMnoBias2.trainC(classificationDataSet, Arrays.copyOf(this.alphas, this.alphas.length), executorService);
            this.early_models.clear();
            this.early_models.put(0, sVMnoBias2);
            for (int i13 = 0; i13 < sampleSize; i13++) {
                this.alphas[i13] = sVMnoBias2.alphas[i13];
            }
        }
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        trainC(classificationDataSet, new FakeExecutor());
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return true;
    }

    @Override // jsat.classifiers.Classifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public DCSVM m572clone() {
        return new DCSVM(this);
    }

    @Parameter.WarmParameter(prefLowToHigh = true)
    public void setC(double d) {
        if (d <= 0.0d) {
            throw new ArithmeticException("C must be a positive constant");
        }
        this.C = d;
    }

    public double getC() {
        return this.C;
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }
}
