package jsat.math.optimization.stochastic;

import java.util.Arrays;
import java.util.Iterator;
import jsat.linear.IndexValue;
import jsat.linear.ScaledVector;
import jsat.linear.Vec;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/math/optimization/stochastic/NAdaGrad.class */
public class NAdaGrad implements GradientUpdater {
    private static final long serialVersionUID = 5138675613579751777L;
    private double[] G;
    private double[] S;
    private double N;
    private double biasG;
    private long t;

    public NAdaGrad() {
    }

    public NAdaGrad(NAdaGrad nAdaGrad) {
        if (nAdaGrad.G != null) {
            this.G = Arrays.copyOf(nAdaGrad.G, nAdaGrad.G.length);
        }
        if (nAdaGrad.S != null) {
            this.S = Arrays.copyOf(nAdaGrad.S, nAdaGrad.S.length);
        }
        this.biasG = nAdaGrad.biasG;
        this.N = nAdaGrad.N;
        this.t = nAdaGrad.t;
    }

    @Override // jsat.math.optimization.stochastic.GradientUpdater
    public void update(Vec vec, Vec vec2, double d) {
        update(vec, vec2, d, 0.0d, 0.0d);
    }

    @Override // jsat.math.optimization.stochastic.GradientUpdater
    public double update(Vec vec, Vec vec2, double d, double d2, double d3) {
        if (!(vec2 instanceof ScaledVector)) {
            double sqrt = (-d) * Math.sqrt((this.t + 1) / Math.max(this.N, this.t + 1));
            Iterator<IndexValue> it = vec2.iterator();
            while (it.hasNext()) {
                IndexValue next = it.next();
                int index = next.getIndex();
                double value = next.getValue();
                double[] dArr = this.G;
                dArr[index] = dArr[index] + (value * value);
                vec.increment(index, (sqrt * value) / (Math.max(this.S[index], 1.0d) * Math.sqrt(this.G[index])));
            }
            double sqrt2 = (d * d3) / Math.sqrt(this.biasG);
            this.biasG += d3 * d3;
            return sqrt2;
        }
        this.t++;
        Iterator<IndexValue> it2 = ((ScaledVector) vec2).getBase().iterator();
        while (it2.hasNext()) {
            IndexValue next2 = it2.next();
            int index2 = next2.getIndex();
            double abs = Math.abs(next2.getValue());
            if (abs > this.S[index2]) {
                vec.set(index2, (vec.get(index2) * this.S[index2]) / abs);
                this.S[index2] = abs;
            }
            this.N += (abs * abs) / (this.S[index2] * this.S[index2]);
        }
        double sqrt3 = (-d) * Math.sqrt(this.t / (this.N + 1.0E-6d));
        Iterator<IndexValue> it3 = vec2.iterator();
        while (it3.hasNext()) {
            IndexValue next3 = it3.next();
            int index3 = next3.getIndex();
            double value2 = next3.getValue();
            double[] dArr2 = this.G;
            dArr2[index3] = dArr2[index3] + (value2 * value2);
            vec.increment(index3, (sqrt3 * value2) / (this.S[index3] * Math.sqrt(this.G[index3])));
        }
        double sqrt4 = (d * d3) / Math.sqrt(this.biasG);
        this.biasG += d3 * d3;
        return sqrt4;
    }

    @Override // jsat.math.optimization.stochastic.GradientUpdater
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public NAdaGrad m709clone() {
        return new NAdaGrad(this);
    }

    @Override // jsat.math.optimization.stochastic.GradientUpdater
    public void setup(int i) {
        this.G = new double[i];
        this.S = new double[i];
        this.biasG = 1.0d;
        this.t = 0L;
    }
}
