package jsat.classifiers.linear.kernelized;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ExecutorService;
import jsat.DataSet;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.linear.kernelized.CSKLR;
import jsat.classifiers.svm.SupportVectorLearner;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.distributions.kernels.KernelTrick;
import jsat.exceptions.FailedToFitException;
import jsat.linear.Vec;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.IntList;
import jsat.utils.ListUtils;
import jsat.utils.random.XORWOW;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/linear/kernelized/CSKLRBatch.class */
public class CSKLRBatch extends SupportVectorLearner implements Parameterized, Classifier {
    private static final long serialVersionUID = -2305532659182911285L;
    private double eta;
    private double curNorm;
    private double R;
    private int T;
    private CSKLR.UpdateMode mode;
    protected double gamma;
    private int epochs;

    public CSKLRBatch(double d, KernelTrick kernelTrick, double d2, CSKLR.UpdateMode updateMode, SupportVectorLearner.CacheMode cacheMode) {
        super(kernelTrick, cacheMode);
        this.R = 10.0d;
        this.T = 0;
        this.gamma = 2.0d;
        this.epochs = 10;
        setEta(d);
        setR(d2);
        setMode(updateMode);
    }

    protected CSKLRBatch(CSKLRBatch cSKLRBatch) {
        super(cSKLRBatch);
        this.R = 10.0d;
        this.T = 0;
        this.gamma = 2.0d;
        this.epochs = 10;
        this.curNorm = cSKLRBatch.curNorm;
        this.epochs = cSKLRBatch.epochs;
        this.eta = cSKLRBatch.eta;
        this.R = cSKLRBatch.R;
        this.T = cSKLRBatch.T;
        this.mode = cSKLRBatch.mode;
        this.gamma = cSKLRBatch.gamma;
    }

    @Override // jsat.classifiers.Classifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public CSKLRBatch m540clone() {
        return new CSKLRBatch(this);
    }

    public void setEpochs(int i) {
        this.epochs = i;
    }

    public int getEpochs() {
        return this.epochs;
    }

    public void setEta(double d) {
        if (d < 0.0d || Double.isNaN(d) || Double.isInfinite(d)) {
            throw new IllegalArgumentException("The learning rate should be in (0, Inf), not " + d);
        }
        this.eta = d;
    }

    public double getEta() {
        return this.eta;
    }

    public void setR(double d) {
        if (d < 0.0d || Double.isNaN(d) || Double.isInfinite(d)) {
            throw new IllegalArgumentException("The max norm should be in (0, Inf), not " + d);
        }
        this.R = d;
    }

    public double getR() {
        return this.R;
    }

    public void setMode(CSKLR.UpdateMode updateMode) {
        this.mode = updateMode;
    }

    public CSKLR.UpdateMode getMode() {
        return this.mode;
    }

    public void setGamma(double d) {
        if (d < 0.0d || Double.isNaN(d) || Double.isInfinite(d)) {
            throw new IllegalArgumentException("Gamma must be in (0, Infity), not " + d);
        }
        this.gamma = d;
    }

    public double getGamma() {
        return this.gamma;
    }

    public static Distribution guessR(DataSet dataSet) {
        return new LogUniform(1.0d, 100000.0d);
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        CategoricalResults categoricalResults = new CategoricalResults(2);
        double score = CSKLR.getScore(-1.0d, getPreScore(dataPoint.getNumericalValues()));
        categoricalResults.setProb(0, score);
        categoricalResults.setProb(1, 1.0d - score);
        return categoricalResults;
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        trainC(classificationDataSet);
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        if (classificationDataSet.getClassSize() != 2) {
            throw new FailedToFitException("CSKLR supports only binary classification");
        }
        int sampleSize = classificationDataSet.getSampleSize();
        this.vecs = new ArrayList(sampleSize);
        this.alphas = new double[sampleSize];
        for (int i = 0; i < sampleSize; i++) {
            this.vecs.add(classificationDataSet.getDataPoint(i).getNumericalValues());
        }
        this.curNorm = 0.0d;
        this.T = 0;
        XORWOW xorwow = new XORWOW();
        IntList intList = new IntList(sampleSize);
        ListUtils.addRange(intList, 0, sampleSize, 1);
        setCacheMode(getCacheMode());
        for (int i2 = 0; i2 < this.epochs; i2++) {
            Collections.shuffle(intList);
            Iterator<Integer> it = intList.iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                double weight = classificationDataSet.getDataPoint(intValue).getWeight();
                double dataPointCategory = (classificationDataSet.getDataPointCategory(intValue) * 2) - 1;
                double preScore = getPreScore(this.vecs.get(intValue));
                double score = CSKLR.getScore(dataPointCategory, preScore);
                switch (this.mode) {
                    case NC:
                        break;
                    default:
                        if (xorwow.nextDouble() > this.mode.pt(dataPointCategory, score, preScore, this.eta, this.gamma)) {
                            break;
                        } else {
                            break;
                        }
                }
                double grad = (-this.eta) * dataPointCategory * this.mode.grad(dataPointCategory, score, preScore, this.gamma) * weight;
                double[] dArr = this.alphas;
                dArr[intValue] = dArr[intValue] + grad;
                this.curNorm += Math.abs(grad) * kEval(intValue, intValue);
                if (this.curNorm > this.R) {
                    double d = this.R / this.curNorm;
                    for (int i3 = 0; i3 < this.alphas.length; i3++) {
                        double[] dArr2 = this.alphas;
                        int i4 = i3;
                        dArr2[i4] = dArr2[i4] * d;
                    }
                    this.curNorm = d;
                }
            }
        }
        int i5 = 0;
        for (int i6 = 0; i6 < sampleSize; i6++) {
            if (this.alphas[i6] > 0.0d || this.alphas[i6] < 0.0d) {
                ListUtils.swap(this.vecs, i5, i6);
                int i7 = i5;
                i5++;
                this.alphas[i7] = this.alphas[i6];
            }
        }
        this.vecs = new ArrayList(this.vecs.subList(0, i5));
        this.alphas = Arrays.copyOfRange(this.alphas, 0, i5);
        setCacheMode(null);
        setAlphas(this.alphas);
    }

    private double getPreScore(Vec vec) {
        return kEvalSum(vec);
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return true;
    }
}
