package jsat.classifiers.svm;

import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import jsat.DataSet;
import jsat.SingleWeightVectorModel;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.WarmClassifier;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.distributions.Distribution;
import jsat.exceptions.FailedToFitException;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.regression.WarmRegressor;
import jsat.utils.IntList;
import jsat.utils.ListUtils;
import jsat.utils.random.XORWOW;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/svm/DCDs.class */
public class DCDs implements BinaryScoreClassifier, Regressor, Parameterized, SingleWeightVectorModel, WarmClassifier, WarmRegressor {
    private static final long serialVersionUID = -1686294187234524696L;
    private int maxIterations;
    private double tolerance;
    private Vec[] vecs;
    private double[] alpha;
    private double[] y;
    private double bias;
    private Vec w;
    private double C;
    private boolean useL1;
    private double eps;
    private boolean useBias;
    private final List<Parameter> params;
    private final Map<String, Parameter> paramMap;

    public DCDs() {
        this(10000, false);
    }

    public DCDs(int i, boolean z) {
        this(i, 0.001d, 1.0d, z);
    }

    public DCDs(int i, double d, double d2, boolean z) {
        this.eps = 0.001d;
        this.useBias = true;
        this.params = Collections.unmodifiableList(Parameter.getParamsFromMethods(this));
        this.paramMap = Parameter.toParameterMap(this.params);
        setMaxIterations(i);
        setTolerance(d);
        setC(d2);
        setUseL1(z);
    }

    public void setC(double d) {
        if (Double.isNaN(d) || Double.isInfinite(d) || d <= 0.0d) {
            throw new ArithmeticException("Penalty parameter must be a positive value, not " + d);
        }
        this.C = d;
    }

    public double getC() {
        return this.C;
    }

    public void setEps(double d) {
        if (Double.isNaN(d) || d < 0.0d || Double.isInfinite(d)) {
            throw new IllegalArgumentException("eps must be non-negative, not " + d);
        }
        this.eps = d;
    }

    public double getEps() {
        return this.eps;
    }

    public void setTolerance(double d) {
        this.tolerance = d;
    }

    public double getTolerance() {
        return this.tolerance;
    }

    public void setUseL1(boolean z) {
        this.useL1 = z;
    }

    public boolean isUseL1() {
        return this.useL1;
    }

    public void setMaxIterations(int i) {
        if (i <= 0) {
            throw new IllegalArgumentException("Number of iterations must be positive, not " + i);
        }
        this.maxIterations = i;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public void setUseBias(boolean z) {
        this.useBias = z;
    }

    public boolean isUseBias() {
        return this.useBias;
    }

    @Override // jsat.SingleWeightVectorModel
    public Vec getRawWeight() {
        return this.w;
    }

    @Override // jsat.SingleWeightVectorModel
    public double getBias() {
        return this.bias;
    }

    @Override // jsat.SimpleWeightVectorModel
    public Vec getRawWeight(int i) {
        if (i < 1) {
            return getRawWeight();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override // jsat.SimpleWeightVectorModel
    public double getBias(int i) {
        if (i < 1) {
            return getBias();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override // jsat.SimpleWeightVectorModel
    public int numWeightsVecs() {
        return 1;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        if (this.w == null) {
            throw new UntrainedModelException("The model has not been trained");
        }
        CategoricalResults categoricalResults = new CategoricalResults(2);
        if (getScore(dataPoint) < 0.0d) {
            categoricalResults.setProb(0, 1.0d);
        } else {
            categoricalResults.setProb(1, 1.0d);
        }
        return categoricalResults;
    }

    @Override // jsat.classifiers.calibration.BinaryScoreClassifier
    public double getScore(DataPoint dataPoint) {
        return this.w.dot(dataPoint.getNumericalValues()) + this.bias;
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        trainC(classificationDataSet);
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        trainC(classificationDataSet, (Classifier) null);
    }

    @Override // jsat.classifiers.WarmClassifier
    public void trainC(ClassificationDataSet classificationDataSet, Classifier classifier, ExecutorService executorService) {
        trainC(classificationDataSet, classifier);
    }

    @Override // jsat.classifiers.WarmClassifier
    public void trainC(ClassificationDataSet classificationDataSet, Classifier classifier) {
        boolean z;
        if (classificationDataSet.getClassSize() != 2) {
            throw new FailedToFitException("SVM only supports binary classificaiton problems");
        }
        this.vecs = new Vec[classificationDataSet.getSampleSize()];
        this.alpha = new double[this.vecs.length];
        this.y = new double[this.vecs.length];
        this.bias = 0.0d;
        double[] dArr = new double[this.vecs.length];
        double[] dArr2 = new double[this.vecs.length];
        double[] dArr3 = new double[this.vecs.length];
        for (int i = 0; i < classificationDataSet.getSampleSize(); i++) {
            DataPoint dataPoint = classificationDataSet.getDataPoint(i);
            this.vecs[i] = dataPoint.getNumericalValues();
            this.y[i] = (classificationDataSet.getDataPointCategory(i) * 2) - 1;
            dArr2[i] = getU(dataPoint.getWeight());
            dArr3[i] = getD(dataPoint.getWeight());
            dArr[i] = this.vecs[i].dot(this.vecs[i]) + dArr3[i];
            if (this.useBias) {
                int i2 = i;
                dArr[i2] = dArr[i2] + 1.0d;
            }
        }
        this.w = new DenseVector(this.vecs[0].length());
        IntList intList = new IntList(this.vecs.length);
        ListUtils.addRange(intList, 0, this.vecs.length, 1);
        if (classifier != null) {
            if (!(classifier instanceof DCDs)) {
                throw new FailedToFitException("Warm solution can not be used for warm start");
            }
            DCDs dCDs = (DCDs) classifier;
            if (this.alpha != null && dCDs.alpha.length != this.alpha.length) {
                throw new FailedToFitException("Warm solution could not have been trained on the same data set");
            }
            double d = this.C / dCDs.C;
            dCDs.w.copyTo(this.w);
            this.w.mutableMultiply(this.C);
            this.bias = dCDs.bias * d;
            System.arraycopy(dCDs.alpha, 0, this.alpha, 0, this.alpha.length);
            for (int i3 = 0; i3 < this.alpha.length; i3++) {
                double[] dArr4 = this.alpha;
                int i4 = i3;
                dArr4[i4] = dArr4[i4] * d;
            }
        }
        boolean z2 = false;
        XORWOW xorwow = new XORWOW();
        for (int i5 = 0; i5 < this.maxIterations; i5++) {
            Collections.shuffle(intList, xorwow);
            double d2 = Double.NEGATIVE_INFINITY;
            double d3 = Double.POSITIVE_INFINITY;
            Iterator<Integer> it = intList.iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                double dot = ((this.y[intValue] * (this.w.dot(this.vecs[intValue]) + this.bias)) - 1.0d) + (dArr3[intValue] * this.alpha[intValue]);
                double d4 = 0.0d;
                if (this.alpha[intValue] == 0.0d) {
                    if (dot > d2 && !z2) {
                        it.remove();
                    }
                    if (dot < 0.0d) {
                        d4 = dot;
                    }
                } else if (this.alpha[intValue] == dArr2[intValue]) {
                    if (dot < d3 && !z2) {
                        it.remove();
                    }
                    if (dot > 0.0d) {
                        d4 = dot;
                    }
                } else {
                    d4 = dot;
                }
                d2 = Math.max(d2, d4);
                d3 = Math.min(d3, d4);
                if (d4 != 0.0d) {
                    double d5 = this.alpha[intValue];
                    this.alpha[intValue] = Math.min(Math.max(this.alpha[intValue] - (dot / dArr[intValue]), 0.0d), dArr2[intValue]);
                    double d6 = (this.alpha[intValue] - d5) * this.y[intValue];
                    this.w.mutableAdd(d6, this.vecs[intValue]);
                    if (this.useBias) {
                        this.bias += d6;
                    }
                }
            }
            if (d2 - d3 >= this.tolerance) {
                z = d2 <= 0.0d || d3 >= 0.0d;
            } else {
                if (intList.size() == this.alpha.length) {
                    break;
                }
                intList.clear();
                ListUtils.addRange(intList, 0, this.vecs.length, 1);
                z = true;
            }
            z2 = z;
        }
        this.vecs = null;
        this.y = null;
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return true;
    }

    @Override // jsat.classifiers.WarmClassifier, jsat.regression.WarmRegressor
    public boolean warmFromSameDataOnly() {
        return true;
    }

    @Override // jsat.regression.Regressor
    /* renamed from: clone, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public DCDs mo526clone() {
        DCDs dCDs = new DCDs(this.maxIterations, this.tolerance, this.C, this.useL1);
        dCDs.bias = this.bias;
        dCDs.useBias = this.useBias;
        if (this.w != null) {
            dCDs.w = this.w.mo525clone();
        }
        if (this.alpha != null) {
            dCDs.alpha = Arrays.copyOf(this.alpha, this.alpha.length);
        }
        return dCDs;
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return this.params;
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return this.paramMap.get(str);
    }

    @Override // jsat.regression.Regressor
    public double regress(DataPoint dataPoint) {
        return this.w.dot(dataPoint.getNumericalValues()) + this.bias;
    }

    @Override // jsat.regression.WarmRegressor
    public void train(RegressionDataSet regressionDataSet, Regressor regressor, ExecutorService executorService) {
        train(regressionDataSet, regressor);
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet, ExecutorService executorService) {
        train(regressionDataSet);
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet) {
        train(regressionDataSet, (Regressor) null);
    }

    @Override // jsat.regression.WarmRegressor
    public void train(RegressionDataSet regressionDataSet, Regressor regressor) {
        double d;
        this.vecs = new Vec[regressionDataSet.getSampleSize()];
        this.alpha = new double[this.vecs.length];
        this.y = new double[this.vecs.length];
        this.bias = 0.0d;
        double[] dArr = new double[this.vecs.length];
        double[] dArr2 = new double[this.vecs.length];
        double[] dArr3 = new double[this.vecs.length];
        double d2 = 0.0d;
        for (int i = 0; i < regressionDataSet.getSampleSize(); i++) {
            DataPoint dataPoint = regressionDataSet.getDataPoint(i);
            this.vecs[i] = dataPoint.getNumericalValues();
            this.y[i] = regressionDataSet.getTargetValue(i);
            dArr2[i] = getU(dataPoint.getWeight());
            dArr3[i] = getD(dataPoint.getWeight());
            dArr[i] = this.vecs[i].dot(this.vecs[i]) + dArr3[i];
            if (this.useBias) {
                int i2 = i;
                dArr[i2] = dArr[i2] + 1.0d;
            }
            d2 += Math.abs(eq24(0.0d, (-this.y[i]) - this.eps, (-this.y[i]) + this.eps, dArr2[i]));
        }
        this.w = new DenseVector(this.vecs[0].length());
        IntList intList = new IntList(2 * this.vecs.length);
        ListUtils.addRange(intList, 0, this.vecs.length, 1);
        if (regressor != null) {
            if (!(regressor instanceof DCDs)) {
                throw new FailedToFitException("Warm solution can not be used for warm start");
            }
            DCDs dCDs = (DCDs) regressor;
            if (this.alpha != null && dCDs.alpha.length != this.alpha.length) {
                throw new FailedToFitException("Warm solution could not have been trained on the same data set");
            }
            double d3 = this.C / dCDs.C;
            dCDs.w.copyTo(this.w);
            this.w.mutableMultiply(this.C);
            this.bias = dCDs.bias * d3;
            System.arraycopy(dCDs.alpha, 0, this.alpha, 0, this.alpha.length);
            for (int i3 = 0; i3 < this.alpha.length; i3++) {
                double[] dArr4 = this.alpha;
                int i4 = i3;
                dArr4[i4] = dArr4[i4] * d3;
            }
        }
        XORWOW xorwow = new XORWOW();
        double d4 = Double.POSITIVE_INFINITY;
        for (int i5 = 0; i5 < this.maxIterations; i5++) {
            double d5 = Double.NEGATIVE_INFINITY;
            double d6 = 0.0d;
            Collections.shuffle(intList, xorwow);
            Iterator<Integer> it = intList.iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                double d7 = this.y[intValue];
                Vec vec = this.vecs[intValue];
                double dot = (-d7) + this.w.dot(vec) + this.bias + (dArr3[intValue] * this.alpha[intValue]);
                double d8 = dot + this.eps;
                double d9 = dot - this.eps;
                double eq24 = eq24(this.alpha[intValue], d9, d8, dArr2[intValue]);
                d5 = Math.max(d5, eq24);
                d6 += Math.abs(eq24);
                boolean z = false;
                if (this.alpha[intValue] == 0.0d && d9 < (-d4) && (-d4) < 0.0d && d4 < d8) {
                    z = true;
                }
                if ((this.alpha[intValue] == dArr2[intValue] && d8 < (-d4)) || (this.alpha[intValue] == (-dArr2[intValue]) && d9 > d4)) {
                    z = true;
                }
                if (z) {
                    it.remove();
                }
                double d10 = dArr[intValue];
                double d11 = d8 < d10 * this.alpha[intValue] ? (-d8) / d10 : d9 > d10 * this.alpha[intValue] ? (-d9) / d10 : -this.alpha[intValue];
                if (Math.abs(d11) >= 1.0E-14d) {
                    double max = Math.max(-dArr2[intValue], Math.min(dArr2[intValue], this.alpha[intValue] + d11));
                    this.w.mutableAdd(max - this.alpha[intValue], vec);
                    if (this.useBias) {
                        this.bias += max - this.alpha[intValue];
                    }
                    this.alpha[intValue] = max;
                }
            }
            if (d6 / d2 >= this.tolerance) {
                d = d5;
            } else {
                if (intList.size() == this.vecs.length) {
                    break;
                }
                intList.clear();
                ListUtils.addRange(intList, 0, this.vecs.length, 1);
                d = Double.POSITIVE_INFINITY;
            }
            d4 = d;
        }
        this.y = null;
        this.vecs = null;
    }

    private double getU(double d) {
        if (this.useL1) {
            return this.C * d;
        }
        return Double.POSITIVE_INFINITY;
    }

    private double getD(double d) {
        if (this.useL1) {
            return 0.0d;
        }
        return 1.0d / ((2.0d * this.C) * d);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static double eq24(double d, double d2, double d3, double d4) {
        double d5 = 0.0d;
        if (d == 0.0d) {
            if (d2 >= 0.0d) {
                d5 = d2;
            } else if (d3 <= 0.0d) {
                d5 = -d3;
            }
        } else if (d < 0.0d) {
            if (d > (-d4) || (d == (-d4) && d2 <= 0.0d)) {
                d5 = Math.abs(d2);
            }
        } else if (d < d4 || (d == d4 && d3 >= 0.0d)) {
            d5 = Math.abs(d3);
        }
        return d5;
    }

    public static Distribution guessC(DataSet dataSet) {
        return PlattSMO.guessC(dataSet);
    }
}
