package jsat.classifiers.svm;

import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ExecutorService;
import jsat.DataSet;
import jsat.SingleWeightVectorModel;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseVector;
import jsat.linear.ScaledVector;
import jsat.linear.Vec;
import jsat.linear.VecWithNorm;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.IntList;
import jsat.utils.ListUtils;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/svm/Pegasos.class */
public class Pegasos implements BinaryScoreClassifier, Parameterized, SingleWeightVectorModel {
    private static final long serialVersionUID = -2145631476467081171L;
    private int epochs;
    private double reg;
    private int batchSize;
    private boolean projectionStep;
    private Vec w;
    private double bias;
    public static final int DEFAULT_EPOCHS = 5;
    public static final double DEFAULT_REG = 1.0E-4d;
    public static final int DEFAULT_BATCH_SIZE = 1;

    public Pegasos() {
        this(5, 1.0E-4d, 1);
    }

    public Pegasos(int i, double d, int i2) {
        this.projectionStep = false;
        setEpochs(i);
        setRegularization(d);
        setBatchSize(i2);
    }

    public Pegasos(Pegasos pegasos) {
        this.projectionStep = false;
        this.epochs = pegasos.epochs;
        this.reg = pegasos.reg;
        this.batchSize = pegasos.batchSize;
        if (pegasos.w != null) {
            this.w = pegasos.w.mo525clone();
        }
        this.bias = pegasos.bias;
        this.projectionStep = pegasos.projectionStep;
    }

    public void setBatchSize(int i) {
        if (i < 1) {
            throw new ArithmeticException("At least one sample must be take at each iteration");
        }
        this.batchSize = i;
    }

    public int getBatchSize() {
        return this.batchSize;
    }

    public void setEpochs(int i) {
        if (i < 1) {
            throw new ArithmeticException("Must perform a positive number of epochs");
        }
        this.epochs = i;
    }

    public double getEpochs() {
        return this.epochs;
    }

    public void setProjectionStep(boolean z) {
        this.projectionStep = z;
    }

    public boolean isProjectionStep() {
        return this.projectionStep;
    }

    public void setRegularization(double d) {
        if (Double.isInfinite(d) || Double.isNaN(d) || d <= 0.0d) {
            throw new ArithmeticException("Pegasos requires a positive regularization cosntant");
        }
        this.reg = d;
    }

    public double getRegularization() {
        return this.reg;
    }

    @Override // jsat.SingleWeightVectorModel
    public Vec getRawWeight() {
        return this.w;
    }

    @Override // jsat.SingleWeightVectorModel
    public double getBias() {
        return this.bias;
    }

    @Override // jsat.SimpleWeightVectorModel
    public Vec getRawWeight(int i) {
        if (i < 1) {
            return getRawWeight();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override // jsat.SimpleWeightVectorModel
    public double getBias(int i) {
        if (i < 1) {
            return getBias();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override // jsat.SimpleWeightVectorModel
    public int numWeightsVecs() {
        return 1;
    }

    @Override // jsat.classifiers.Classifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public Pegasos m575clone() {
        return new Pegasos(this);
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        CategoricalResults categoricalResults = new CategoricalResults(2);
        if (getScore(dataPoint) < 0.0d) {
            categoricalResults.setProb(0, 1.0d);
        } else {
            categoricalResults.setProb(1, 1.0d);
        }
        return categoricalResults;
    }

    @Override // jsat.classifiers.calibration.BinaryScoreClassifier
    public double getScore(DataPoint dataPoint) {
        return this.w.dot(dataPoint.getNumericalValues()) + this.bias;
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        trainC(classificationDataSet);
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        if (classificationDataSet.getClassSize() != 2) {
            throw new FailedToFitException("SVM only supports binary classificaiton problems");
        }
        int sampleSize = classificationDataSet.getSampleSize();
        this.w = new DenseVector(classificationDataSet.getNumNumericalVars());
        if (this.projectionStep) {
            this.w = new VecWithNorm(this.w, 0.0d);
        }
        this.w = new ScaledVector(this.w);
        this.bias = 0.0d;
        IntList intList = new IntList(this.batchSize);
        IntList intList2 = new IntList(sampleSize);
        ListUtils.addRange(intList2, 0, sampleSize, 1);
        int i = 0;
        for (int i2 = 0; i2 < this.epochs; i2++) {
            Collections.shuffle(intList2);
            int i3 = 0;
            while (true) {
                int i4 = i3;
                if (i4 < sampleSize) {
                    i++;
                    intList.clear();
                    intList.addAll(intList2.subList(i4, Math.min(i4 + this.batchSize, sampleSize)));
                    Iterator<Integer> it = intList.iterator();
                    while (it.hasNext()) {
                        int intValue = it.next().intValue();
                        if (getSign(classificationDataSet, intValue) * (this.w.dot(getX(classificationDataSet, intValue)) + this.bias) >= 1.0d) {
                            it.remove();
                        }
                    }
                    double d = 1.0d / (this.reg * i);
                    this.w.mutableMultiply(1.0d - (d * this.reg));
                    Iterator<Integer> it2 = intList.iterator();
                    while (it2.hasNext()) {
                        int intValue2 = it2.next().intValue();
                        double sign = getSign(classificationDataSet, intValue2);
                        Vec x = getX(classificationDataSet, intValue2);
                        double d2 = (sign * d) / this.batchSize;
                        this.w.mutableAdd(d2, x);
                        this.bias += d2;
                    }
                    if (this.projectionStep) {
                        double min = Math.min(1.0d, 1.0d / (Math.sqrt(this.reg) * this.w.pNorm(2.0d)));
                        this.w.mutableMultiply(min);
                        this.bias *= min;
                    }
                    i3 = i4 + this.batchSize;
                }
            }
        }
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    private Vec getX(ClassificationDataSet classificationDataSet, int i) {
        return classificationDataSet.getDataPoint(i).getNumericalValues();
    }

    private double getSign(ClassificationDataSet classificationDataSet, int i) {
        return classificationDataSet.getDataPointCategory(i) == 1 ? 1.0d : -1.0d;
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }

    public static Distribution guessRegularization(DataSet dataSet) {
        return new LogUniform(1.0E-7d, 0.01d);
    }
}
