package jsat.classifiers.linear;

import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ExecutorService;
import jsat.SingleWeightVectorModel;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.Vec;
import jsat.lossfunctions.LossC;
import jsat.lossfunctions.LossFunc;
import jsat.lossfunctions.LossR;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.utils.random.XORWOW;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/linear/SCD.class */
public class SCD implements Classifier, Regressor, Parameterized, SingleWeightVectorModel {
    private static final long serialVersionUID = 3576901723216525618L;
    private Vec w;
    private LossFunc loss;
    private double reg;
    private int iterations;

    public SCD(LossFunc lossFunc, double d, int i) {
        double deriv2Max = lossFunc.getDeriv2Max();
        if (Double.isNaN(deriv2Max) || Double.isInfinite(deriv2Max) || deriv2Max <= 0.0d) {
            throw new IllegalArgumentException("SCD needs a loss function with a finite positive maximal second derivative");
        }
        this.loss = lossFunc;
        setRegularization(d);
        setIterations(i);
    }

    public SCD(SCD scd) {
        this(scd.loss.m685clone(), scd.reg, scd.iterations);
        if (scd.w != null) {
            this.w = scd.w.mo524clone();
        }
    }

    public void setIterations(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("The iterations must be a positive value, not " + i);
        }
        this.iterations = i;
    }

    public int getIterations() {
        return this.iterations;
    }

    public void setRegularization(double d) {
        if (Double.isInfinite(d) || Double.isNaN(d) || d <= 0.0d) {
            throw new IllegalArgumentException("Regularization must be a positive value");
        }
        this.reg = d;
    }

    public double getRegularization() {
        return this.reg;
    }

    @Override // jsat.SingleWeightVectorModel
    public Vec getRawWeight() {
        return this.w;
    }

    @Override // jsat.SingleWeightVectorModel
    public double getBias() {
        return 0.0d;
    }

    @Override // jsat.SimpleWeightVectorModel
    public Vec getRawWeight(int i) {
        if (i < 1) {
            return getRawWeight();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override // jsat.SimpleWeightVectorModel
    public double getBias(int i) {
        if (i < 1) {
            return getBias();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override // jsat.SimpleWeightVectorModel
    public int numWeightsVecs() {
        return 1;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        if (this.w == null || !(this.loss instanceof LossC)) {
            throw new UntrainedModelException("Model was not trained with a classification function");
        }
        return ((LossC) this.loss).getClassification(this.w.dot(dataPoint.getNumericalValues()));
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        trainC(classificationDataSet);
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        double[] dArr = new double[classificationDataSet.getSampleSize()];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = (classificationDataSet.getDataPointCategory(i) * 2) - 1;
        }
        train(classificationDataSet.getNumericColumns(), dArr);
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    @Override // jsat.regression.Regressor
    public double regress(DataPoint dataPoint) {
        if (this.w == null || !(this.loss instanceof LossR)) {
            throw new UntrainedModelException("Model was not trained with a classification function");
        }
        return ((LossR) this.loss).getRegression(this.w.dot(dataPoint.getNumericalValues()));
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet, ExecutorService executorService) {
        train(regressionDataSet);
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet) {
        train(regressionDataSet.getNumericColumns(), regressionDataSet.getTargetValues().arrayCopy());
    }

    private void train(Vec[] vecArr, double[] dArr) {
        double deriv2Max = this.loss.getDeriv2Max();
        double[] dArr2 = new double[dArr.length];
        this.w = new DenseVector(vecArr.length);
        XORWOW xorwow = new XORWOW();
        for (int i = 0; i < this.iterations; i++) {
            int nextInt = xorwow.nextInt(vecArr.length);
            double d = 0.0d;
            Iterator<IndexValue> it = vecArr[nextInt].iterator();
            while (it.hasNext()) {
                IndexValue next = it.next();
                d += this.loss.getDeriv(dArr2[next.getIndex()], dArr[next.getIndex()]) * next.getValue();
            }
            double length = d / dArr.length;
            double d2 = this.w.get(nextInt);
            double d3 = d2 - (length / deriv2Max) > this.reg / deriv2Max ? ((-length) / deriv2Max) - (this.reg / deriv2Max) : d2 - (length / deriv2Max) < (-this.reg) / deriv2Max ? ((-length) / deriv2Max) + (this.reg / deriv2Max) : -d2;
            this.w.increment(nextInt, d3);
            Iterator<IndexValue> it2 = vecArr[nextInt].iterator();
            while (it2.hasNext()) {
                IndexValue next2 = it2.next();
                int index = next2.getIndex();
                dArr2[index] = dArr2[index] + (d3 * next2.getValue());
            }
        }
    }

    @Override // jsat.regression.Regressor
    /* renamed from: clone, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public SCD m570clone() {
        return new SCD(this);
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }
}
