package smile.base.mlp;

import java.io.Serializable;
import java.util.Arrays;
import java.util.stream.Collectors;

/* loaded from: input_file:smile-core-2.4.0.jar:smile/base/mlp/MultilayerPerceptron.class */
public abstract class MultilayerPerceptron implements Serializable {
    private static final long serialVersionUID = 2;
    protected int p;
    protected OutputLayer output;
    protected Layer[] net;
    protected double[] target;
    protected double eta = 0.1d;
    protected double alpha = 0.0d;
    protected double lambda = 0.0d;

    public MultilayerPerceptron(Layer... layerArr) {
        if (layerArr.length < 2) {
            throw new IllegalArgumentException("Too few layers: " + layerArr.length);
        }
        Layer layer = layerArr[0];
        for (int i = 1; i < layerArr.length; i++) {
            Layer layer2 = layerArr[i];
            if (layer2.getInputSize() != layer.getOutputSize()) {
                throw new IllegalArgumentException(String.format("Invalid network architecture. Layer %d has %d neurons while layer %d takes %d inputs", Integer.valueOf(i - 1), Integer.valueOf(layer.getOutputSize()), Integer.valueOf(i), Integer.valueOf(layer2.getInputSize())));
            }
            layer = layer2;
        }
        this.output = (OutputLayer) layerArr[layerArr.length - 1];
        this.net = (Layer[]) Arrays.copyOf(layerArr, layerArr.length - 1);
        this.p = layerArr[0].getInputSize();
        this.target = new double[this.output.getOutputSize()];
    }

    public String toString() {
        return String.format("x(%d) -> %s -> %s(eta = %.2f, alpha = %.2f, lambda = %.2f)", Integer.valueOf(this.p), Arrays.stream(this.net).map((v0) -> {
            return v0.toString();
        }).collect(Collectors.joining(" -> ")), this.output, Double.valueOf(this.eta), Double.valueOf(this.alpha), Double.valueOf(this.lambda));
    }

    public void setLearningRate(double d) {
        if (d <= 0.0d) {
            throw new IllegalArgumentException("Invalid learning rate: " + d);
        }
        this.eta = d;
    }

    public void setMomentum(double d) {
        if (d < 0.0d || d >= 1.0d) {
            throw new IllegalArgumentException("Invalid momentum factor: " + d);
        }
        this.alpha = d;
    }

    public void setWeightDecay(double d) {
        if (d < 0.0d || d > 0.1d) {
            throw new IllegalArgumentException("Invalid weight decay factor: " + d);
        }
        this.lambda = d;
    }

    public double getLearningRate() {
        return this.eta;
    }

    public double getMomentum() {
        return this.alpha;
    }

    public double getWeightDecay() {
        return this.lambda;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void propagate(double[] dArr) {
        double[] dArr2 = dArr;
        for (int i = 0; i < this.net.length; i++) {
            this.net[i].propagate(dArr2);
            dArr2 = this.net[i].output();
        }
        this.output.propagate(dArr2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v29, types: [smile.base.mlp.Layer[]] */
    /* JADX WARN: Type inference failed for: r0v30 */
    public void backpropagate(double[] dArr) {
        this.output.computeError(this.target, 1.0d);
        OutputLayer outputLayer = this.output;
        for (int length = this.net.length - 1; length >= 0; length--) {
            outputLayer.backpropagate(this.net[length].gradient());
            outputLayer = this.net[length];
        }
        outputLayer.backpropagate(null);
        for (Layer layer : this.net) {
            layer.computeUpdate(this.eta, this.alpha, dArr);
            dArr = layer.output();
        }
        this.output.computeUpdate(this.eta, this.alpha, dArr);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void update() {
        double d = 1.0d - ((2.0d * this.eta) * this.lambda);
        if (d < 0.9d) {
            throw new IllegalStateException(String.format("Invalid learning rate (eta = %.2f) and/or decay (lambda = %.2f)", Double.valueOf(this.eta), Double.valueOf(this.lambda)));
        }
        for (Layer layer : this.net) {
            layer.update(this.alpha, d);
        }
        this.output.update(this.alpha, d);
    }
}
