package jsat.classifiers.linear;

import jsat.SingleWeightVectorModel;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.DataPoint;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.exceptions.FailedToFitException;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.DenseVector;
import jsat.linear.Vec;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/linear/ROMMA.class */
public class ROMMA extends BaseUpdateableClassifier implements BinaryScoreClassifier, SingleWeightVectorModel {
    private static final long serialVersionUID = 8163937542627337711L;
    private boolean useBias;
    private boolean aggressive;
    private Vec w;
    private double bias;

    public ROMMA() {
        this(true);
    }

    public ROMMA(boolean z) {
        this.useBias = true;
        setAggressive(z);
    }

    protected ROMMA(ROMMA romma) {
        this.useBias = true;
        this.aggressive = romma.aggressive;
        if (romma.w != null) {
            this.w = romma.w;
        }
        this.bias = romma.bias;
        this.useBias = romma.useBias;
    }

    @Override // jsat.classifiers.BaseUpdateableClassifier
    public ROMMA mo479clone() {
        return new ROMMA(this);
    }

    public void setAggressive(boolean z) {
        this.aggressive = z;
    }

    public boolean isAggressive() {
        return this.aggressive;
    }

    public void setUseBias(boolean z) {
        this.useBias = z;
    }

    public boolean isUseBias() {
        return this.useBias;
    }

    public Vec getWeightVec() {
        return this.w;
    }

    @Override // jsat.SingleWeightVectorModel
    public Vec getRawWeight() {
        return this.w;
    }

    @Override // jsat.SingleWeightVectorModel
    public double getBias() {
        return this.bias;
    }

    @Override // jsat.SimpleWeightVectorModel
    public Vec getRawWeight(int i) {
        if (i < 1) {
            return getRawWeight();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override // jsat.SimpleWeightVectorModel
    public double getBias(int i) {
        if (i < 1) {
            return getBias();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override // jsat.SimpleWeightVectorModel
    public int numWeightsVecs() {
        return 1;
    }

    @Override // jsat.classifiers.UpdateableClassifier
    public void setUp(CategoricalData[] categoricalDataArr, int i, CategoricalData categoricalData) {
        if (i <= 0) {
            throw new FailedToFitException("ROMMA requires numerical features");
        }
        if (categoricalData.getNumOfCategories() != 2) {
            throw new FailedToFitException("ROMMA only supports binary classification");
        }
        this.w = new DenseVector(i);
        this.bias = 0.0d;
    }

    @Override // jsat.classifiers.UpdateableClassifier
    public void update(DataPoint dataPoint, int i) {
        Vec numericalValues = dataPoint.getNumericalValues();
        double dot = this.w.dot(numericalValues) + this.bias;
        double d = (i * 2) - 1;
        double d2 = d * dot;
        if (d2 < 1.0d) {
            double dot2 = this.w.dot(this.w);
            double dot3 = numericalValues.dot(numericalValues);
            double d3 = dot2 * dot3;
            if (this.aggressive && d2 >= d3) {
                this.w.zeroOut();
                this.w.mutableAdd(d / dot3, numericalValues);
                if (this.useBias) {
                    this.bias = d / dot3;
                    return;
                }
                return;
            }
            double d4 = d3 - (dot * dot);
            double d5 = (d3 - d2) / d4;
            double d6 = (dot2 * (d - dot)) / d4;
            this.w.mutableMultiply(d5);
            this.w.mutableAdd(d6, numericalValues);
            if (this.useBias) {
                this.bias = (d5 * this.bias) + d6;
            }
        }
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        if (this.w == null) {
            throw new UntrainedModelException("Model has not been trained");
        }
        double score = getScore(dataPoint);
        CategoricalResults categoricalResults = new CategoricalResults(2);
        if (score < 0.0d) {
            categoricalResults.setProb(0, 1.0d);
        } else {
            categoricalResults.setProb(1, 1.0d);
        }
        return categoricalResults;
    }

    @Override // jsat.classifiers.calibration.BinaryScoreClassifier
    public double getScore(DataPoint dataPoint) {
        return this.w.dot(dataPoint.getNumericalValues()) + this.bias;
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }
}
