package smile.regression;

import smile.base.mlp.Cost;
import smile.base.mlp.Layer;
import smile.base.mlp.LayerBuilder;
import smile.base.mlp.MultilayerPerceptron;
import smile.base.mlp.OutputFunction;
import smile.base.mlp.OutputLayer;

/* loaded from: input_file:smile-core-2.4.0.jar:smile/regression/MLP.class */
public class MLP extends MultilayerPerceptron implements OnlineRegression<double[]> {
    private static final long serialVersionUID = 2;

    public MLP(int i, LayerBuilder... layerBuilderArr) {
        super(net(i, layerBuilderArr));
    }

    private static Layer[] net(int i, LayerBuilder... layerBuilderArr) {
        int length = layerBuilderArr.length;
        Layer[] layerArr = new Layer[length + 1];
        for (int i2 = 0; i2 < length; i2++) {
            layerArr[i2] = layerBuilderArr[i2].build(i);
            i = layerBuilderArr[i2].neurons();
        }
        layerArr[length] = new OutputLayer(1, i, OutputFunction.LINEAR, Cost.MEAN_SQUARED_ERROR);
        return layerArr;
    }

    @Override // smile.regression.Regression
    public double predict(double[] dArr) {
        propagate(dArr);
        return this.output.output()[0];
    }

    @Override // smile.regression.OnlineRegression
    public void update(double[] dArr, double d) {
        propagate(dArr);
        this.target[0] = d;
        backpropagate(dArr);
        update();
    }

    @Override // smile.regression.OnlineRegression
    public void update(double[][] dArr, double[] dArr2) {
        double d = this.alpha;
        this.alpha = 1.0d;
        for (int i = 0; i < dArr.length; i++) {
            propagate(dArr[i]);
            this.target[0] = dArr2[i];
            backpropagate(dArr[i]);
        }
        update();
        this.alpha = d;
    }
}
