package jsat.classifiers.svm.extended;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import jsat.DataSet;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.DataPoint;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseVector;
import jsat.linear.ScaledVector;
import jsat.linear.Vec;
import jsat.linear.VecWithNorm;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.DoubleList;
import jsat.utils.IndexTable;
import jsat.utils.IntList;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/svm/extended/OnlineAMM.class */
public class OnlineAMM extends BaseUpdateableClassifier implements Parameterized {
    private static final long serialVersionUID = 8291068484917637037L;
    protected List<Map<Integer, Vec>> weightMatrix;
    protected int[] nextID;
    protected double lambda;
    protected int k;
    protected double c;
    protected int time;
    protected int classBudget;
    public static final int DEFAULT_PRUNE_FREQUENCY = 10000;
    public static final double DEFAULT_PRUNE_CONSTANT = 10.0d;
    public static final int DEFAULT_CLASS_BUDGET = 50;
    public static final double DEFAULT_REGULARIZER = 0.01d;

    public OnlineAMM() {
        this(0.01d);
    }

    public OnlineAMM(double d) {
        this(d, 50);
    }

    public OnlineAMM(double d, int i) {
        setLambda(d);
        setClassBudget(i);
        setPruneFrequency(10000);
        setC(10.0d);
    }

    public OnlineAMM(OnlineAMM onlineAMM) {
        if (onlineAMM.weightMatrix != null) {
            this.weightMatrix = new ArrayList(onlineAMM.weightMatrix.size());
            for (Map<Integer, Vec> map : onlineAMM.weightMatrix) {
                LinkedHashMap linkedHashMap = new LinkedHashMap(map.size());
                for (Map.Entry<Integer, Vec> entry : map.entrySet()) {
                    linkedHashMap.put(entry.getKey(), entry.getValue().mo524clone());
                }
                this.weightMatrix.add(linkedHashMap);
            }
            this.nextID = Arrays.copyOf(onlineAMM.nextID, onlineAMM.nextID.length);
        }
        this.time = onlineAMM.time;
        this.lambda = onlineAMM.lambda;
        this.k = onlineAMM.k;
        this.c = onlineAMM.c;
        this.classBudget = onlineAMM.classBudget;
        setEpochs(onlineAMM.getEpochs());
    }

    @Override // jsat.classifiers.BaseUpdateableClassifier
    /* renamed from: clone */
    public OnlineAMM mo479clone() {
        return new OnlineAMM(this);
    }

    public void setLambda(double d) {
        if (d <= 0.0d || Double.isNaN(d) || Double.isInfinite(d)) {
            throw new IllegalArgumentException("Lambda must be positive, not " + d);
        }
        this.lambda = d;
    }

    public double getLambda() {
        return this.lambda;
    }

    public void setPruneFrequency(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("Pruning frequency must be positive, not " + i);
        }
        this.k = i;
    }

    public int getPruneFrequency() {
        return this.k;
    }

    public void setC(double d) {
        if (d <= 0.0d || Double.isNaN(d) || Double.isInfinite(d)) {
            throw new IllegalArgumentException("C must be positive, not " + d);
        }
        this.c = d;
    }

    public double getC() {
        return this.c;
    }

    public void setClassBudget(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("Number of hyperplanes must be positive, not " + i);
        }
        this.classBudget = i;
    }

    public int getClassBudget() {
        return this.classBudget;
    }

    @Override // jsat.classifiers.UpdateableClassifier
    public void setUp(CategoricalData[] categoricalDataArr, int i, CategoricalData categoricalData) {
        if (i < 1) {
            throw new FailedToFitException("OnlineAMM requires numeric features to perform classification");
        }
        this.weightMatrix = new ArrayList(categoricalData.getNumOfCategories());
        for (int i2 = 0; i2 < categoricalData.getNumOfCategories(); i2++) {
            this.weightMatrix.add(new LinkedHashMap());
        }
        this.nextID = new int[this.weightMatrix.size()];
        this.time = 1;
    }

    @Override // jsat.classifiers.UpdateableClassifier
    public void update(DataPoint dataPoint, int i) {
        update(dataPoint, i, Integer.MIN_VALUE);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public int update(DataPoint dataPoint, int i, int i2) {
        double d;
        Vec numericalValues = dataPoint.getNumericalValues();
        if (i2 == Integer.MIN_VALUE || i2 > this.nextID[i]) {
            d = 0.0d;
            i2 = -1;
            for (Map.Entry<Integer, Vec> entry : this.weightMatrix.get(i).entrySet()) {
                double dot = numericalValues.dot(entry.getValue());
                if (dot >= d) {
                    i2 = entry.getKey().intValue();
                    d = dot;
                }
            }
        } else {
            if (!this.weightMatrix.get(i).containsKey(Integer.valueOf(i2))) {
                return update(dataPoint, i, Integer.MIN_VALUE);
            }
            d = i2 == -1 ? 0.0d : this.weightMatrix.get(i).get(Integer.valueOf(i2)).dot(numericalValues);
        }
        double d2 = this.lambda;
        int i3 = this.time;
        this.time = i3 + 1;
        double d3 = 1.0d / (d2 * i3);
        int i4 = i > 0 ? 0 : 1;
        double d4 = 0.0d;
        int i5 = -1;
        for (int i6 = 0; i6 < this.weightMatrix.size(); i6++) {
            if (i6 != i) {
                for (Map.Entry<Integer, Vec> entry2 : this.weightMatrix.get(i6).entrySet()) {
                    double dot2 = numericalValues.dot(entry2.getValue());
                    if (dot2 > d4) {
                        i4 = i6;
                        i5 = entry2.getKey().intValue();
                        d4 = dot2;
                    }
                }
            }
        }
        boolean z = 0.0d < (1.0d + d4) - d;
        for (int i7 = 0; i7 < this.weightMatrix.size(); i7++) {
            Map<Integer, Vec> map = this.weightMatrix.get(i7);
            for (Map.Entry<Integer, Vec> entry3 : map.entrySet()) {
                int intValue = entry3.getKey().intValue();
                Vec value = entry3.getValue();
                value.mutableMultiply(-((d3 * this.lambda) - 1.0d));
                if (i7 == i4 && intValue == i5 && z) {
                    value.mutableSubtract(d3, numericalValues);
                } else if (i7 == i && intValue == i2 && z) {
                    value.mutableAdd(d3, numericalValues);
                }
            }
            if (i7 == i4 && i5 == -1 && z && map.size() < this.classBudget) {
                ScaledVector scaledVector = new ScaledVector(new VecWithNorm(new DenseVector(numericalValues), numericalValues.pNorm(2.0d)));
                scaledVector.mutableMultiply(-d3);
                int[] iArr = this.nextID;
                int i8 = i7;
                int i9 = iArr[i8];
                iArr[i8] = i9 + 1;
                map.put(Integer.valueOf(i9), scaledVector);
            } else if (i7 == i && i2 == -1 && z && map.size() < this.classBudget) {
                ScaledVector scaledVector2 = new ScaledVector(new VecWithNorm(new DenseVector(numericalValues), numericalValues.pNorm(2.0d)));
                scaledVector2.mutableMultiply(d3);
                int[] iArr2 = this.nextID;
                int i10 = i7;
                int i11 = iArr2[i10];
                iArr2[i10] = i11 + 1;
                map.put(Integer.valueOf(i11), scaledVector2);
                i2 = map.size() - 1;
            }
        }
        if (this.time % this.k == 0) {
            double d5 = this.c / ((this.time - 1) * this.lambda);
            IntList intList = new IntList(this.weightMatrix.size());
            IntList intList2 = new IntList(this.weightMatrix.size());
            DoubleList doubleList = new DoubleList(this.weightMatrix.size());
            for (int i12 = 0; i12 < this.weightMatrix.size(); i12++) {
                for (Map.Entry<Integer, Vec> entry4 : this.weightMatrix.get(i12).entrySet()) {
                    Vec value2 = entry4.getValue();
                    intList.add(i12);
                    intList2.add(entry4.getKey());
                    doubleList.add(value2.dot(value2));
                }
            }
            IndexTable indexTable = new IndexTable(doubleList);
            for (int i13 = 0; i13 < doubleList.size(); i13++) {
                int index = indexTable.index(i13);
                double doubleValue = doubleList.get(index).doubleValue();
                if (doubleValue >= d5) {
                    break;
                }
                d5 -= doubleValue;
                this.weightMatrix.get(intList.getI(index)).remove(Integer.valueOf(intList2.getI(index)));
            }
        }
        return i2;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        Vec numericalValues = dataPoint.getNumericalValues();
        int i = 0;
        double d = Double.NEGATIVE_INFINITY;
        for (int i2 = 0; i2 < this.weightMatrix.size(); i2++) {
            Iterator<Vec> it = this.weightMatrix.get(i2).values().iterator();
            while (it.hasNext()) {
                double dot = numericalValues.dot(it.next());
                if (dot > d) {
                    i = i2;
                    d = dot;
                }
            }
        }
        CategoricalResults categoricalResults = new CategoricalResults(this.weightMatrix.size());
        categoricalResults.setProb(i, 1.0d);
        return categoricalResults;
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }

    public static Distribution guessLambda(DataSet dataSet) {
        return new LogUniform(1.0E-7d, 0.01d);
    }
}
