package org.ojalgo.ann;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInput;
import java.io.DataInputStream;
import java.io.DataOutput;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.Consumer;
import org.ojalgo.data.DataBatch;
import org.ojalgo.function.BinaryFunction;
import org.ojalgo.function.PrimitiveFunction;
import org.ojalgo.function.aggregator.Aggregator;
import org.ojalgo.function.constant.PrimitiveMath;
import org.ojalgo.function.special.MissingMath;
import org.ojalgo.matrix.store.MatrixStore;
import org.ojalgo.matrix.store.PhysicalStore;
import org.ojalgo.matrix.store.Primitive32Store;
import org.ojalgo.matrix.store.Primitive64Store;
import org.ojalgo.structure.Access1D;
import org.ojalgo.structure.Structure2D;

/* loaded from: input_file:ojalgo-51.3.0.jar:org/ojalgo/ann/ArtificialNeuralNetwork.class */
public final class ArtificialNeuralNetwork {
    private transient TrainingConfiguration myConfiguration = null;
    private final PhysicalStore.Factory<Double, ?> myFactory;
    private final CalculationLayer[] myLayers;

    /* loaded from: input_file:ojalgo-51.3.0.jar:org/ojalgo/ann/ArtificialNeuralNetwork$Activator.class */
    public enum Activator {
        IDENTITY(ArtificialNeuralNetwork::doIdentity, d -> {
            return PrimitiveMath.ONE;
        }, true),
        RELU(ArtificialNeuralNetwork::doReLU, d2 -> {
            return d2 > PrimitiveMath.ZERO ? PrimitiveMath.ONE : PrimitiveMath.ZERO;
        }, true),
        SIGMOID(ArtificialNeuralNetwork::doSigmoid, d3 -> {
            return d3 * (PrimitiveMath.ONE - d3);
        }, true),
        SOFTMAX(ArtificialNeuralNetwork::doSoftMax, d4 -> {
            return PrimitiveMath.ONE;
        }, false),
        TANH(ArtificialNeuralNetwork::doTanh, d5 -> {
            return PrimitiveMath.ONE - (d5 * d5);
        }, true);

        private final PrimitiveFunction.Unary myDerivativeInTermsOfOutput;
        private final Consumer<PhysicalStore<Double>> myFunction;
        private final boolean mySingleFolded;

        Activator(Consumer consumer, PrimitiveFunction.Unary unary, boolean z) {
            this.myFunction = consumer;
            this.myDerivativeInTermsOfOutput = unary;
            this.mySingleFolded = z;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public void activate(PhysicalStore<Double> physicalStore) {
            this.myFunction.accept(physicalStore);
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public void activate(PhysicalStore<Double> physicalStore, double d) {
            if (PrimitiveMath.ZERO >= d || d > PrimitiveMath.ONE) {
                throw new IllegalArgumentException();
            }
            this.myFunction.accept(physicalStore);
            physicalStore.modifyAll(NodeDropper.of(d));
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public PrimitiveFunction.Unary getDerivativeInTermsOfOutput() {
            return this.myDerivativeInTermsOfOutput;
        }

        boolean isSingleFolded() {
            return this.mySingleFolded;
        }
    }

    /* loaded from: input_file:ojalgo-51.3.0.jar:org/ojalgo/ann/ArtificialNeuralNetwork$Error.class */
    public enum Error implements PrimitiveFunction.Binary {
        CROSS_ENTROPY((d, d2) -> {
            return (-d) * Math.log(d2);
        }, (d3, d4) -> {
            return d4 - d3;
        }),
        HALF_SQUARED_DIFFERENCE((d5, d6) -> {
            return PrimitiveMath.HALF * (d5 - d6) * (d5 - d6);
        }, (d7, d8) -> {
            return d8 - d7;
        });

        private final PrimitiveFunction.Binary myDerivative;
        private final PrimitiveFunction.Binary myFunction;

        Error(PrimitiveFunction.Binary binary, PrimitiveFunction.Binary binary2) {
            this.myFunction = binary;
            this.myDerivative = binary2;
        }

        public double invoke(Access1D<?> access1D, Access1D<?> access1D2) {
            int minIntExact = MissingMath.toMinIntExact(access1D.count(), access1D2.count());
            double d = PrimitiveMath.ZERO;
            for (int i = 0; i < minIntExact; i++) {
                d += this.myFunction.invoke(access1D.doubleValue(i), access1D2.doubleValue(i));
            }
            return d;
        }

        @Override // org.ojalgo.function.BinaryFunction
        public double invoke(double d, double d2) {
            return this.myFunction.invoke(d, d2);
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public PrimitiveFunction.Binary getDerivative() {
            return this.myDerivative;
        }
    }

    public static NetworkBuilder builder(int i) {
        return builder(Primitive64Store.FACTORY, i);
    }

    @Deprecated
    public static NetworkTrainer builder(int i, int... iArr) {
        return builder(Primitive64Store.FACTORY, i, iArr);
    }

    public static NetworkBuilder builder(PhysicalStore.Factory<Double, ?> factory, int i) {
        return new NetworkBuilder(factory, i);
    }

    @Deprecated
    public static NetworkTrainer builder(PhysicalStore.Factory<Double, ?> factory, int i, int... iArr) {
        NetworkBuilder builder = builder(factory, i);
        for (int i2 : iArr) {
            builder.layer(i2);
        }
        return builder.get().newTrainer();
    }

    public static ArtificialNeuralNetwork from(DataInput dataInput) throws IOException {
        return FileFormat.read(null, dataInput);
    }

    public static ArtificialNeuralNetwork from(File file) {
        return from((PhysicalStore.Factory<Double, ?>) null, file);
    }

    public static ArtificialNeuralNetwork from(Path path, OpenOption... openOptionArr) {
        return from(null, path, openOptionArr);
    }

    public static ArtificialNeuralNetwork from(PhysicalStore.Factory<Double, ?> factory, DataInput dataInput) throws IOException {
        return FileFormat.read(factory, dataInput);
    }

    public static ArtificialNeuralNetwork from(PhysicalStore.Factory<Double, ?> factory, File file) {
        try {
            DataInputStream dataInputStream = new DataInputStream(new BufferedInputStream(new FileInputStream(file)));
            Throwable th = null;
            try {
                try {
                    ArtificialNeuralNetwork from = from(factory, dataInputStream);
                    if (dataInputStream != null) {
                        if (0 != 0) {
                            try {
                                dataInputStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            dataInputStream.close();
                        }
                    }
                    return from;
                } finally {
                }
            } finally {
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public static ArtificialNeuralNetwork from(PhysicalStore.Factory<Double, ?> factory, Path path, OpenOption... openOptionArr) {
        try {
            DataInputStream dataInputStream = new DataInputStream(new BufferedInputStream(Files.newInputStream(path, openOptionArr)));
            Throwable th = null;
            try {
                try {
                    ArtificialNeuralNetwork from = from(factory, dataInputStream);
                    if (dataInputStream != null) {
                        if (0 != 0) {
                            try {
                                dataInputStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            dataInputStream.close();
                        }
                    }
                    return from;
                } finally {
                }
            } finally {
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    static void doIdentity(PhysicalStore<Double> physicalStore) {
    }

    static void doReLU(PhysicalStore<Double> physicalStore) {
        physicalStore.modifyAll(PrimitiveMath.MAX.second(PrimitiveMath.ZERO));
    }

    static void doSigmoid(PhysicalStore<Double> physicalStore) {
        physicalStore.modifyAll(PrimitiveMath.LOGISTIC);
    }

    static void doSoftMax(PhysicalStore<Double> physicalStore) {
        physicalStore.modifyAll(PrimitiveMath.EXP);
        physicalStore.onRows((BinaryFunction<Double>) PrimitiveMath.DIVIDE, (Access1D<Double>) physicalStore.reduceRows(Aggregator.SUM).collect(Primitive64Store.FACTORY)).supplyTo(physicalStore);
    }

    static void doTanh(PhysicalStore<Double> physicalStore) {
        physicalStore.modifyAll(PrimitiveMath.TANH);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public ArtificialNeuralNetwork(NetworkBuilder networkBuilder) {
        this.myFactory = networkBuilder.getFactory();
        List<LayerTemplate> layers = networkBuilder.getLayers();
        this.myLayers = new CalculationLayer[layers.size()];
        for (int i = 0; i < this.myLayers.length; i++) {
            LayerTemplate layerTemplate = layers.get(i);
            this.myLayers[i] = new CalculationLayer(this.myFactory, layerTemplate.inputs, layerTemplate.outputs, layerTemplate.activator);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public ArtificialNeuralNetwork(PhysicalStore.Factory<Double, ?> factory, int i, int[] iArr) {
        this.myFactory = factory;
        this.myLayers = new CalculationLayer[iArr.length];
        int i2 = i;
        for (int i3 = 0; i3 < iArr.length; i3++) {
            int i4 = i2;
            i2 = iArr[i3];
            this.myLayers[i3] = new CalculationLayer(factory, i4, i2, Activator.SIGMOID);
        }
    }

    public int depth() {
        return this.myLayers.length;
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        return obj != null && (obj instanceof ArtificialNeuralNetwork) && Arrays.equals(this.myLayers, ((ArtificialNeuralNetwork) obj).myLayers);
    }

    public Activator getActivator(int i) {
        return this.myLayers[i].getActivator();
    }

    public double getBias(int i, int i2) {
        return this.myLayers[i].getBias(i2);
    }

    public double getWeight(int i, int i2, int i3) {
        return this.myLayers[i].getWeight(i2, i3);
    }

    public int hashCode() {
        return (31 * 1) + Arrays.hashCode(this.myLayers);
    }

    public NetworkInvoker newInvoker() {
        return newInvoker(1);
    }

    public NetworkInvoker newInvoker(int i) {
        return new NetworkInvoker(this, i);
    }

    public NetworkTrainer newTrainer() {
        return newTrainer(1);
    }

    public NetworkTrainer newTrainer(int i) {
        NetworkTrainer networkTrainer = new NetworkTrainer(this, i);
        if (getOutputActivator() == Activator.SOFTMAX) {
            networkTrainer.error(Error.CROSS_ENTROPY);
        } else {
            networkTrainer.error(Error.HALF_SQUARED_DIFFERENCE);
        }
        return networkTrainer;
    }

    public Structure2D[] structure() {
        Structure2D[] structure2DArr = new Structure2D[this.myLayers.length];
        for (int i = 0; i < structure2DArr.length; i++) {
            structure2DArr[i] = this.myLayers[i].getStructure();
        }
        return structure2DArr;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("ArtificialNeuralNetwork [Layers=");
        for (CalculationLayer calculationLayer : this.myLayers) {
            sb.append("\n");
            sb.append(calculationLayer);
        }
        sb.append("\n");
        sb.append("]");
        return sb.toString();
    }

    public int width() {
        int countInputNodes = this.myLayers[0].countInputNodes();
        for (CalculationLayer calculationLayer : this.myLayers) {
            countInputNodes = Math.max(countInputNodes, calculationLayer.countOutputNodes());
        }
        return countInputNodes;
    }

    public void writeTo(DataOutput dataOutput) throws IOException {
        FileFormat.write(this, this.myFactory == Primitive32Store.FACTORY ? 2 : 1, dataOutput);
    }

    public void writeTo(File file) {
        try {
            DataOutputStream dataOutputStream = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(file)));
            Throwable th = null;
            try {
                try {
                    writeTo(dataOutputStream);
                    if (dataOutputStream != null) {
                        if (0 != 0) {
                            try {
                                dataOutputStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            dataOutputStream.close();
                        }
                    }
                } finally {
                }
            } finally {
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public void writeTo(Path path, OpenOption... openOptionArr) {
        try {
            DataOutputStream dataOutputStream = new DataOutputStream(new BufferedOutputStream(Files.newOutputStream(path, openOptionArr)));
            Throwable th = null;
            try {
                try {
                    writeTo(dataOutputStream);
                    if (dataOutputStream != null) {
                        if (0 != 0) {
                            try {
                                dataOutputStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            dataOutputStream.close();
                        }
                    }
                } finally {
                }
            } finally {
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void adjust(int i, PhysicalStore<Double> physicalStore, PhysicalStore<Double> physicalStore2, PhysicalStore<Double> physicalStore3, PhysicalStore<Double> physicalStore4) {
        this.myLayers[i].adjust(physicalStore, physicalStore2, physicalStore3, physicalStore4, -this.myConfiguration.learningRate, this.myConfiguration.probabilityDidKeepInput(i), this.myConfiguration.regularisation());
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public int countInputNodes() {
        return this.myLayers[0].countInputNodes();
    }

    int countInputNodes(int i) {
        return this.myLayers[i].countInputNodes();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public int countOutputNodes() {
        return this.myLayers[this.myLayers.length - 1].countOutputNodes();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public int countOutputNodes(int i) {
        return this.myLayers[i].countOutputNodes();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Activator getOutputActivator() {
        return this.myLayers[this.myLayers.length - 1].getActivator();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public List<MatrixStore<Double>> getWeights() {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.myLayers.length; i++) {
            arrayList.add(this.myLayers[i].getLogicalWeights());
        }
        return arrayList;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public PhysicalStore<Double> invoke(int i, PhysicalStore<Double> physicalStore, PhysicalStore<Double> physicalStore2) {
        return this.myConfiguration != null ? this.myLayers[i].invoke(physicalStore, physicalStore2, this.myConfiguration.probabilityWillKeepOutput(i, depth())) : this.myLayers[i].invoke(physicalStore, physicalStore2);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public DataBatch newBatch(int i, int i2) {
        return DataBatch.from(this.myFactory, i, i2);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public PhysicalStore<Double> newStore(int i, int i2) {
        return (PhysicalStore) this.myFactory.make(i, i2);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void randomise() {
        for (int i = 0; i < this.myLayers.length; i++) {
            this.myLayers[i].randomise();
        }
    }

    void scale(int i, double d) {
        this.myLayers[i].scale(d);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void setActivator(int i, Activator activator) {
        this.myLayers[i].setActivator(activator);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void setBias(int i, int i2, double d) {
        this.myLayers[i].setBias(i2, d);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void setConfiguration(TrainingConfiguration trainingConfiguration) {
        if (this.myConfiguration != null && trainingConfiguration == null) {
            int depth = depth();
            for (int i = 1; i < depth; i++) {
                scale(i, this.myConfiguration.probabilityDidKeepInput(i));
            }
        }
        this.myConfiguration = trainingConfiguration;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void setWeight(int i, int i2, int i3, double d) {
        this.myLayers[i].setWeight(i2, i3, d);
    }
}
