package jsat.classifiers.svm;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import jsat.DataSet;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.WarmClassifier;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.classifiers.svm.SupportVectorLearner;
import jsat.distributions.Distribution;
import jsat.distributions.kernels.KernelTrick;
import jsat.distributions.kernels.LinearKernel;
import jsat.exceptions.FailedToFitException;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.regression.WarmRegressor;
import jsat.utils.FakeExecutor;
import jsat.utils.PairedReturn;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.ParallelUtils;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/svm/LSSVM.class */
public class LSSVM extends SupportVectorLearner implements BinaryScoreClassifier, Regressor, Parameterized, WarmRegressor, WarmClassifier {
    private static final long serialVersionUID = -7569924400631719451L;
    protected double b;
    protected double b_low;
    protected double b_up;
    private double C;
    private int i_up;
    private int i_low;
    private double[] fcache;
    private double dualObjective;
    private static double epsilon = 1.0E-12d;
    private static double tol = 0.001d;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/svm/LSSVM$BiasGapCallable.class */
    public class BiasGapCallable implements Callable<Double> {
        int from;
        int to;

        public BiasGapCallable(int i, int i2) {
            this.from = i;
            this.to = i2;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Double call() throws Exception {
            double d = 0.0d;
            for (int i = this.from; i < this.to; i++) {
                d += LSSVM.this.fcache[i] - (LSSVM.this.alphas[i] / LSSVM.this.C);
            }
            return Double.valueOf(d);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/svm/LSSVM$DualityGapCallable.class */
    public class DualityGapCallable implements Callable<Double> {
        int from;
        int to;

        public DualityGapCallable(int i, int i2) {
            this.from = i;
            this.to = i2;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Double call() throws Exception {
            double d = 0.0d;
            for (int i = this.from; i < this.to; i++) {
                double d2 = (LSSVM.this.b + (LSSVM.this.alphas[i] / LSSVM.this.C)) - LSSVM.this.fcache[i];
                d += (LSSVM.this.alphas[i] * (LSSVM.this.fcache[i] - ((0.5d * LSSVM.this.alphas[i]) / LSSVM.this.C))) + (((LSSVM.this.C * d2) * d2) / 2.0d);
            }
            return Double.valueOf(d);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/svm/LSSVM$TakeStepLoop.class */
    public class TakeStepLoop implements Callable<PairedReturn<Integer, Integer>> {
        int from;
        int to;
        int i1;
        int i2;
        double alph1;
        double alph2;
        int i_low_p;
        int i_up_p;

        public TakeStepLoop(int i, int i2, int i3, int i4, double d, double d2) {
            this.from = i;
            this.to = i2;
            this.i1 = i3;
            this.i2 = i4;
            this.alph1 = d;
            this.alph2 = d2;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public PairedReturn<Integer, Integer> call() throws Exception {
            double d = LSSVM.this.alphas[this.i1];
            double d2 = LSSVM.this.alphas[this.i2];
            double d3 = Double.NEGATIVE_INFINITY;
            double d4 = Double.POSITIVE_INFINITY;
            for (int i = this.from; i < this.to; i++) {
                double kEval = LSSVM.this.kEval(this.i1, i);
                double kEval2 = LSSVM.this.kEval(this.i2, i);
                double[] dArr = LSSVM.this.fcache;
                int i2 = i;
                double d5 = dArr[i2] + ((d - this.alph1) * kEval) + ((d2 - this.alph2) * kEval2);
                dArr[i2] = d5;
                if (d5 > d3) {
                    d3 = d5;
                    this.i_up_p = i;
                }
                if (d5 < d4) {
                    d4 = d5;
                    this.i_low_p = i;
                }
            }
            return new PairedReturn<>(Integer.valueOf(this.i_up_p), Integer.valueOf(this.i_low_p));
        }
    }

    public LSSVM() {
        this(new LinearKernel());
    }

    public LSSVM(KernelTrick kernelTrick) {
        this(kernelTrick, SupportVectorLearner.CacheMode.NONE);
    }

    public LSSVM(KernelTrick kernelTrick, SupportVectorLearner.CacheMode cacheMode) {
        super(kernelTrick, cacheMode);
        this.b = 0.0d;
        this.C = 1.0d;
    }

    public LSSVM(LSSVM lssvm) {
        super(lssvm.getKernel().m628clone(), lssvm.getCacheMode());
        this.b = 0.0d;
        this.C = 1.0d;
        this.b_low = lssvm.b_low;
        this.b_up = lssvm.b_up;
        this.i_up = lssvm.i_up;
        this.i_low = lssvm.i_low;
        this.C = lssvm.C;
        if (lssvm.alphas != null) {
            this.alphas = Arrays.copyOf(lssvm.alphas, lssvm.alphas.length);
        }
        if (lssvm.fcache != null) {
            this.fcache = Arrays.copyOf(lssvm.fcache, lssvm.fcache.length);
        }
    }

    @Parameter.WarmParameter(prefLowToHigh = true)
    public void setC(double d) {
        if (d <= 0.0d || Double.isNaN(d) || Double.isInfinite(d)) {
            throw new IllegalArgumentException("C must be in (0, Infty), not " + d);
        }
        this.C = d;
    }

    public double getC() {
        return this.C;
    }

    private boolean takeStep(int i, int i2, ExecutorService executorService, int i3) throws InterruptedException, ExecutionException {
        double d = this.alphas[i];
        double d2 = this.alphas[i2];
        double d3 = this.fcache[i];
        double d4 = this.fcache[i2];
        double d5 = d + d2;
        double kEval = ((2.0d * kEval(i2, i)) - kEval(i, i)) - kEval(i2, i2);
        double d6 = d2 - ((d3 - d4) / kEval);
        if (Math.abs(d6 - d2) < epsilon * (d6 + d2 + epsilon)) {
            return false;
        }
        this.alphas[i] = d5 - d6;
        this.alphas[i2] = d6;
        double d7 = (d3 - d4) / kEval;
        this.dualObjective -= ((kEval / 2.0d) * d7) * d7;
        this.b_up = Double.NEGATIVE_INFINITY;
        this.b_low = Double.POSITIVE_INFINITY;
        ArrayList arrayList = new ArrayList(i3);
        for (int i4 = 0; i4 < i3; i4++) {
            arrayList.add(executorService.submit(new TakeStepLoop(ParallelUtils.getStartBlock(this.fcache.length, i4, i3), ParallelUtils.getEndBlock(this.fcache.length, i4, i3), i, i2, d, d2)));
        }
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            PairedReturn pairedReturn = (PairedReturn) ((Future) it.next()).get();
            int intValue = ((Integer) pairedReturn.getFirstItem()).intValue();
            int intValue2 = ((Integer) pairedReturn.getSecondItem()).intValue();
            if (this.fcache[intValue] > this.b_up) {
                this.b_up = this.fcache[intValue];
                this.i_up = intValue;
            }
            if (this.fcache[intValue2] < this.b_low) {
                this.b_low = this.fcache[intValue2];
                this.i_low = intValue2;
            }
        }
        return true;
    }

    @Override // jsat.regression.WarmRegressor
    public boolean warmFromSameDataOnly() {
        return true;
    }

    private double computeDualityGap(boolean z, ExecutorService executorService, int i) throws InterruptedException, ExecutionException {
        double d = 0.0d;
        if (z) {
            this.b = (this.b_up + this.b_low) / 2.0d;
        } else {
            this.b = 0.0d;
            ArrayList arrayList = new ArrayList(i);
            for (int i2 = 0; i2 < i; i2++) {
                arrayList.add(executorService.submit(new BiasGapCallable(ParallelUtils.getStartBlock(this.alphas.length, i2, i), ParallelUtils.getEndBlock(this.alphas.length, i2, i))));
            }
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                this.b += ((Double) ((Future) it.next()).get()).doubleValue();
            }
            this.b /= this.alphas.length;
        }
        ArrayList arrayList2 = new ArrayList(i);
        for (int i3 = 0; i3 < i; i3++) {
            arrayList2.add(executorService.submit(new DualityGapCallable(ParallelUtils.getStartBlock(this.alphas.length, i3, i), ParallelUtils.getEndBlock(this.alphas.length, i3, i))));
        }
        Iterator it2 = arrayList2.iterator();
        while (it2.hasNext()) {
            d += ((Double) ((Future) it2.next()).get()).doubleValue();
        }
        return d;
    }

    private void initializeVariables(double[] dArr, LSSVM lssvm, DataSet dataSet) {
        this.alphas = new double[dArr.length];
        this.fcache = new double[dArr.length];
        this.dualObjective = 0.0d;
        if (lssvm == null) {
            for (int i = 0; i < dArr.length; i++) {
                this.fcache[i] = -dArr[i];
            }
        } else {
            if (lssvm.alphas.length != this.alphas.length) {
                throw new FailedToFitException("Warm LS-SVM solution could not have been trained on the sama data, different number of alpha values present");
            }
            double d = this.C / lssvm.C;
            for (int i2 = 0; i2 < dArr.length; i2++) {
                this.alphas[i2] = lssvm.alphas[i2];
                this.fcache[i2] = lssvm.fcache[i2] - (((d - 1.0d) * lssvm.alphas[i2]) / this.C);
                this.dualObjective += this.alphas[i2] * (dArr[i2] - this.fcache[i2]);
            }
            this.dualObjective /= 2.0d;
        }
        this.b_up = Double.NEGATIVE_INFINITY;
        this.b_low = Double.POSITIVE_INFINITY;
        for (int i3 = 0; i3 < this.fcache.length; i3++) {
            double d2 = this.fcache[i3];
            if (d2 > this.b_up) {
                this.b_up = d2;
                this.i_up = i3;
            }
            if (d2 < this.b_low) {
                this.b_low = d2;
                this.i_low = i3;
            }
        }
        setCacheMode(getCacheMode());
    }

    @Override // jsat.classifiers.calibration.BinaryScoreClassifier
    public double getScore(DataPoint dataPoint) {
        return regress(dataPoint);
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        CategoricalResults categoricalResults = new CategoricalResults(2);
        if (regress(dataPoint) > 0.0d) {
            categoricalResults.setProb(1, 1.0d);
        } else {
            categoricalResults.setProb(0, 1.0d);
        }
        return categoricalResults;
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        trainC(classificationDataSet, null, executorService);
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        trainC(classificationDataSet, (ExecutorService) null);
    }

    @Override // jsat.regression.WarmRegressor
    public void train(RegressionDataSet regressionDataSet, Regressor regressor, ExecutorService executorService) {
        if (regressor != null && !(regressor instanceof LSSVM)) {
            throw new FailedToFitException("Warm solution must be an implementation of LS-SVM, not " + regressor.getClass());
        }
        mainLoop(regressionDataSet, (LSSVM) regressor, regressionDataSet.getTargetValues().arrayCopy(), executorService);
    }

    @Override // jsat.regression.WarmRegressor
    public void train(RegressionDataSet regressionDataSet, Regressor regressor) {
        train(regressionDataSet, regressor, null);
    }

    @Override // jsat.classifiers.WarmClassifier
    public void trainC(ClassificationDataSet classificationDataSet, Classifier classifier, ExecutorService executorService) {
        if (classificationDataSet.getClassSize() != 2) {
            throw new FailedToFitException("LS-SVM only supports binary classification problems");
        }
        if (classifier != null && !(classifier instanceof LSSVM)) {
            throw new FailedToFitException("Warm solution must be an implementation of LS-SVM, not " + classifier.getClass());
        }
        double[] dArr = new double[classificationDataSet.getSampleSize()];
        for (int i = 0; i < classificationDataSet.getSampleSize(); i++) {
            dArr[i] = (classificationDataSet.getDataPointCategory(i) * 2) - 1;
        }
        mainLoop(classificationDataSet, (LSSVM) classifier, dArr, executorService);
    }

    @Override // jsat.classifiers.WarmClassifier
    public void trainC(ClassificationDataSet classificationDataSet, Classifier classifier) {
        trainC(classificationDataSet, classifier, null);
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    @Override // jsat.regression.Regressor
    public double regress(DataPoint dataPoint) {
        return kEvalSum(dataPoint.getNumericalValues()) - this.b;
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet, ExecutorService executorService) {
        train(regressionDataSet, null, executorService);
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet) {
        train(regressionDataSet, (ExecutorService) null);
    }

    @Override // jsat.regression.Regressor
    public LSSVM clone() {
        return new LSSVM(this);
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }

    private void mainLoop(DataSet dataSet, LSSVM lssvm, double[] dArr, ExecutorService executorService) {
        int i;
        if (executorService == null || (executorService instanceof FakeExecutor)) {
            executorService = new FakeExecutor();
            i = 1;
        } else {
            i = SystemInfo.LogicalCores;
        }
        try {
            this.vecs = dataSet.getDataVectors();
            initializeVariables(dArr, lssvm, dataSet);
            boolean z = true;
            double computeDualityGap = computeDualityGap(true, executorService, i);
            int i2 = 0;
            while (computeDualityGap > tol * this.dualObjective && z) {
                z = takeStep(this.i_up, this.i_low, executorService, i);
                computeDualityGap = computeDualityGap(true, executorService, i);
                i2++;
            }
            setCacheMode(null);
            setAlphas(this.alphas);
        } catch (InterruptedException e) {
            throw new FailedToFitException(e);
        } catch (ExecutionException e2) {
            throw new FailedToFitException(e2);
        }
    }

    public static Distribution guessC(DataSet dataSet) {
        return PlattSMO.guessC(dataSet);
    }
}
