package org.ojalgo.ann;

import java.util.function.DoubleUnaryOperator;
import org.ojalgo.ann.ArtificialNeuralNetwork;
import org.ojalgo.function.constant.PrimitiveMath;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:ojalgo-51.3.0.jar:org/ojalgo/ann/TrainingConfiguration.class */
public final class TrainingConfiguration {
    boolean dropouts = false;
    ArtificialNeuralNetwork.Error error = ArtificialNeuralNetwork.Error.HALF_SQUARED_DIFFERENCE;
    double learningRate = PrimitiveMath.HUNDREDTH;
    boolean regularisationL1 = false;
    double regularisationL1Factor = PrimitiveMath.ZERO;
    boolean regularisationL2 = false;
    double regularisationL2Factor = PrimitiveMath.ZERO;

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (!(obj instanceof TrainingConfiguration)) {
            return false;
        }
        TrainingConfiguration trainingConfiguration = (TrainingConfiguration) obj;
        return this.dropouts == trainingConfiguration.dropouts && this.error == trainingConfiguration.error && Double.doubleToLongBits(this.learningRate) == Double.doubleToLongBits(trainingConfiguration.learningRate) && this.regularisationL1 == trainingConfiguration.regularisationL1 && Double.doubleToLongBits(this.regularisationL1Factor) == Double.doubleToLongBits(trainingConfiguration.regularisationL1Factor) && this.regularisationL2 == trainingConfiguration.regularisationL2 && Double.doubleToLongBits(this.regularisationL2Factor) == Double.doubleToLongBits(trainingConfiguration.regularisationL2Factor);
    }

    public int hashCode() {
        int hashCode = (31 * ((31 * 1) + (this.dropouts ? 1231 : 1237))) + (this.error == null ? 0 : this.error.hashCode());
        long doubleToLongBits = Double.doubleToLongBits(this.learningRate);
        int i = (31 * ((31 * hashCode) + ((int) (doubleToLongBits ^ (doubleToLongBits >>> 32))))) + (this.regularisationL1 ? 1231 : 1237);
        long doubleToLongBits2 = Double.doubleToLongBits(this.regularisationL1Factor);
        int i2 = (31 * ((31 * i) + ((int) (doubleToLongBits2 ^ (doubleToLongBits2 >>> 32))))) + (this.regularisationL2 ? 1231 : 1237);
        long doubleToLongBits3 = Double.doubleToLongBits(this.regularisationL2Factor);
        return (31 * i2) + ((int) (doubleToLongBits3 ^ (doubleToLongBits3 >>> 32)));
    }

    private double doL1(double d) {
        return d < PrimitiveMath.ZERO ? -this.regularisationL1Factor : this.regularisationL1Factor;
    }

    private double doL2(double d) {
        return this.regularisationL2Factor * d;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public double probabilityDidKeepInput(int i) {
        return (!this.dropouts || i == 0) ? PrimitiveMath.ONE : PrimitiveMath.HALF;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public double probabilityWillKeepOutput(int i, int i2) {
        return (!this.dropouts || i >= i2 - 1) ? PrimitiveMath.ONE : PrimitiveMath.HALF;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public DoubleUnaryOperator regularisation() {
        if (this.regularisationL2) {
            return this.regularisationL1 ? d -> {
                return doL1(d) + doL2(d);
            } : this::doL2;
        }
        if (this.regularisationL1) {
            return this::doL1;
        }
        return null;
    }
}
