package jsat.classifiers.linear;

import java.util.Iterator;
import java.util.List;
import jsat.DataSet;
import jsat.SingleWeightVectorModel;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.DataPoint;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.exceptions.FailedToFitException;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.Matrix;
import jsat.linear.Vec;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/linear/NHERD.class */
public class NHERD extends BaseUpdateableClassifier implements BinaryScoreClassifier, Parameterized, SingleWeightVectorModel {
    private static final long serialVersionUID = -1186002893766449917L;
    private Vec w;
    private Matrix sigmaM;
    private Vec sigmaV;
    private CovMode covMode;
    private double C;
    private Vec Sigma_xt;

    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/linear/NHERD$CovMode.class */
    public enum CovMode {
        FULL,
        DROP,
        PROJECT,
        EXACT
    }

    public NHERD(double d, CovMode covMode) {
        setC(d);
        setCovMode(covMode);
    }

    protected NHERD(NHERD nherd) {
        this.C = nherd.C;
        this.covMode = nherd.covMode;
        if (nherd.w != null) {
            this.w = nherd.w.mo524clone();
        }
        if (nherd.sigmaM != null) {
            this.sigmaM = nherd.sigmaM.mo640clone();
        }
        if (nherd.sigmaV != null) {
            this.sigmaV = nherd.sigmaV.mo524clone();
        }
        if (nherd.Sigma_xt != null) {
            this.Sigma_xt = nherd.Sigma_xt.mo524clone();
        }
    }

    public void setC(double d) {
        if (Double.isNaN(d) || Double.isInfinite(d) || d <= 0.0d) {
            throw new IllegalArgumentException("C must be a postive constant, not " + d);
        }
        this.C = d;
    }

    public double getC() {
        return this.C;
    }

    public void setCovMode(CovMode covMode) {
        this.covMode = covMode;
    }

    public CovMode getCovMode() {
        return this.covMode;
    }

    public Vec getWeightVec() {
        return this.w;
    }

    @Override // jsat.SingleWeightVectorModel
    public Vec getRawWeight() {
        return this.w;
    }

    @Override // jsat.SingleWeightVectorModel
    public double getBias() {
        return 0.0d;
    }

    @Override // jsat.SimpleWeightVectorModel
    public Vec getRawWeight(int i) {
        if (i < 1) {
            return getRawWeight();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override // jsat.SimpleWeightVectorModel
    public double getBias(int i) {
        if (i < 1) {
            return getBias();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override // jsat.SimpleWeightVectorModel
    public int numWeightsVecs() {
        return 1;
    }

    @Override // jsat.classifiers.BaseUpdateableClassifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public NHERD mo479clone() {
        return new NHERD(this);
    }

    @Override // jsat.classifiers.UpdateableClassifier
    public void setUp(CategoricalData[] categoricalDataArr, int i, CategoricalData categoricalData) {
        if (i <= 0) {
            throw new FailedToFitException("AROW requires numeric attributes to perform classification");
        }
        if (categoricalData.getNumOfCategories() != 2) {
            throw new FailedToFitException("AROW is a binary classifier");
        }
        this.w = new DenseVector(i);
        this.Sigma_xt = new DenseVector(i);
        if (this.covMode == CovMode.FULL) {
            this.sigmaM = Matrix.eye(i);
        } else {
            this.sigmaV = new DenseVector(i);
            this.sigmaV.mutableAdd(1.0d);
        }
    }

    @Override // jsat.classifiers.UpdateableClassifier
    public void update(DataPoint dataPoint, int i) {
        double dot;
        Vec numericalValues = dataPoint.getNumericalValues();
        double d = (i * 2) - 1;
        double dot2 = numericalValues.dot(this.w);
        if (d * dot2 > 1.0d) {
            return;
        }
        if (this.covMode != CovMode.FULL) {
            dot = 0.0d;
            Iterator<IndexValue> it = numericalValues.iterator();
            while (it.hasNext()) {
                IndexValue next = it.next();
                double value = next.getValue();
                dot += value * value * this.sigmaV.get(next.getIndex());
            }
        } else {
            this.sigmaM.multiply(numericalValues, 1.0d, this.Sigma_xt);
            dot = numericalValues.dot(this.Sigma_xt);
        }
        double max = (d * Math.max(0.0d, 1.0d - (d * dot2))) / (dot + (1.0d / this.C));
        if (this.covMode == CovMode.FULL) {
            this.w.mutableAdd(max, this.Sigma_xt);
        } else {
            Iterator<IndexValue> it2 = numericalValues.iterator();
            while (it2.hasNext()) {
                IndexValue next2 = it2.next();
                this.w.increment(next2.getIndex(), max * next2.getValue() * this.sigmaV.get(next2.getIndex()));
            }
        }
        double d2 = this.C * ((this.C * dot) + 2.0d);
        double d3 = (1.0d + (this.C * dot)) * (1.0d + (this.C * dot));
        switch (this.covMode) {
            case FULL:
                Matrix.OuterProductUpdate(this.sigmaM, this.Sigma_xt, this.Sigma_xt, (-d2) / d3);
                break;
            case DROP:
                double d4 = (-d2) / d3;
                Iterator<IndexValue> it3 = numericalValues.iterator();
                while (it3.hasNext()) {
                    IndexValue next3 = it3.next();
                    int index = next3.getIndex();
                    double value2 = next3.getValue() * this.sigmaV.get(index);
                    this.sigmaV.increment(index, d4 * value2 * value2);
                }
                break;
            case PROJECT:
                Iterator<IndexValue> it4 = numericalValues.iterator();
                while (it4.hasNext()) {
                    IndexValue next4 = it4.next();
                    int index2 = next4.getIndex();
                    double value3 = next4.getValue();
                    this.sigmaV.set(index2, 1.0d / ((1.0d / this.sigmaV.get(index2)) + ((d2 * value3) * value3)));
                }
                break;
            case EXACT:
                Iterator<IndexValue> it5 = numericalValues.iterator();
                while (it5.hasNext()) {
                    IndexValue next5 = it5.next();
                    int index3 = next5.getIndex();
                    double value4 = next5.getValue();
                    double d5 = this.sigmaV.get(index3);
                    this.sigmaV.set(index3, d5 / Math.pow((((d5 * value4) * value4) * this.C) + 1.0d, 2.0d));
                }
                break;
        }
        if (this.covMode == CovMode.FULL) {
            this.Sigma_xt.zeroOut();
        }
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        if (this.w == null) {
            throw new UntrainedModelException("Model has not yet ben trained");
        }
        CategoricalResults categoricalResults = new CategoricalResults(2);
        if (getScore(dataPoint) < 0.0d) {
            categoricalResults.setProb(0, 1.0d);
        } else {
            categoricalResults.setProb(1, 1.0d);
        }
        return categoricalResults;
    }

    @Override // jsat.classifiers.calibration.BinaryScoreClassifier
    public double getScore(DataPoint dataPoint) {
        return this.w.dot(dataPoint.getNumericalValues());
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }

    public static Distribution guessC(DataSet dataSet) {
        return new LogUniform(Math.pow(2.0d, -4.0d), Math.pow(2.0d, 4.0d));
    }
}
