package jsat.classifiers.linear.kernelized;

import java.util.List;
import java.util.concurrent.ExecutorService;
import jsat.DataSet;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.classifiers.UpdateableClassifier;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.distributions.kernels.KernelPoint;
import jsat.distributions.kernels.KernelPoints;
import jsat.distributions.kernels.KernelTrick;
import jsat.distributions.kernels.RBFKernel;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.lossfunctions.LossC;
import jsat.lossfunctions.LossFunc;
import jsat.lossfunctions.LossMC;
import jsat.lossfunctions.LossR;
import jsat.lossfunctions.SoftmaxLoss;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.BaseUpdateableRegressor;
import jsat.regression.RegressionDataSet;
import jsat.regression.UpdateableRegressor;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/linear/kernelized/KernelSGD.class */
public class KernelSGD implements UpdateableClassifier, UpdateableRegressor, Parameterized {
    private static final long serialVersionUID = -4956596506787859023L;
    private LossFunc loss;

    @Parameter.ParameterHolder
    private KernelTrick kernel;
    private double lambda;
    private double eta;
    private KernelPoint.BudgetStrategy budgetStrategy;
    private int budgetSize;
    private double errorTolerance;
    private int time;
    private KernelPoint kpoint;
    private KernelPoints kpoints;
    private int epochs;

    public KernelSGD() {
        this(new SoftmaxLoss(), new RBFKernel(), 1.0E-4d, KernelPoint.BudgetStrategy.MERGE_RBF, 300);
    }

    public KernelSGD(LossFunc lossFunc, KernelTrick kernelTrick, double d, KernelPoint.BudgetStrategy budgetStrategy, int i) {
        this(lossFunc, kernelTrick, d, budgetStrategy, i, 1.0d, 0.05d);
    }

    public KernelSGD(LossFunc lossFunc, KernelTrick kernelTrick, double d, KernelPoint.BudgetStrategy budgetStrategy, int i, double d2, double d3) {
        this.epochs = 1;
        setLoss(lossFunc);
        setKernel(kernelTrick);
        setLambda(d);
        setEta(d2);
        setBudgetStrategy(budgetStrategy);
        setErrorTolerance(d3);
        setBudgetSize(i);
    }

    public KernelSGD(KernelSGD kernelSGD) {
        this.epochs = 1;
        this.loss = kernelSGD.loss.m683clone();
        this.kernel = kernelSGD.kernel.m629clone();
        this.lambda = kernelSGD.lambda;
        this.eta = kernelSGD.eta;
        this.budgetStrategy = kernelSGD.budgetStrategy;
        this.budgetSize = kernelSGD.budgetSize;
        this.errorTolerance = kernelSGD.errorTolerance;
        this.time = kernelSGD.time;
        this.epochs = kernelSGD.epochs;
        if (kernelSGD.kpoint != null) {
            this.kpoint = kernelSGD.kpoint.m626clone();
        }
        if (kernelSGD.kpoints != null) {
            this.kpoints = kernelSGD.kpoints.m628clone();
        }
    }

    public void setEpochs(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("Epochs must be a poistive constant, not " + i);
        }
        this.epochs = i;
    }

    public int getEpochs() {
        return this.epochs;
    }

    public void setLoss(LossFunc lossFunc) {
        if (lossFunc == null) {
            throw new NullPointerException("Loss may not be null");
        }
        this.loss = lossFunc;
    }

    public LossFunc getLoss() {
        return this.loss;
    }

    public void setLambda(double d) {
        if (d <= 0.0d || Double.isNaN(d) || Double.isInfinite(d)) {
            throw new IllegalArgumentException("lambda must be a positive constant, not " + d);
        }
        this.lambda = d;
    }

    public double getLambda() {
        return this.lambda;
    }

    public void setErrorTolerance(double d) {
        if (d < 0.0d || d > 1.0d || Double.isNaN(d)) {
            throw new IllegalArgumentException("Error tolerance must be in [0, 1], not " + d);
        }
        this.errorTolerance = d;
    }

    public double getErrorTolerance() {
        return this.errorTolerance;
    }

    public void setBudgetSize(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("Budgest size must be a positive constant, not " + i);
        }
        this.budgetSize = i;
    }

    public int getBudgetSize() {
        return this.budgetSize;
    }

    public void setBudgetStrategy(KernelPoint.BudgetStrategy budgetStrategy) {
        if (budgetStrategy == null) {
            throw new NullPointerException("Budget strategy must be non null");
        }
        this.budgetStrategy = budgetStrategy;
    }

    public KernelPoint.BudgetStrategy getBudgetStrategy() {
        return this.budgetStrategy;
    }

    public void setEta(double d) {
        this.eta = d;
    }

    public double getEta() {
        return this.eta;
    }

    public void setKernel(KernelTrick kernelTrick) {
        if (kernelTrick == null) {
            throw new NullPointerException("kernel trick must be non null");
        }
        this.kernel = kernelTrick;
    }

    public KernelTrick getKernel() {
        return this.kernel;
    }

    @Override // jsat.classifiers.UpdateableClassifier
    public void setUp(CategoricalData[] categoricalDataArr, int i, CategoricalData categoricalData) {
        if (!(this.loss instanceof LossC)) {
            throw new FailedToFitException("Loss in use (" + this.loss.getClass().getSimpleName() + ") does not support classification");
        }
        if (categoricalData.getNumOfCategories() == 2) {
            this.kpoint = new KernelPoint(this.kernel, this.errorTolerance);
            this.kpoint.setBudgetStrategy(this.budgetStrategy);
            this.kpoint.setErrorTolerance(this.errorTolerance);
            this.kpoint.setMaxBudget(this.budgetSize);
            this.kpoints = null;
        } else {
            if (!(this.loss instanceof LossMC)) {
                throw new FailedToFitException("Loss in use (" + this.loss.getClass().getSimpleName() + ") does not support multi-class classification");
            }
            this.kpoint = null;
            this.kpoints = new KernelPoints(this.kernel, categoricalData.getNumOfCategories(), this.errorTolerance);
            this.kpoints.setBudgetStrategy(this.budgetStrategy);
            this.kpoints.setErrorTolerance(this.errorTolerance);
            this.kpoints.setMaxBudget(this.budgetSize);
        }
        this.time = 0;
    }

    @Override // jsat.regression.UpdateableRegressor
    public void setUp(CategoricalData[] categoricalDataArr, int i) {
        if (!(this.loss instanceof LossR)) {
            throw new FailedToFitException("Loss in use (" + this.loss.getClass().getSimpleName() + ") does not support regession");
        }
        this.kpoint = new KernelPoint(this.kernel, this.errorTolerance);
        this.kpoint.setBudgetStrategy(this.budgetStrategy);
        this.kpoint.setErrorTolerance(this.errorTolerance);
        this.kpoint.setMaxBudget(this.budgetSize);
        this.kpoints = null;
        this.time = 0;
    }

    @Override // jsat.classifiers.UpdateableClassifier
    public void update(DataPoint dataPoint, int i) {
        Vec numericalValues = dataPoint.getNumericalValues();
        List<Double> queryInfo = this.kernel.getQueryInfo(numericalValues);
        double nextEta = getNextEta();
        if (this.kpoint != null) {
            this.kpoint.mutableMultiply(1.0d - (nextEta * this.lambda));
            double deriv = ((LossC) this.loss).getDeriv(this.kpoint.dot(numericalValues, queryInfo), (i * 2) - 1);
            if (deriv != 0.0d) {
                this.kpoint.mutableAdd((-nextEta) * deriv, numericalValues, queryInfo);
                return;
            }
            return;
        }
        if (this.kpoints != null) {
            this.kpoints.mutableMultiply(1.0d - (nextEta * this.lambda));
            DenseVector denseVector = new DenseVector(this.kpoints.dot(numericalValues, queryInfo));
            ((LossMC) this.loss).process(denseVector, denseVector);
            ((LossMC) this.loss).deriv(denseVector, denseVector, i);
            denseVector.mutableMultiply(-nextEta);
            this.kpoints.mutableAdd(numericalValues, denseVector, queryInfo);
        }
    }

    @Override // jsat.regression.UpdateableRegressor
    public void update(DataPoint dataPoint, double d) {
        Vec numericalValues = dataPoint.getNumericalValues();
        List<Double> queryInfo = this.kernel.getQueryInfo(numericalValues);
        double nextEta = getNextEta();
        this.kpoint.mutableMultiply(1.0d - (nextEta * this.lambda));
        double deriv = ((LossR) this.loss).getDeriv(this.kpoint.dot(numericalValues, queryInfo), d);
        if (deriv != 0.0d) {
            this.kpoint.mutableAdd((-nextEta) * deriv, numericalValues, queryInfo);
        }
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        Vec numericalValues = dataPoint.getNumericalValues();
        List<Double> queryInfo = this.kernel.getQueryInfo(numericalValues);
        if (this.kpoint != null) {
            return ((LossC) this.loss).getClassification(this.kpoint.dot(numericalValues, queryInfo));
        }
        DenseVector denseVector = new DenseVector(this.kpoints.dot(numericalValues, queryInfo));
        ((LossMC) this.loss).process(denseVector, denseVector);
        return ((LossMC) this.loss).getClassification(denseVector);
    }

    @Override // jsat.regression.Regressor
    public double regress(DataPoint dataPoint) {
        Vec numericalValues = dataPoint.getNumericalValues();
        return ((LossR) this.loss).getRegression(this.kpoint.dot(numericalValues, this.kernel.getQueryInfo(numericalValues)));
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        trainC(classificationDataSet);
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        BaseUpdateableClassifier.trainEpochs(classificationDataSet, this, this.epochs);
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet, ExecutorService executorService) {
        train(regressionDataSet);
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet) {
        BaseUpdateableRegressor.trainEpochs(regressionDataSet, this, this.epochs);
    }

    @Override // jsat.regression.Regressor
    /* renamed from: clone, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public KernelSGD mo582clone() {
        return new KernelSGD(this);
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }

    private double getNextEta() {
        double d = this.eta;
        double d2 = this.lambda;
        int i = this.time + 1;
        this.time = i;
        return d / (d2 * (i + (2.0d / this.lambda)));
    }

    public static Distribution guessLambda(DataSet dataSet) {
        return new LogUniform(1.0E-7d, 0.01d);
    }
}
