package jsat.distributions;

import jsat.linear.Vec;
import jsat.text.GreekLetters;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/distributions/Normal.class */
public class Normal extends ContinuousDistribution {
    private static final long serialVersionUID = -5298346576152986165L;
    private double mean;
    private double stndDev;

    public Normal() {
        this(0.0d, 1.0d);
    }

    public Normal(double d, double d2) {
        if (d2 <= 0.0d) {
            throw new RuntimeException("Standerd deviation of the normal distribution needs to be greater than zero");
        }
        setMean(d);
        setStndDev(d2);
    }

    public void setMean(double d) {
        if (Double.isInfinite(d) || Double.isNaN(d)) {
            throw new ArithmeticException("Mean can not be infinite of NaN");
        }
        this.mean = d;
    }

    public void setStndDev(double d) {
        if (Double.isInfinite(d) || Double.isNaN(d)) {
            throw new ArithmeticException("Standard devation can not be infinite of NaN");
        }
        if (d <= 0.0d) {
            throw new ArithmeticException("The standard devation can not be <= 0");
        }
        this.stndDev = d;
    }

    public static double cdf(double d, double d2, double d3) {
        if (Double.isNaN(d) || Double.isInfinite(d)) {
            throw new ArithmeticException("X is not a real number");
        }
        return cdfApproxMarsaglia2004(zTransform(d, d2, d3));
    }

    @Override // jsat.distributions.ContinuousDistribution, jsat.distributions.Distribution
    public double cdf(double d) {
        return cdf(d, this.mean, this.stndDev);
    }

    public static double invcdf(double d, double d2, double d3) {
        double d4;
        if (d < 0.0d || d > 1.0d) {
            throw new RuntimeException("Inverse of a probability requires a probablity in the range [0,1], not " + d);
        }
        double[] dArr = {-39.69683028665376d, 220.9460984245205d, -275.9285104469687d, 138.357751867269d, -30.66479806614716d, 2.506628277459239d};
        double[] dArr2 = {-54.47609879822406d, 161.5858368580409d, -155.6989798598866d, 66.80131188771972d, -13.28068155288572d};
        double[] dArr3 = {-0.007784894002430293d, -0.3223964580411365d, -2.400758277161838d, -2.549732539343734d, 4.374664141464968d, 2.938163982698783d};
        double[] dArr4 = {0.007784695709041462d, 0.3224671290700398d, 2.445134137142996d, 3.754408661907416d};
        double d5 = 1.0d - 0.02425d;
        if (0.0d < d && d < 0.02425d) {
            double sqrt = Math.sqrt((-2.0d) * Math.log(d));
            d4 = ((((((((((dArr3[0] * sqrt) + dArr3[1]) * sqrt) + dArr3[2]) * sqrt) + dArr3[3]) * sqrt) + dArr3[4]) * sqrt) + dArr3[5]) / ((((((((dArr4[0] * sqrt) + dArr4[1]) * sqrt) + dArr4[2]) * sqrt) + dArr4[3]) * sqrt) + 1.0d);
        } else if (0.02425d > d || d > d5) {
            double sqrt2 = Math.sqrt((-2.0d) * Math.log(1.0d - d));
            d4 = (-((((((((((dArr3[0] * sqrt2) + dArr3[1]) * sqrt2) + dArr3[2]) * sqrt2) + dArr3[3]) * sqrt2) + dArr3[4]) * sqrt2) + dArr3[5])) / ((((((((dArr4[0] * sqrt2) + dArr4[1]) * sqrt2) + dArr4[2]) * sqrt2) + dArr4[3]) * sqrt2) + 1.0d);
        } else {
            double d6 = d - 0.5d;
            double d7 = d6 * d6;
            d4 = (((((((((((dArr[0] * d7) + dArr[1]) * d7) + dArr[2]) * d7) + dArr[3]) * d7) + dArr[4]) * d7) + dArr[5]) * d6) / ((((((((((dArr2[0] * d7) + dArr2[1]) * d7) + dArr2[2]) * d7) + dArr2[3]) * d7) + dArr2[4]) * d7) + 1.0d);
        }
        double cdf = (cdf(d4, 0.0d, 1.0d) - d) * Math.sqrt(6.283185307179586d) * Math.exp((d4 * d4) / 2.0d);
        return ((d4 - (cdf / (1.0d + ((d4 * cdf) / 2.0d)))) * d3) + d2;
    }

    @Override // jsat.distributions.ContinuousDistribution, jsat.distributions.Distribution
    public double invCdf(double d) {
        return invcdf(d, this.mean, this.stndDev);
    }

    public static double pdf(double d, double d2, double d3) {
        return (1.0d / Math.sqrt((6.283185307179586d * d3) * d3)) * Math.exp((-Math.pow(d - d2, 2.0d)) / ((2.0d * d3) * d3));
    }

    @Override // jsat.distributions.ContinuousDistribution
    public double pdf(double d) {
        return pdf(d, this.mean, this.stndDev);
    }

    public static double logPdf(double d, double d2, double d3) {
        return (((-0.5d) * Math.log(6.283185307179586d)) - Math.log(d3)) + ((-Math.pow(d - d2, 2.0d)) / ((2.0d * d3) * d3));
    }

    @Override // jsat.distributions.ContinuousDistribution
    public double logPdf(double d) {
        return logPdf(d, this.mean, this.stndDev);
    }

    public double invPdf(double d) {
        return Math.exp(Math.pow(this.mean - d, 2.0d) / (2.0d * Math.pow(this.stndDev, 2.0d))) * Math.sqrt(6.283185307179586d) * this.stndDev;
    }

    public static double zTransform(double d, double d2, double d3) {
        return (d - d2) / d3;
    }

    public double zTransform(double d) {
        return zTransform(d, this.mean, this.stndDev);
    }

    private static double cdfApproxMarsaglia2004(double d) {
        if (d >= 8.22d) {
            return 1.0d;
        }
        if (d <= -8.22d) {
            return 0.0d;
        }
        double d2 = d;
        double d3 = 0.0d;
        double d4 = d;
        double d5 = d * d;
        double d6 = 1.0d;
        while (d2 != d3) {
            double d7 = d2;
            d3 = d7;
            double d8 = d6 + 2.0d;
            d6 = d7;
            double d9 = d4 * (d5 / d8);
            d4 = d7;
            d2 = d7 + d9;
        }
        return 0.5d + (d2 * Math.exp(((-0.5d) * d5) - 0.9189385332046728d));
    }

    @Override // jsat.distributions.ContinuousDistribution
    public String getDescriptiveName() {
        return "Normal(μ=" + this.mean + ", σ=" + this.stndDev + ")";
    }

    @Override // jsat.distributions.Distribution
    public double min() {
        return Double.NEGATIVE_INFINITY;
    }

    @Override // jsat.distributions.Distribution
    public double max() {
        return Double.POSITIVE_INFINITY;
    }

    @Override // jsat.distributions.ContinuousDistribution
    public String getDistributionName() {
        return "Normal";
    }

    @Override // jsat.distributions.ContinuousDistribution
    public String[] getVariables() {
        return new String[]{GreekLetters.mu, GreekLetters.sigma};
    }

    @Override // jsat.distributions.ContinuousDistribution
    public void setVariable(String str, double d) {
        if (str.equals(GreekLetters.mu)) {
            this.mean = d;
        } else if (str.equals(GreekLetters.sigma)) {
            setStndDev(d);
        }
    }

    @Override // jsat.distributions.ContinuousDistribution, jsat.distributions.Distribution
    /* renamed from: clone */
    public ContinuousDistribution mo616clone() {
        return new Normal(this.mean, this.stndDev);
    }

    @Override // jsat.distributions.ContinuousDistribution
    public void setUsingData(Vec vec) {
        this.mean = vec.mean();
        setStndDev(vec.standardDeviation());
    }

    @Override // jsat.distributions.ContinuousDistribution
    public double[] getCurrentVariableValues() {
        return new double[]{this.mean, this.stndDev};
    }

    @Override // jsat.distributions.ContinuousDistribution, jsat.distributions.Distribution
    public double mean() {
        return this.mean;
    }

    @Override // jsat.distributions.Distribution
    public double median() {
        return this.mean;
    }

    @Override // jsat.distributions.ContinuousDistribution, jsat.distributions.Distribution
    public double mode() {
        return this.mean;
    }

    @Override // jsat.distributions.ContinuousDistribution, jsat.distributions.Distribution
    public double variance() {
        return this.stndDev * this.stndDev;
    }

    @Override // jsat.distributions.Distribution
    public double standardDeviation() {
        return this.stndDev;
    }

    @Override // jsat.distributions.ContinuousDistribution, jsat.distributions.Distribution
    public double skewness() {
        return 0.0d;
    }

    public int hashCode() {
        long doubleToLongBits = Double.doubleToLongBits(this.mean);
        int i = (31 * 1) + ((int) (doubleToLongBits ^ (doubleToLongBits >>> 32)));
        long doubleToLongBits2 = Double.doubleToLongBits(this.stndDev);
        return (31 * i) + ((int) (doubleToLongBits2 ^ (doubleToLongBits2 >>> 32)));
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        Normal normal = (Normal) obj;
        return Double.doubleToLongBits(this.mean) == Double.doubleToLongBits(normal.mean) && Double.doubleToLongBits(this.stndDev) == Double.doubleToLongBits(normal.stndDev);
    }
}
