package org.ojalgo.ann;

import java.util.List;
import java.util.function.Supplier;
import org.ojalgo.ann.ArtificialNeuralNetwork;
import org.ojalgo.data.DataBatch;
import org.ojalgo.matrix.store.MatrixStore;
import org.ojalgo.matrix.store.PhysicalStore;
import org.ojalgo.structure.Access1D;
import org.ojalgo.structure.Structure2D;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:ojalgo-49.2.1.jar:org/ojalgo/ann/WrappedANN.class */
public abstract class WrappedANN implements Supplier<ArtificialNeuralNetwork> {
    private final int myBatchSize;
    private PhysicalStore<Double> myInput;
    private final ArtificialNeuralNetwork myNetwork;
    private final PhysicalStore<Double>[] myOutputs;

    /* JADX INFO: Access modifiers changed from: package-private */
    public WrappedANN(ArtificialNeuralNetwork artificialNeuralNetwork, int i) {
        this.myNetwork = artificialNeuralNetwork;
        this.myBatchSize = i;
        this.myOutputs = new PhysicalStore[artificialNeuralNetwork.depth()];
        for (int i2 = 0; i2 < this.myOutputs.length; i2++) {
            this.myOutputs[i2] = artificialNeuralNetwork.newStore(i, artificialNeuralNetwork.countOutputNodes(i2));
        }
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (!(obj instanceof WrappedANN)) {
            return false;
        }
        WrappedANN wrappedANN = (WrappedANN) obj;
        return this.myNetwork == null ? wrappedANN.myNetwork == null : this.myNetwork.equals(wrappedANN.myNetwork);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // java.util.function.Supplier
    public ArtificialNeuralNetwork get() {
        return this.myNetwork;
    }

    public int hashCode() {
        return (31 * 1) + (this.myNetwork == null ? 0 : this.myNetwork.hashCode());
    }

    public DataBatch newInputBatch() {
        return this.myNetwork.newBatch(this.myBatchSize, this.myNetwork.countInputNodes());
    }

    private void setInput(Access1D<Double> access1D) {
        if ((access1D instanceof PhysicalStore) && ((PhysicalStore) access1D).getRowDim() == this.myBatchSize) {
            this.myInput = (PhysicalStore) access1D;
            return;
        }
        if (this.myInput == null || this.myInput.getRowDim() != this.myBatchSize) {
            this.myInput = this.myNetwork.newStore(this.myBatchSize, this.myNetwork.countInputNodes());
        }
        this.myInput.fillMatching(access1D);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void adjust(int i, PhysicalStore<Double> physicalStore, PhysicalStore<Double> physicalStore2, PhysicalStore<Double> physicalStore3, PhysicalStore<Double> physicalStore4) {
        this.myNetwork.adjust(i, physicalStore, physicalStore2, physicalStore3, physicalStore4);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public int depth() {
        return this.myNetwork.depth();
    }

    ArtificialNeuralNetwork.Activator getActivator(int i) {
        return this.myNetwork.getActivator(i);
    }

    int getBatchSize() {
        return this.myBatchSize;
    }

    double getBias(int i, int i2) {
        return this.myNetwork.getBias(i, i2);
    }

    PhysicalStore<Double> getInput() {
        return this.myInput;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public PhysicalStore<Double> getInput(int i) {
        return i <= 0 ? this.myInput : this.myOutputs[i - 1];
    }

    PhysicalStore<Double> getOutput() {
        return this.myOutputs[this.myOutputs.length - 1];
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public PhysicalStore<Double> getOutput(int i) {
        return this.myOutputs[i];
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public ArtificialNeuralNetwork.Activator getOutputActivator() {
        return this.myNetwork.getOutputActivator();
    }

    double getWeight(int i, int i2, int i3) {
        return this.myNetwork.getWeight(i, i2, i3);
    }

    List<MatrixStore<Double>> getWeights() {
        return this.myNetwork.getWeights();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public MatrixStore<Double> invoke(Access1D<Double> access1D, TrainingConfiguration trainingConfiguration) {
        setInput(access1D);
        this.myNetwork.setConfiguration(trainingConfiguration);
        PhysicalStore<Double> physicalStore = this.myInput;
        int depth = depth();
        for (int i = 0; i < depth; i++) {
            physicalStore = this.myNetwork.invoke(i, physicalStore, this.myOutputs[i]);
        }
        return physicalStore;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public DataBatch newOutputBatch() {
        return this.myNetwork.newBatch(this.myBatchSize, this.myNetwork.countOutputNodes());
    }

    void randomise() {
        this.myNetwork.randomise();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void setActivator(int i, ArtificialNeuralNetwork.Activator activator) {
        this.myNetwork.setActivator(i, activator);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void setBias(int i, int i2, double d) {
        this.myNetwork.setBias(i, i2, d);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void setWeight(int i, int i2, int i3, double d) {
        this.myNetwork.setWeight(i, i2, i3, d);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Structure2D[] structure() {
        return this.myNetwork.structure();
    }
}
