package jsat.classifiers.linear;

import java.util.Iterator;
import java.util.List;
import jsat.DataSet;
import jsat.SingleWeightVectorModel;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.DataPoint;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.distributions.Normal;
import jsat.distributions.Uniform;
import jsat.exceptions.FailedToFitException;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.Matrix;
import jsat.linear.Vec;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/linear/SCW.class */
public class SCW extends BaseUpdateableClassifier implements BinaryScoreClassifier, Parameterized, SingleWeightVectorModel {
    private static final long serialVersionUID = -6721377074407660742L;
    private double C;
    private double eta;
    private double phi;
    private double phiSqrd;
    private double zeta;
    private double psi;
    private Mode mode;
    private Vec w;
    private Matrix sigmaM;
    private Vec sigmaV;
    private Vec Sigma_xt;
    private boolean diagonalOnly;

    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/linear/SCW$Mode.class */
    public enum Mode {
        CW,
        SCWI,
        SCWII
    }

    private void zeroOutSigmaXt(Vec vec) {
        if (!this.diagonalOnly || !vec.isSparse()) {
            this.Sigma_xt.zeroOut();
            return;
        }
        Iterator<IndexValue> it = vec.iterator();
        while (it.hasNext()) {
            this.Sigma_xt.set(it.next().getIndex(), 0.0d);
        }
    }

    public SCW() {
        this(0.5d, Mode.SCWI, true);
    }

    public SCW(double d, Mode mode, boolean z) {
        this.C = 1.0d;
        this.diagonalOnly = false;
        setEta(d);
        setMode(mode);
        setDiagonalOnly(z);
    }

    protected SCW(SCW scw) {
        this.C = 1.0d;
        this.diagonalOnly = false;
        this.C = scw.C;
        this.diagonalOnly = scw.diagonalOnly;
        this.mode = scw.mode;
        setEta(scw.eta);
        if (scw.w != null) {
            this.w = scw.w.mo525clone();
        }
        if (scw.sigmaM != null) {
            this.sigmaM = scw.sigmaM.mo641clone();
        }
        if (scw.sigmaV != null) {
            this.sigmaV = scw.sigmaV.mo525clone();
        }
        if (scw.Sigma_xt != null) {
            this.Sigma_xt = scw.Sigma_xt.mo525clone();
        }
    }

    public void setEta(double d) {
        if (Double.isNaN(d) || d < 0.5d || d > 1.0d) {
            throw new IllegalArgumentException("eta must be in [0.5, 1] not " + d);
        }
        this.eta = d;
        this.phi = Normal.invcdf(d, 0.0d, 1.0d);
        this.phiSqrd = this.phi * this.phi;
        this.zeta = 1.0d + this.phiSqrd;
        this.psi = 1.0d + (this.phiSqrd / 2.0d);
    }

    public double getEta() {
        return this.eta;
    }

    public void setC(double d) {
        this.C = d;
    }

    public double getC() {
        return this.C;
    }

    public void setMode(Mode mode) {
        this.mode = mode;
    }

    public Mode getMode() {
        return this.mode;
    }

    public void setDiagonalOnly(boolean z) {
        this.diagonalOnly = z;
    }

    public boolean isDiagonalOnly() {
        return this.diagonalOnly;
    }

    public Vec getWeightVec() {
        return this.w;
    }

    @Override // jsat.SingleWeightVectorModel
    public Vec getRawWeight() {
        return this.w;
    }

    @Override // jsat.SingleWeightVectorModel
    public double getBias() {
        return 0.0d;
    }

    @Override // jsat.SimpleWeightVectorModel
    public Vec getRawWeight(int i) {
        if (i < 1) {
            return getRawWeight();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override // jsat.SimpleWeightVectorModel
    public double getBias(int i) {
        if (i < 1) {
            return getBias();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override // jsat.SimpleWeightVectorModel
    public int numWeightsVecs() {
        return 1;
    }

    @Override // jsat.classifiers.BaseUpdateableClassifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public SCW mo480clone() {
        return new SCW(this);
    }

    @Override // jsat.classifiers.UpdateableClassifier
    public void setUp(CategoricalData[] categoricalDataArr, int i, CategoricalData categoricalData) {
        if (i <= 0) {
            throw new FailedToFitException("SCW requires numeric attributes to perform classification");
        }
        if (categoricalData.getNumOfCategories() != 2) {
            throw new FailedToFitException("SCW is a binary classifier");
        }
        this.w = new DenseVector(i);
        this.Sigma_xt = new DenseVector(i);
        if (!this.diagonalOnly) {
            this.sigmaM = Matrix.eye(i);
        } else {
            this.sigmaV = new DenseVector(i);
            this.sigmaV.mutableAdd(1.0d);
        }
    }

    @Override // jsat.classifiers.UpdateableClassifier
    public void update(DataPoint dataPoint, int i) {
        double min;
        Vec numericalValues = dataPoint.getNumericalValues();
        double d = (i * 2) - 1;
        double dot = numericalValues.dot(this.w);
        double d2 = 0.0d;
        if (this.diagonalOnly) {
            Iterator<IndexValue> it = numericalValues.iterator();
            while (it.hasNext()) {
                IndexValue next = it.next();
                double value = next.getValue();
                d2 += value * value * this.sigmaV.get(next.getIndex());
            }
        } else {
            this.sigmaM.multiply(numericalValues, 1.0d, this.Sigma_xt);
            d2 = numericalValues.dot(this.Sigma_xt);
        }
        if (d2 <= 0.0d) {
            throw new FailedToFitException("Numerical issues occured");
        }
        double d3 = d * dot;
        if (Math.max(0.0d, (this.phi * Math.sqrt(d2)) - d3) <= 1.0E-15d) {
            if (this.diagonalOnly) {
                return;
            }
            zeroOutSigmaXt(numericalValues);
            return;
        }
        if (this.mode == Mode.SCWI || this.mode == Mode.CW) {
            double max = Math.max(0.0d, (((-d3) * this.psi) + Math.sqrt(((((d3 * d3) * this.phiSqrd) * this.phiSqrd) / 4.0d) + ((d2 * this.phiSqrd) * this.zeta))) / (d2 * this.zeta));
            min = this.mode == Mode.SCWI ? Math.min(this.C, max) : max;
        } else {
            double d4 = d2 + (1.0d / (2.0d * this.C));
            min = Math.max(0.0d, ((-(((2.0d * d3) * d4) + ((this.phiSqrd * d3) * d2))) + (this.phi * Math.sqrt(((((this.phiSqrd * d2) * d2) * d3) * d3) + (((4.0d * d4) * d2) * (d4 + (d2 * this.phiSqrd)))))) / (2.0d * ((d4 * d4) + ((d4 * d2) * this.phiSqrd))));
        }
        if (min < 1.0E-7d) {
            if (this.diagonalOnly) {
                return;
            }
            zeroOutSigmaXt(numericalValues);
            return;
        }
        double pow = Math.pow((((-min) * d2) * this.phi) + Math.sqrt(((((min * min) * d2) * d2) * this.phiSqrd) + (4.0d * d2)), 2.0d) / 4.0d;
        if (this.diagonalOnly) {
            Iterator<IndexValue> it2 = numericalValues.iterator();
            while (it2.hasNext()) {
                IndexValue next2 = it2.next();
                this.w.increment(next2.getIndex(), min * d * next2.getValue() * this.sigmaV.get(next2.getIndex()));
            }
        } else {
            this.w.mutableAdd(min * d, this.Sigma_xt);
        }
        if (!this.diagonalOnly) {
            Matrix.OuterProductUpdate(this.sigmaM, this.Sigma_xt, this.Sigma_xt, -((min * this.phi) / (Math.sqrt(pow) + ((d2 * min) * this.phi))));
            zeroOutSigmaXt(numericalValues);
            return;
        }
        double pow2 = min * this.phi * Math.pow(pow, -0.5d);
        Iterator<IndexValue> it3 = numericalValues.iterator();
        while (it3.hasNext()) {
            IndexValue next3 = it3.next();
            int index = next3.getIndex();
            this.sigmaV.set(index, 1.0d / ((1.0d / this.sigmaV.get(index)) + (pow2 * Math.pow(next3.getValue(), 2.0d))));
        }
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        if (this.w == null) {
            throw new UntrainedModelException("Model has not yet ben trained");
        }
        CategoricalResults categoricalResults = new CategoricalResults(2);
        if (getScore(dataPoint) < 0.0d) {
            categoricalResults.setProb(0, 1.0d);
        } else {
            categoricalResults.setProb(1, 1.0d);
        }
        return categoricalResults;
    }

    @Override // jsat.classifiers.calibration.BinaryScoreClassifier
    public double getScore(DataPoint dataPoint) {
        return this.w.dot(dataPoint.getNumericalValues());
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }

    public static Distribution guessC(DataSet dataSet) {
        return new LogUniform(Math.pow(2.0d, -4.0d), Math.pow(2.0d, 4.0d));
    }

    public static Distribution guessEta(DataSet dataSet) {
        return new Uniform(0.5d, 0.95d);
    }
}
