package jsat.classifiers.linear.kernelized;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import jsat.DataSet;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.DataPoint;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.distributions.kernels.KernelTrick;
import jsat.linear.Vec;
import jsat.lossfunctions.HingeLoss;
import jsat.lossfunctions.LossC;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.DoubleList;
import jsat.utils.random.XORWOW;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/linear/kernelized/BOGD.class */
public class BOGD extends BaseUpdateableClassifier implements BinaryScoreClassifier, Parameterized {
    private static final long serialVersionUID = -3547832514098781996L;

    @Parameter.ParameterHolder
    private KernelTrick k;
    private int budget;
    private double eta;
    private double reg;
    private double maxCoeff;
    private LossC lossC;
    private boolean uniformSampling;
    private Random rand;
    private List<Vec> vecs;
    private List<Double> selfK;
    private DoubleList alphas;
    private List<Double> accelCache;
    private double[] dist;

    public BOGD(KernelTrick kernelTrick, int i, double d, double d2, double d3) {
        this(kernelTrick, i, d, d2, d3, new HingeLoss());
    }

    public BOGD(KernelTrick kernelTrick, int i, double d, double d2, double d3, LossC lossC) {
        setKernel(kernelTrick);
        setBudget(i);
        setEta(d);
        setRegularization(d2);
        setMaxCoeff(d3);
        this.lossC = lossC;
        setUniformSampling(false);
    }

    public BOGD(BOGD bogd) {
        this.k = bogd.k.m629clone();
        this.budget = bogd.budget;
        this.eta = bogd.eta;
        this.reg = bogd.reg;
        this.maxCoeff = bogd.maxCoeff;
        this.lossC = bogd.lossC.m683clone();
        this.uniformSampling = bogd.uniformSampling;
        this.rand = new XORWOW();
        if (bogd.vecs != null) {
            this.vecs = new ArrayList(this.budget);
            Iterator<Vec> it = bogd.vecs.iterator();
            while (it.hasNext()) {
                this.vecs.add(it.next().mo525clone());
            }
            this.selfK = new DoubleList(bogd.selfK);
            this.alphas = new DoubleList(bogd.alphas);
        }
        if (bogd.accelCache != null) {
            this.accelCache = new DoubleList(bogd.accelCache);
        }
        if (bogd.dist != null) {
            this.dist = Arrays.copyOf(bogd.dist, bogd.dist.length);
        }
    }

    public void setRegularization(double d) {
        if (d <= 0.0d || Double.isNaN(d) || Double.isInfinite(d)) {
            throw new IllegalArgumentException("Regularization must be positive, not " + d);
        }
        this.reg = d;
    }

    public double getRegularization() {
        return this.reg;
    }

    public void setEta(double d) {
        if (d <= 0.0d || Double.isNaN(d) || Double.isInfinite(d)) {
            throw new IllegalArgumentException("Eta must be positive, not " + d);
        }
        this.eta = d;
    }

    public double getEta() {
        return this.eta;
    }

    public void setMaxCoeff(double d) {
        if (d <= 0.0d || Double.isNaN(d) || Double.isInfinite(d)) {
            throw new IllegalArgumentException("MaxCoeff must be positive, not " + d);
        }
        this.maxCoeff = d;
    }

    public double getMaxCoeff() {
        return this.maxCoeff;
    }

    public void setBudget(int i) {
        if (i <= 0) {
            throw new IllegalArgumentException("Budget must be positive, not " + i);
        }
        this.budget = i;
    }

    public int getBudget() {
        return this.budget;
    }

    public void setKernel(KernelTrick kernelTrick) {
        this.k = kernelTrick;
    }

    public KernelTrick getKernel() {
        return this.k;
    }

    public void setUniformSampling(boolean z) {
        this.uniformSampling = z;
    }

    public boolean isUniformSampling() {
        return this.uniformSampling;
    }

    @Override // jsat.classifiers.BaseUpdateableClassifier
    /* renamed from: clone */
    public BOGD mo480clone() {
        return new BOGD(this);
    }

    @Override // jsat.classifiers.UpdateableClassifier
    public void setUp(CategoricalData[] categoricalDataArr, int i, CategoricalData categoricalData) {
        this.vecs = new ArrayList(this.budget);
        this.alphas = new DoubleList(this.budget);
        this.selfK = new DoubleList(this.budget);
        if (this.k.supportsAcceleration()) {
            this.accelCache = new DoubleList(this.budget);
        } else {
            this.accelCache = null;
        }
        if (!this.uniformSampling) {
            this.dist = new double[this.budget];
        }
        this.rand = new XORWOW();
    }

    @Override // jsat.classifiers.calibration.BinaryScoreClassifier
    public double getScore(DataPoint dataPoint) {
        Vec numericalValues = dataPoint.getNumericalValues();
        return score(numericalValues, this.k.getQueryInfo(numericalValues));
    }

    private double score(Vec vec, List<Double> list) {
        return this.k.evalSum(this.vecs, this.accelCache, this.alphas.getBackingArray(), vec, list, 0, this.alphas.size());
    }

    @Override // jsat.classifiers.UpdateableClassifier
    public void update(DataPoint dataPoint, int i) {
        int i2;
        double d;
        Vec numericalValues = dataPoint.getNumericalValues();
        List<Double> queryInfo = this.k.getQueryInfo(numericalValues);
        double deriv = this.lossC.getDeriv(score(numericalValues, queryInfo), (i * 2) - 1);
        if (deriv == 0.0d) {
            this.alphas.getVecView().mutableMultiply(1.0d - (this.eta * this.reg));
            return;
        }
        if (this.vecs.size() < this.budget) {
            this.alphas.getVecView().mutableMultiply(1.0d - (this.eta * this.reg));
            this.alphas.add((-this.eta) * deriv);
            this.selfK.add(Double.valueOf(Math.sqrt(this.k.eval(0, 0, Arrays.asList(numericalValues), queryInfo))));
            if (this.k.supportsAcceleration()) {
                this.accelCache.addAll(queryInfo);
            }
            this.vecs.add(numericalValues);
            return;
        }
        if (this.uniformSampling) {
            i2 = this.rand.nextInt(this.budget);
            d = 1.0d;
        } else {
            double d2 = 0.0d;
            for (int i3 = 0; i3 < this.budget; i3++) {
                d2 += Math.abs(this.alphas.get(i3).doubleValue()) * this.selfK.get(i3).doubleValue();
            }
            double d3 = (this.budget - 1) / d2;
            double nextDouble = this.rand.nextDouble();
            double d4 = 0.0d;
            int i4 = -1;
            while (d4 < nextDouble) {
                i4++;
                double[] dArr = this.dist;
                double doubleValue = 1.0d - ((d3 * this.alphas.get(i4).doubleValue()) * this.selfK.get(i4).doubleValue());
                dArr[i4] = doubleValue;
                d4 += doubleValue;
            }
            int i5 = i4;
            int i6 = i4 + 1;
            i2 = i5;
            while (i6 < this.budget) {
                double[] dArr2 = this.dist;
                int i7 = i6;
                double doubleValue2 = d3 * this.alphas.get(i6).doubleValue();
                int i8 = i6;
                i6++;
                double doubleValue3 = 1.0d - (doubleValue2 * this.selfK.get(i8).doubleValue());
                dArr2[i7] = doubleValue3;
                d4 += doubleValue3;
            }
            d = d4;
        }
        for (int i9 = 0; i9 < this.budget; i9++) {
            if (i9 != i2) {
                double d5 = this.alphas.getD(i9);
                this.alphas.set(i9, Math.signum(d5) * Math.min(((1.0d - (this.reg * this.eta)) / (1.0d - (this.uniformSampling ? 1.0d / this.budget : this.dist[i9] / d))) * Math.abs(d5), this.maxCoeff * this.eta));
            }
        }
        if (this.k.supportsAcceleration()) {
            int size = this.accelCache.size() / this.budget;
            for (int i10 = 0; i10 < size; i10++) {
                this.accelCache.remove(i2 * size);
            }
        }
        this.alphas.remove(i2);
        this.vecs.remove(i2);
        this.selfK.remove(i2);
        this.alphas.add((-this.eta) * deriv);
        this.selfK.add(Double.valueOf(Math.sqrt(this.k.eval(0, 0, Arrays.asList(numericalValues), queryInfo))));
        this.accelCache.addAll(queryInfo);
        this.vecs.add(numericalValues);
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        Vec numericalValues = dataPoint.getNumericalValues();
        return this.lossC.getClassification(score(numericalValues, this.k.getQueryInfo(numericalValues)));
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    public static Distribution guessRegularization(DataSet dataSet) {
        double sampleSize = dataSet.getSampleSize();
        double d = sampleSize * sampleSize;
        return new LogUniform(Math.pow(2.0d, -3.0d) / d, Math.pow(2.0d, 3.0d) / d);
    }

    public static Distribution guessEta(DataSet dataSet) {
        return new LogUniform(Math.pow(2.0d, -3.0d), Math.pow(2.0d, 3.0d));
    }

    public static Distribution guessMaxCoeff(DataSet dataSet) {
        return new LogUniform(Math.pow(2.0d, 0.0d), Math.pow(2.0d, 4.0d));
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }
}
