package jsat.classifiers.svm;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.classifiers.svm.SupportVectorLearner;
import jsat.distributions.kernels.KernelTrick;
import jsat.distributions.kernels.NormalizedKernel;
import jsat.exceptions.FailedToFitException;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.Vec;
import jsat.utils.FakeExecutor;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.AtomicDouble;
import jsat.utils.concurrent.ParallelUtils;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/svm/SVMnoBias.class */
public class SVMnoBias extends SupportVectorLearner implements BinaryScoreClassifier {
    private double C;
    private double tolerance;
    protected short[] label;
    protected Vec weights;
    private double T_a;
    private double S_a;

    public SVMnoBias(KernelTrick kernelTrick) {
        super(kernelTrick, SupportVectorLearner.CacheMode.NONE);
        this.C = 1.0d;
        this.tolerance = 0.001d;
    }

    public SVMnoBias(SVMnoBias sVMnoBias) {
        super(sVMnoBias);
        this.C = 1.0d;
        this.tolerance = 0.001d;
        if (sVMnoBias.weights != null) {
            this.weights = sVMnoBias.weights.mo525clone();
        }
        if (sVMnoBias.label != null) {
            this.label = Arrays.copyOf(sVMnoBias.label, sVMnoBias.label.length);
        }
        this.C = sVMnoBias.C;
        this.tolerance = sVMnoBias.tolerance;
    }

    @Override // jsat.classifiers.svm.SupportVectorLearner
    public void setKernel(KernelTrick kernelTrick) {
        if (kernelTrick.normalized()) {
            super.setKernel(kernelTrick);
        } else {
            super.setKernel(new NormalizedKernel(kernelTrick));
        }
    }

    @Override // jsat.classifiers.calibration.BinaryScoreClassifier
    public double getScore(DataPoint dataPoint) {
        return kEvalSum(dataPoint.getNumericalValues());
    }

    @Override // jsat.classifiers.Classifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public SVMnoBias m579clone() {
        return new SVMnoBias(this);
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        if (this.vecs == null) {
            throw new UntrainedModelException("Classifier has yet to be trained");
        }
        CategoricalResults categoricalResults = new CategoricalResults(2);
        if (getScore(dataPoint) > 0.0d) {
            categoricalResults.setProb(1, 1.0d);
        } else {
            categoricalResults.setProb(0, 1.0d);
        }
        return categoricalResults;
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        trainC(classificationDataSet, new FakeExecutor());
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        bookKeepingInit(classificationDataSet);
        solver_1d(procedure3_init(), executorService);
        setCacheMode(null);
    }

    protected void trainC(ClassificationDataSet classificationDataSet, double[] dArr) {
        trainC(classificationDataSet, dArr, new FakeExecutor());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void trainC(ClassificationDataSet classificationDataSet, double[] dArr, ExecutorService executorService) {
        bookKeepingInit(classificationDataSet);
        for (int i = 0; i < this.alphas.length; i++) {
            this.alphas[i] = Math.abs(dArr[i]);
        }
        solver_1d(procedure4m_init(executorService), executorService);
        setCacheMode(null);
    }

    private void solver_1d(final double[] dArr, ExecutorService executorService) {
        int i = executorService instanceof FakeExecutor ? 1 : SystemInfo.LogicalCores;
        final int length = this.alphas.length;
        double d = 1.0d / ((2.0d * this.C) * length);
        while (this.S_a > this.tolerance / (2.0d * d)) {
            double d2 = -1.0d;
            int i2 = -1;
            double d3 = -1.0d;
            for (int i3 = 0; i3 < length; i3++) {
                double max = Math.max(Math.min(this.weights.get(i3) * this.C, dArr[i3] + this.alphas[i3]), 0.0d) - this.alphas[i3];
                double d4 = max * (dArr[i3] - (max / 2.0d));
                if (d4 >= d2) {
                    d2 = d4;
                    i2 = i3;
                    d3 = max;
                }
            }
            double[] dArr2 = this.alphas;
            int i4 = i2;
            dArr2[i4] = dArr2[i4] + d3;
            if (this.alphas[i2] + 1.0E-7d > this.weights.get(i2) * this.C) {
                this.alphas[i2] = this.weights.get(i2) * this.C;
            } else if (this.alphas[i2] - 1.0E-7d < 0.0d) {
                this.alphas[i2] = 0.0d;
            }
            final double d5 = d3;
            final int i5 = i2;
            this.T_a -= d3 * (((2.0d * dArr[i2]) - 1.0d) - d3);
            double d6 = 0.0d;
            ArrayList arrayList = new ArrayList(i);
            accessingRow(i5);
            for (int i6 = 0; i6 < i; i6++) {
                final int i7 = i6;
                final int i8 = i;
                arrayList.add(executorService.submit(new Callable<Double>() { // from class: jsat.classifiers.svm.SVMnoBias.1
                    /* JADX WARN: Can't rename method to resolve collision */
                    @Override // java.util.concurrent.Callable
                    public Double call() throws Exception {
                        double d7 = 0.0d;
                        int startBlock = ParallelUtils.getStartBlock(length, i7, i8);
                        int endBlock = ParallelUtils.getEndBlock(length, i7, i8);
                        for (int i9 = startBlock; i9 < endBlock; i9++) {
                            double[] dArr3 = dArr;
                            int i10 = i9;
                            dArr3[i10] = dArr3[i10] - (((d5 * SVMnoBias.this.label[i5]) * SVMnoBias.this.label[i9]) * SVMnoBias.this.kEval(i5, i9));
                            d7 += SVMnoBias.this.weights.get(i9) * SVMnoBias.this.C * Math.min(Math.max(0.0d, dArr[i9]), 2.0d);
                        }
                        return Double.valueOf(d7);
                    }
                }));
            }
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                try {
                    d6 += ((Double) ((Future) it.next()).get()).doubleValue();
                } catch (InterruptedException e) {
                    throw new FailedToFitException(e);
                } catch (ExecutionException e2) {
                    throw new FailedToFitException(e2);
                }
            }
            this.S_a = this.T_a + d6;
        }
        accessingRow(-1);
        for (int i9 = 0; i9 < this.label.length; i9++) {
            double[] dArr3 = this.alphas;
            int i10 = i9;
            dArr3[i10] = dArr3[i10] * this.label[i9];
        }
    }

    private double[] procedure3_init() {
        int length = this.alphas.length;
        this.T_a = 0.0d;
        this.S_a = 0.0d;
        double[] dArr = new double[length];
        for (int i = 0; i < length; i++) {
            dArr[i] = 1.0d;
            this.S_a += this.weights.get(i) * this.C;
        }
        return dArr;
    }

    private double[] procedure4m_init(ExecutorService executorService) {
        int i = executorService instanceof FakeExecutor ? 1 : SystemInfo.LogicalCores;
        final int length = this.alphas.length;
        this.T_a = 0.0d;
        final AtomicDouble atomicDouble = new AtomicDouble(0.0d);
        final double[] dArr = new double[length];
        ArrayList arrayList = new ArrayList(i);
        for (int i2 = 0; i2 < i; i2++) {
            final int i3 = i2;
            final int i4 = i;
            arrayList.add(executorService.submit(new Callable<Double>() { // from class: jsat.classifiers.svm.SVMnoBias.2
                /* JADX WARN: Can't rename method to resolve collision */
                @Override // java.util.concurrent.Callable
                public Double call() throws Exception {
                    double d = 0.0d;
                    double d2 = 0.0d;
                    int startBlock = ParallelUtils.getStartBlock(length, i3, i4);
                    int endBlock = ParallelUtils.getEndBlock(length, i3, i4);
                    for (int i5 = startBlock; i5 < endBlock; i5++) {
                        dArr[i5] = 1.0d;
                        double d3 = 0.0d;
                        for (int i6 = 0; i6 < length; i6++) {
                            if (SVMnoBias.this.alphas[i6] != 0.0d) {
                                d3 -= ((SVMnoBias.this.alphas[i6] * SVMnoBias.this.label[i5]) * SVMnoBias.this.label[i6]) * (SVMnoBias.this.getCacheMode() == SupportVectorLearner.CacheMode.FULL ? SVMnoBias.this.kEval(i5, i6) : SVMnoBias.this.k(i5, i6));
                            }
                        }
                        double[] dArr2 = dArr;
                        int i7 = i5;
                        dArr2[i7] = dArr2[i7] + d3;
                        d -= SVMnoBias.this.alphas[i5] * dArr[i5];
                        d2 += SVMnoBias.this.weights.get(i5) * SVMnoBias.this.C * Math.min(Math.max(dArr[i5], 0.0d), 2.0d);
                    }
                    atomicDouble.addAndGet(d2);
                    return Double.valueOf(d);
                }
            }));
        }
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            try {
                this.T_a += ((Double) ((Future) it.next()).get()).doubleValue();
            } catch (InterruptedException e) {
                throw new FailedToFitException(e);
            } catch (ExecutionException e2) {
                throw new FailedToFitException(e2);
            }
        }
        this.S_a = this.T_a + atomicDouble.get();
        return dArr;
    }

    private void bookKeepingInit(ClassificationDataSet classificationDataSet) {
        int sampleSize = classificationDataSet.getSampleSize();
        this.vecs = classificationDataSet.getDataVectors();
        this.weights = classificationDataSet.getDataWeights();
        this.label = new short[sampleSize];
        for (int i = 0; i < sampleSize; i++) {
            this.label[i] = (short) ((classificationDataSet.getDataPointCategory(i) * 2) - 1);
        }
        setCacheMode(getCacheMode());
        this.alphas = new double[sampleSize];
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return true;
    }

    public void setC(double d) {
        if (d <= 0.0d) {
            throw new ArithmeticException("C must be a positive constant");
        }
        this.C = d;
    }

    public double getC() {
        return this.C;
    }

    public void setTolerance(double d) {
        this.tolerance = d;
    }

    public double getTolerance() {
        return this.tolerance;
    }
}
