package smile.base.mlp;

import java.io.Serializable;
import java.util.Arrays;
import smile.math.MathEx;
import smile.math.matrix.DenseMatrix;
import smile.math.matrix.Matrix;

/* loaded from: input_file:smile-core-2.4.0.jar:smile/base/mlp/Layer.class */
public abstract class Layer implements Serializable {
    private static final long serialVersionUID = 2;
    protected int n;
    protected int p;
    protected double[] output;
    protected double[] gradient;
    protected DenseMatrix weight;
    protected DenseMatrix update;
    protected double[] bias;
    protected double[] updateBias;

    public Layer(int i, int i2) {
        this.n = i;
        this.p = i2;
        this.weight = Matrix.randn(i, i2, 0.0d, Math.sqrt(2.0d / i2));
        this.bias = new double[i];
        this.output = new double[i];
        this.gradient = new double[i];
        this.update = Matrix.zeros(i, i2);
        this.updateBias = new double[i];
    }

    public int getOutputSize() {
        return this.n;
    }

    public int getInputSize() {
        return this.p;
    }

    public double[] output() {
        return this.output;
    }

    public double[] gradient() {
        return this.gradient;
    }

    public void propagate(double[] dArr) {
        System.arraycopy(this.bias, 0, this.output, 0, this.n);
        this.weight.axpy(dArr, this.output);
        f(this.output);
    }

    public abstract void f(double[] dArr);

    public abstract void backpropagate(double[] dArr);

    public void computeUpdate(double d, double d2, double[] dArr) {
        for (int i = 0; i < this.p; i++) {
            double d3 = dArr[i];
            for (int i2 = 0; i2 < this.n; i2++) {
                double d4 = d * this.gradient[i2] * d3;
                if (d2 > 0.0d) {
                    d4 += d2 * this.update.get(i2, i);
                }
                this.update.set(i2, i, d4);
            }
        }
        for (int i3 = 0; i3 < this.n; i3++) {
            double d5 = d * this.gradient[i3];
            if (d2 > 0.0d) {
                d5 += d2 * this.updateBias[i3];
            }
            this.updateBias[i3] = d5;
        }
    }

    public void update(double d, double d2) {
        this.weight.add(this.update);
        MathEx.add(this.bias, this.updateBias);
        if (d2 < 1.0d) {
            this.weight.mul(d2);
        }
        if (d == 1.0d) {
            this.update.fill(0.0d);
            Arrays.fill(this.updateBias, 0.0d);
        }
    }

    public static HiddenLayerBuilder linear(int i) {
        return new HiddenLayerBuilder(i, ActivationFunction.linear());
    }

    public static HiddenLayerBuilder rectifier(int i) {
        return new HiddenLayerBuilder(i, ActivationFunction.rectifier());
    }

    public static HiddenLayerBuilder sigmoid(int i) {
        return new HiddenLayerBuilder(i, ActivationFunction.sigmoid());
    }

    public static HiddenLayerBuilder tanh(int i) {
        return new HiddenLayerBuilder(i, ActivationFunction.tanh());
    }

    public static OutputLayerBuilder mse(int i, OutputFunction outputFunction) {
        return new OutputLayerBuilder(i, outputFunction, Cost.MEAN_SQUARED_ERROR);
    }

    public static OutputLayerBuilder mle(int i, OutputFunction outputFunction) {
        return new OutputLayerBuilder(i, outputFunction, Cost.LIKELIHOOD);
    }
}
