package jsat.math.optimization.stochastic;

import java.util.Iterator;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.ScaledVector;
import jsat.linear.Vec;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/math/optimization/stochastic/AdaDelta.class */
public class AdaDelta implements GradientUpdater {
    private static final long serialVersionUID = 5855631993426837618L;
    private double rho;
    private Vec gSqrd;
    private Vec deltaXSqrt;
    private double biasGSqrd;
    private double deltaBiasSqrt;
    private double eps;

    public AdaDelta() {
        this(0.95d);
    }

    public AdaDelta(double d) {
        this.eps = 1.0E-4d;
        setRho(d);
    }

    public AdaDelta(AdaDelta adaDelta) {
        this.eps = 1.0E-4d;
        this.rho = adaDelta.rho;
        if (adaDelta.gSqrd != null) {
            this.gSqrd = adaDelta.gSqrd.mo525clone();
            this.deltaXSqrt = adaDelta.deltaXSqrt.mo525clone();
        }
        this.biasGSqrd = adaDelta.biasGSqrd;
        this.deltaBiasSqrt = adaDelta.deltaBiasSqrt;
    }

    public void setRho(double d) {
        if (d <= 0.0d || d >= 1.0d || Double.isNaN(d)) {
            throw new IllegalArgumentException("Rho must be in (0, 1)");
        }
        this.rho = d;
    }

    public double getRho() {
        return this.rho;
    }

    @Override // jsat.math.optimization.stochastic.GradientUpdater
    public void update(Vec vec, Vec vec2, double d) {
        update(vec, vec2, d, 0.0d, 0.0d);
    }

    @Override // jsat.math.optimization.stochastic.GradientUpdater
    public double update(Vec vec, Vec vec2, double d, double d2, double d3) {
        this.gSqrd.mutableMultiply(this.rho);
        this.biasGSqrd *= this.rho;
        Iterator<IndexValue> it = vec2.iterator();
        while (it.hasNext()) {
            IndexValue next = it.next();
            int index = next.getIndex();
            double value = next.getValue();
            this.gSqrd.increment(index, value * value * (1.0d - this.rho));
            double d4 = (-Math.sqrt((this.deltaXSqrt.get(index) + this.eps) / (this.gSqrd.get(index) + this.eps))) * value;
            vec.increment(index, d * d4);
            this.deltaXSqrt.increment(index, ((1.0d - this.rho) / this.rho) * d4 * d4);
        }
        this.deltaXSqrt.mutableMultiply(this.rho);
        this.biasGSqrd += d3 * d3 * (1.0d - this.rho);
        double sqrt = Math.sqrt((this.deltaBiasSqrt + this.eps) / (this.biasGSqrd + this.eps)) * d3;
        double d5 = d * sqrt;
        this.deltaBiasSqrt += ((1.0d - this.rho) / this.rho) * sqrt * sqrt;
        this.deltaBiasSqrt *= this.rho;
        return d5;
    }

    @Override // jsat.math.optimization.stochastic.GradientUpdater
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public AdaDelta m706clone() {
        return new AdaDelta(this);
    }

    @Override // jsat.math.optimization.stochastic.GradientUpdater
    public void setup(int i) {
        this.gSqrd = new ScaledVector(new DenseVector(i));
        this.deltaXSqrt = new ScaledVector(new DenseVector(i));
        this.biasGSqrd = 0.0d;
        this.deltaBiasSqrt = 0.0d;
    }
}
