package jsat.classifiers.linear.kernelized;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.DataPoint;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.distributions.kernels.KernelTrick;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.Matrix;
import jsat.linear.SubMatrix;
import jsat.linear.Vec;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.DoubleList;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/linear/kernelized/Projectron.class */
public class Projectron extends BaseUpdateableClassifier implements BinaryScoreClassifier, Parameterized {
    private static final long serialVersionUID = -4025799790045954359L;

    @Parameter.ParameterHolder
    private KernelTrick k;
    private double eta;
    private DoubleList alpha;
    private List<Vec> S;
    private List<Double> cacheAccel;
    private Matrix InvK;
    private Matrix InvKExpanded;
    private double[] k_raw;
    private boolean useMarginUpdates;

    public Projectron(KernelTrick kernelTrick) {
        this(kernelTrick, 0.1d);
    }

    public Projectron(KernelTrick kernelTrick, double d) {
        this(kernelTrick, d, true);
    }

    public Projectron(KernelTrick kernelTrick, double d, boolean z) {
        setKernel(kernelTrick);
        setEta(d);
        setUseMarginUpdates(z);
    }

    protected Projectron(Projectron projectron) {
        this.k = projectron.k.m628clone();
        this.eta = projectron.eta;
        if (projectron.S != null) {
            this.alpha = new DoubleList(projectron.alpha);
            this.S = new ArrayList(projectron.S);
            this.cacheAccel = new DoubleList(projectron.cacheAccel);
            this.InvKExpanded = projectron.InvKExpanded.mo640clone();
            this.InvK = new SubMatrix(this.InvKExpanded, 0, 0, projectron.InvK.rows(), projectron.InvK.cols());
            this.k_raw = Arrays.copyOf(projectron.k_raw, projectron.k_raw.length);
        }
    }

    public void setKernel(KernelTrick kernelTrick) {
        this.k = kernelTrick;
    }

    public KernelTrick getKernel() {
        return this.k;
    }

    public void setEta(double d) {
        if (Double.isNaN(d) || Double.isInfinite(d) || d < 0.0d) {
            throw new IllegalArgumentException("eta must be in the range [0, Infity), not " + d);
        }
        this.eta = d;
    }

    public double getEta() {
        return this.eta;
    }

    public void setUseMarginUpdates(boolean z) {
        this.useMarginUpdates = z;
    }

    public boolean isUseMarginUpdates() {
        return this.useMarginUpdates;
    }

    @Override // jsat.classifiers.BaseUpdateableClassifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public Projectron mo479clone() {
        return new Projectron(this);
    }

    @Override // jsat.classifiers.UpdateableClassifier
    public void setUp(CategoricalData[] categoricalDataArr, int i, CategoricalData categoricalData) {
        if (i < 1) {
            throw new IllegalArgumentException("Projectrion requires numeric features");
        }
        if (categoricalData.getNumOfCategories() != 2) {
            throw new FailedToFitException("Projectrion only supports binary classification");
        }
        this.alpha = new DoubleList(50);
        this.cacheAccel = new DoubleList(50);
        this.S = new ArrayList(50);
        this.InvKExpanded = new DenseMatrix(50, 50);
        this.k_raw = new double[50];
    }

    @Override // jsat.classifiers.UpdateableClassifier
    public void update(DataPoint dataPoint, int i) {
        Vec numericalValues = dataPoint.getNumericalValues();
        List<Double> queryInfo = this.k.getQueryInfo(numericalValues);
        double score = getScore(numericalValues, queryInfo, this.k_raw);
        double d = (i * 2) - 1;
        if (this.S.isEmpty()) {
            this.InvK = new SubMatrix(this.InvKExpanded, 0, 0, 1, 1);
            this.InvK.set(0, 0, 1.0d);
            this.S.add(numericalValues);
            this.alpha.add(d);
            this.cacheAccel.addAll(queryInfo);
            return;
        }
        if (d * score > 1.0d) {
            return;
        }
        if (d * score >= 1.0d || d * score <= 0.0d || this.useMarginUpdates) {
            DenseVector denseVector = new DenseVector(this.k_raw, 0, this.S.size());
            Vec multiply = this.InvK.multiply(denseVector);
            double eval = this.k.eval(0, 0, Arrays.asList(numericalValues), queryInfo);
            double dot = denseVector.dot(multiply);
            double max = Math.max(eval - dot, 0.0d);
            double sqrt = Math.sqrt(max);
            if (Math.signum(score) == d) {
                if (d * score <= 1.0d) {
                    double d2 = 1.0d - (d * score);
                    if (d2 < sqrt / this.eta) {
                        return;
                    }
                    double max2 = Math.max(Math.max(d2 / dot, (2.0d * (d2 - (sqrt / this.eta))) / dot), 1.0d);
                    for (int i2 = 0; i2 < this.S.size(); i2++) {
                        this.alpha.set(i2, this.alpha.get(i2).doubleValue() + (d * max2 * multiply.get(i2)));
                    }
                    return;
                }
                return;
            }
            if (sqrt < this.eta) {
                for (int i3 = 0; i3 < this.S.size(); i3++) {
                    this.alpha.set(i3, this.alpha.get(i3).doubleValue() + (d * multiply.get(i3)));
                }
                return;
            }
            if (this.S.size() == this.InvKExpanded.rows()) {
                this.InvKExpanded = new DenseMatrix(this.S.size() * 2, this.S.size() * 2);
                for (int i4 = 0; i4 < this.InvK.rows(); i4++) {
                    for (int i5 = 0; i5 < this.InvK.cols(); i5++) {
                        this.InvKExpanded.set(i4, i5, this.InvK.get(i4, i5));
                    }
                }
                this.InvK = new SubMatrix(this.InvKExpanded, 0, 0, this.S.size(), this.S.size());
                this.k_raw = Arrays.copyOf(this.k_raw, this.S.size() * 2);
            }
            this.InvK = new SubMatrix(this.InvKExpanded, 0, 0, this.S.size() + 1, this.S.size() + 1);
            DenseVector denseVector2 = new DenseVector(this.S.size() + 1);
            for (int i6 = 0; i6 < multiply.length(); i6++) {
                denseVector2.set(i6, multiply.get(i6));
            }
            denseVector2.set(this.S.size(), -1.0d);
            if (max > 0.0d) {
                Matrix.OuterProductUpdate(this.InvK, denseVector2, denseVector2, 1.0d / max);
            }
            this.S.add(numericalValues);
            this.alpha.add(d);
            this.cacheAccel.addAll(queryInfo);
        }
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        CategoricalResults categoricalResults = new CategoricalResults(2);
        if (getScore(dataPoint) < 0.0d) {
            categoricalResults.setProb(0, 1.0d);
        } else {
            categoricalResults.setProb(1, 1.0d);
        }
        return categoricalResults;
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    private double getScore(Vec vec, List<Double> list, double[] dArr) {
        double d = 0.0d;
        for (int i = 0; i < this.S.size(); i++) {
            double eval = this.k.eval(i, vec, list, this.S, this.cacheAccel);
            if (dArr != null) {
                dArr[i] = eval;
            }
            d += this.alpha.get(i).doubleValue() * eval;
        }
        return d;
    }

    @Override // jsat.classifiers.calibration.BinaryScoreClassifier
    public double getScore(DataPoint dataPoint) {
        return this.k.evalSum(this.S, this.cacheAccel, this.alpha.getBackingArray(), dataPoint.getNumericalValues(), 0, this.S.size());
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }
}
