package jsat.classifiers.linear;

import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ExecutorService;
import jsat.SingleWeightVectorModel;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.DataPoint;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.Vec;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.BaseUpdateableRegressor;
import jsat.regression.RegressionDataSet;
import jsat.regression.UpdateableRegressor;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/linear/STGD.class */
public class STGD extends BaseUpdateableClassifier implements UpdateableRegressor, BinaryScoreClassifier, Parameterized, SingleWeightVectorModel {
    private static final long serialVersionUID = 5753298014967370769L;
    private Vec w;
    private int K;
    private double learningRate;
    private double threshold;
    private double gravity;
    private int time;
    private int[] t;

    public STGD(int i, double d, double d2, double d3) {
        setK(i);
        setLearningRate(d);
        setThreshold(d2);
        setGravity(d3);
    }

    protected STGD(STGD stgd) {
        if (stgd.w != null) {
            this.w = stgd.w.mo524clone();
        }
        this.K = stgd.K;
        this.learningRate = stgd.learningRate;
        this.threshold = stgd.threshold;
        this.gravity = stgd.gravity;
        this.time = stgd.time;
        if (stgd.t != null) {
            this.t = Arrays.copyOf(stgd.t, stgd.t.length);
        }
    }

    public void setK(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("K must be positive, not " + i);
        }
        this.K = i;
    }

    public int getK() {
        return this.K;
    }

    public void setLearningRate(double d) {
        if (Double.isInfinite(d) || Double.isNaN(d) || d <= 0.0d) {
            throw new IllegalArgumentException("Learning rate must be positive, not " + d);
        }
        this.learningRate = d;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public void setThreshold(double d) {
        if (Double.isNaN(d) || d <= 0.0d) {
            throw new IllegalArgumentException("Threshold must be positive, not " + d);
        }
        this.threshold = d;
    }

    public double getThreshold() {
        return this.threshold;
    }

    public void setGravity(double d) {
        if (Double.isInfinite(d) || Double.isNaN(d) || d <= 0.0d) {
            throw new IllegalArgumentException("Gravity must be positive, not " + d);
        }
        this.gravity = d;
    }

    public double getGravity() {
        return this.gravity;
    }

    @Override // jsat.SingleWeightVectorModel
    public Vec getRawWeight() {
        return this.w;
    }

    @Override // jsat.SingleWeightVectorModel
    public double getBias() {
        return 0.0d;
    }

    @Override // jsat.SimpleWeightVectorModel
    public Vec getRawWeight(int i) {
        if (i < 1) {
            return getRawWeight();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override // jsat.SimpleWeightVectorModel
    public double getBias(int i) {
        if (i < 1) {
            return getBias();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override // jsat.SimpleWeightVectorModel
    public int numWeightsVecs() {
        return 1;
    }

    @Override // jsat.classifiers.BaseUpdateableClassifier
    /* renamed from: clone */
    public STGD mo479clone() {
        return new STGD(this);
    }

    @Override // jsat.classifiers.UpdateableClassifier
    public void setUp(CategoricalData[] categoricalDataArr, int i, CategoricalData categoricalData) {
        if (categoricalData.getNumOfCategories() != 2) {
            throw new FailedToFitException("STGD supports only binary classification");
        }
        setUp(categoricalDataArr, i);
    }

    @Override // jsat.regression.UpdateableRegressor
    public void setUp(CategoricalData[] categoricalDataArr, int i) {
        if (i < 1) {
            throw new FailedToFitException("STGD requires numeric features");
        }
        this.w = new DenseVector(i);
        this.t = new int[i];
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet, ExecutorService executorService) {
        train(regressionDataSet);
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet) {
        BaseUpdateableRegressor.trainEpochs(regressionDataSet, this, getEpochs());
    }

    private static double T(double d, double d2, double d3) {
        return (d < 0.0d || d > d3) ? (d > 0.0d || d < (-d3)) ? d : Math.min(0.0d, d + d2) : Math.max(0.0d, d - d2);
    }

    @Override // jsat.classifiers.UpdateableClassifier
    public void update(DataPoint dataPoint, int i) {
        this.time++;
        Vec numericalValues = dataPoint.getNumericalValues();
        int i2 = (i * 2) - 1;
        int signum = (int) Math.signum(this.w.dot(numericalValues));
        if (signum == i2) {
            return;
        }
        performUpdate(numericalValues, i2, signum);
    }

    @Override // jsat.regression.UpdateableRegressor
    public void update(DataPoint dataPoint, double d) {
        this.time++;
        Vec numericalValues = dataPoint.getNumericalValues();
        performUpdate(numericalValues, d, this.w.dot(numericalValues));
    }

    private void performUpdate(Vec vec, double d, double d2) {
        Iterator<IndexValue> it = vec.iterator();
        while (it.hasNext()) {
            IndexValue next = it.next();
            int index = next.getIndex();
            this.w.set(index, T(this.w.get(index) + (2.0d * this.learningRate * (d - d2) * next.getValue()), ((this.time - this.t[index]) / this.K) * this.gravity * this.learningRate, this.threshold));
            int[] iArr = this.t;
            iArr[index] = iArr[index] + (((this.time - this.t[index]) / this.K) * this.K);
        }
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        CategoricalResults categoricalResults = new CategoricalResults(2);
        if (getScore(dataPoint) > 0.0d) {
            categoricalResults.setProb(1, 1.0d);
        } else {
            categoricalResults.setProb(0, 1.0d);
        }
        return categoricalResults;
    }

    @Override // jsat.regression.Regressor
    public double regress(DataPoint dataPoint) {
        return getScore(dataPoint);
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    @Override // jsat.classifiers.calibration.BinaryScoreClassifier
    public double getScore(DataPoint dataPoint) {
        return this.w.dot(dataPoint.getNumericalValues());
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }
}
