package jsat.math;

import jsat.linear.Vec;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/math/MathTricks.class */
public class MathTricks {
    public static final Function sqrtFunc = new FunctionBase() { // from class: jsat.math.MathTricks.1
        private static final long serialVersionUID = -5898515135319116600L;

        @Override // jsat.math.Function
        public double f(Vec vec) {
            return Math.sqrt(vec.get(0));
        }
    };
    public static final Function sqrdFunc = new FunctionBase() { // from class: jsat.math.MathTricks.2
        private static final long serialVersionUID = 6831886040279358142L;

        @Override // jsat.math.Function
        public double f(Vec vec) {
            double d = vec.get(0);
            return d * d;
        }
    };
    public static final Function invsFunc = new FunctionBase() { // from class: jsat.math.MathTricks.3
        private static final long serialVersionUID = -7745316806635400174L;

        @Override // jsat.math.Function
        public double f(Vec vec) {
            return 1.0d / vec.get(0);
        }
    };
    public static final Function logFunc = new FunctionBase() { // from class: jsat.math.MathTricks.4
        private static final long serialVersionUID = -4653355640520837353L;

        @Override // jsat.math.Function
        public double f(Vec vec) {
            return Math.log(vec.get(0));
        }
    };
    public static final Function expFunc = new FunctionBase() { // from class: jsat.math.MathTricks.5
        private static final long serialVersionUID = 7075309263321302492L;

        @Override // jsat.math.Function
        public double f(Vec vec) {
            return Math.exp(vec.get(0));
        }
    };
    public static final Function absFunc = new FunctionBase() { // from class: jsat.math.MathTricks.6
        private static final long serialVersionUID = -3706702191562872641L;

        @Override // jsat.math.Function
        public double f(Vec vec) {
            return Math.abs(vec.get(0));
        }
    };

    private MathTricks() {
    }

    public static double max(double... dArr) {
        double d = Double.NEGATIVE_INFINITY;
        for (double d2 : dArr) {
            d = Math.max(d2, d);
        }
        return d;
    }

    public static double min(double... dArr) {
        double d = Double.NEGATIVE_INFINITY;
        for (double d2 : dArr) {
            d = Math.min(d2, d);
        }
        return d;
    }

    public static double logSumExp(Vec vec, double d) {
        double d2 = 0.0d;
        for (int i = 0; i < vec.length(); i++) {
            d2 += Math.exp(vec.get(i) - d);
        }
        return d + Math.log(d2);
    }

    public static double logSumExp(double[] dArr, double d) {
        double d2 = 0.0d;
        for (double d3 : dArr) {
            d2 += Math.exp(d3 - d);
        }
        return d + Math.log(d2);
    }

    public static void softmax(double[] dArr, boolean z) {
        double d = z ? 1.0d : Double.NEGATIVE_INFINITY;
        for (double d2 : dArr) {
            d = max(d, d2);
        }
        double exp = z ? Math.exp(-d) : 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            double exp2 = Math.exp(dArr[i] - d);
            dArr[i] = exp2;
            exp += exp2;
        }
        for (int i2 = 0; i2 < dArr.length; i2++) {
            int i3 = i2;
            dArr[i3] = dArr[i3] / exp;
        }
    }

    public static void softmax(Vec vec, boolean z) {
        double max = max(z ? 1.0d : Double.NEGATIVE_INFINITY, vec.max());
        double exp = z ? Math.exp(-max) : 0.0d;
        for (int i = 0; i < vec.length(); i++) {
            double exp2 = Math.exp(vec.get(i) - max);
            vec.set(i, exp2);
            exp += exp2;
        }
        vec.mutableDivide(exp);
    }

    public static double hornerPolyR(double[] dArr, double d) {
        double d2 = 0.0d;
        for (double d3 : dArr) {
            d2 = (d2 * d) + d3;
        }
        return d2;
    }

    public static double hornerPoly(double[] dArr, double d) {
        double d2 = 0.0d;
        for (int length = dArr.length - 1; length >= 0; length--) {
            d2 = (d2 * d) + dArr[length];
        }
        return d2;
    }
}
