package jsat.classifiers.linear.kernelized;

import java.util.Arrays;
import java.util.List;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.DataPoint;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.distributions.kernels.KernelTrick;
import jsat.exceptions.FailedToFitException;
import jsat.linear.Vec;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/linear/kernelized/Forgetron.class */
public class Forgetron extends BaseUpdateableClassifier implements BinaryScoreClassifier, Parameterized {
    private static final long serialVersionUID = -2631315082407427077L;

    @Parameter.ParameterHolder
    private KernelTrick K;
    private Vec[] I;
    private double[] s;
    private int size;
    private int curPos;
    private int budget;
    private double U;
    private double Bconst;
    private double Q;
    private double M;
    private boolean selfTuned;

    public Forgetron(KernelTrick kernelTrick, int i) {
        this.selfTuned = true;
        this.K = kernelTrick;
        setBudget(i);
    }

    public void setSelfTurned(boolean z) {
        this.selfTuned = z;
    }

    public boolean isSelfTuned() {
        return this.selfTuned;
    }

    protected Forgetron(Forgetron forgetron) {
        super(forgetron);
        this.selfTuned = true;
        this.K = forgetron.K.mo624clone();
        this.budget = forgetron.budget;
        this.U = forgetron.U;
        this.Bconst = forgetron.Bconst;
        this.Q = forgetron.Q;
        this.M = forgetron.M;
        this.curPos = forgetron.curPos;
        this.size = forgetron.size;
        if (forgetron.I != null) {
            this.I = new Vec[forgetron.I.length];
            for (int i = 0; i < forgetron.I.length; i++) {
                if (forgetron.I[i] != null) {
                    this.I[i] = forgetron.I[i].mo525clone();
                }
            }
        }
        if (forgetron.s != null) {
            this.s = Arrays.copyOf(forgetron.s, forgetron.s.length);
        }
    }

    public void setBudget(int i) {
        this.budget = i;
        double d = i;
        this.U = Math.sqrt((d + 1.0d) / Math.log(d + 1.0d)) / 4.0d;
        this.Bconst = Math.pow(d + 1.0d, 1.0d / ((2.0d * d) + 2.0d));
    }

    public int getBudget() {
        return this.budget;
    }

    public void setKernelTrick(KernelTrick kernelTrick) {
        this.K = kernelTrick;
    }

    public KernelTrick getKernelTrick() {
        return this.K;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        CategoricalResults categoricalResults = new CategoricalResults(2);
        categoricalResults.setProb((int) ((Math.signum(getScore(dataPoint)) + 1.0d) / 2.0d), 1.0d);
        return categoricalResults;
    }

    @Override // jsat.classifiers.calibration.BinaryScoreClassifier
    public double getScore(DataPoint dataPoint) {
        return classify(dataPoint.getNumericalValues());
    }

    private double classify(Vec vec) {
        double d = 0.0d;
        for (int i = 0; i < this.size; i++) {
            d += this.s[i] * this.K.eval(this.I[i], vec);
        }
        return d;
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    @Override // jsat.classifiers.BaseUpdateableClassifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public Forgetron mo480clone() {
        return new Forgetron(this);
    }

    @Override // jsat.classifiers.UpdateableClassifier
    public void setUp(CategoricalData[] categoricalDataArr, int i, CategoricalData categoricalData) {
        if (categoricalData.getNumOfCategories() != 2) {
            throw new FailedToFitException("Forgetron only supports binary classification");
        }
        if (i == 0) {
            throw new FailedToFitException("Forgetron requires numeric attributes");
        }
        this.I = new Vec[this.budget];
        this.s = new double[this.budget];
        this.M = 0.0d;
        this.Q = 0.0d;
        this.size = 0;
        this.curPos = 0;
    }

    private double psi(double d, double d2) {
        return ((d * d) + (2.0d * d)) - ((2.0d * d) * d2);
    }

    @Override // jsat.classifiers.UpdateableClassifier
    public void update(DataPoint dataPoint, int i) {
        Vec numericalValues = dataPoint.getNumericalValues();
        double classify = classify(numericalValues);
        double d = (i * 2) - 1;
        if (d * classify > 0.0d) {
            return;
        }
        this.M += 1.0d;
        if (!this.selfTuned) {
            double d2 = 1.0d;
            if (this.size > 0) {
                for (int i2 = 0; i2 < this.size; i2++) {
                    d2 += Math.pow(this.s[i2], 2.0d) * this.K.eval(this.I[i2], this.I[i2]);
                }
            }
            double min = Math.min(this.Bconst, this.U / Math.sqrt(d2));
            this.I[this.curPos] = numericalValues;
            this.s[this.curPos] = d;
            if (this.size < this.budget) {
                this.size++;
            }
            for (int i3 = 0; i3 < this.size; i3++) {
                double[] dArr = this.s;
                int i4 = i3;
                dArr[i4] = dArr[i4] * min;
            }
        } else if (this.size + 1 <= this.budget) {
            this.size++;
            this.I[this.curPos] = numericalValues;
            this.s[this.curPos] = d;
        } else {
            int i5 = this.curPos;
            double classify2 = classify(this.I[i5]) + (d * this.K.eval(numericalValues, this.I[i5]));
            double abs = Math.abs(this.s[i5]);
            double signum = Math.signum(this.s[i5]);
            double d3 = (abs * abs) - (((2.0d * signum) * abs) * classify2);
            double d4 = 2.0d * abs;
            double d5 = this.Q - (0.46875d * this.M);
            double d6 = (d4 * d4) - ((4.0d * d3) * d5);
            double min2 = (d3 > 0.0d || (d3 < 0.0d && d6 > 0.0d && ((-d4) - Math.sqrt(d6)) / (2.0d * d3) > 1.0d)) ? Math.min(1.0d, ((-d4) + Math.sqrt(d6)) / (2.0d * d3)) : Math.abs(d3) <= 1.0E-13d ? Math.min(1.0d, (-d5) / d4) : 1.0d;
            this.Q += psi(min2 * abs, signum * min2 * classify2);
            this.I[this.curPos] = numericalValues;
            this.s[this.curPos] = d;
            if (min2 != 1.0d) {
                for (int i6 = 0; i6 < this.s.length; i6++) {
                    double[] dArr2 = this.s;
                    int i7 = i6;
                    dArr2[i7] = dArr2[i7] * min2;
                }
            }
        }
        this.curPos = (this.curPos + 1) % this.I.length;
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }
}
