package smile.base.rbf;

import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.util.Arrays;
import smile.clustering.CLARANS;
import smile.clustering.KMeans;
import smile.math.MathEx;
import smile.math.distance.EuclideanDistance;
import smile.math.distance.Metric;
import smile.math.rbf.GaussianRadialBasis;
import smile.math.rbf.RadialBasisFunction;

/* loaded from: input_file:smile-core-2.4.0.jar:smile/base/rbf/RBF.class */
public class RBF<T> implements Serializable {
    private static final long serialVersionUID = 2;
    private T center;
    private RadialBasisFunction rbf;
    private Metric<T> distance;

    public RBF(T t, RadialBasisFunction radialBasisFunction, Metric<T> metric) {
        this.center = t;
        this.rbf = radialBasisFunction;
        this.distance = metric;
    }

    public double f(T t) {
        return this.rbf.f(this.distance.d(t, this.center));
    }

    public static <T> RBF<T>[] of(T[] tArr, RadialBasisFunction radialBasisFunction, Metric<T> metric) {
        int length = tArr.length;
        RBF<T>[] rbfArr = new RBF[length];
        for (int i = 0; i < length; i++) {
            rbfArr[i] = new RBF<>(tArr[i], radialBasisFunction, metric);
        }
        return rbfArr;
    }

    public static <T> RBF<T>[] of(T[] tArr, RadialBasisFunction[] radialBasisFunctionArr, Metric<T> metric) {
        int length = tArr.length;
        RBF<T>[] rbfArr = new RBF[length];
        for (int i = 0; i < length; i++) {
            rbfArr[i] = new RBF<>(tArr[i], radialBasisFunctionArr[i], metric);
        }
        return rbfArr;
    }

    private static <T> double estimateWidth(T[] tArr, Metric<T> metric) {
        int length = tArr.length;
        double d = 0.0d;
        for (int i = 0; i < length; i++) {
            for (int i2 = 0; i2 < i; i2++) {
                double d2 = metric.d(tArr[i], tArr[i2]);
                if (d < d2) {
                    d = d2;
                }
            }
        }
        return d / Math.sqrt(2 * length);
    }

    private static <T> double[] estimateWidth(T[] tArr, Metric<T> metric, int i) {
        int length = tArr.length;
        double[] dArr = new double[length];
        GaussianRadialBasis[] gaussianRadialBasisArr = new GaussianRadialBasis[length];
        for (int i2 = 0; i2 < length; i2++) {
            for (int i3 = 0; i3 < length; i3++) {
                dArr[i3] = metric.d(tArr[i2], tArr[i3]);
            }
            Arrays.sort(dArr);
            double d = 0.0d;
            for (int i4 = 1; i4 <= i; i4++) {
                d += dArr[i4];
            }
            gaussianRadialBasisArr[i2] = new GaussianRadialBasis(d / i);
        }
        return dArr;
    }

    private static <T> double[] estimateWidth(T[] tArr, int[] iArr, T[] tArr2, int[] iArr2, Metric<T> metric, double d) {
        int length = tArr2.length;
        double[] dArr = new double[length];
        for (int i = 0; i < tArr.length; i++) {
            int i2 = iArr[i];
            dArr[i2] = dArr[i2] + MathEx.sqr(metric.d(tArr[i], tArr2[iArr[i]]));
        }
        for (int i3 = 0; i3 < length; i3++) {
            if (iArr2[i3] >= 5 || dArr[i3] != 0.0d) {
                dArr[i3] = Math.sqrt(dArr[i3] / iArr2[i3]);
            } else {
                dArr[i3] = Double.POSITIVE_INFINITY;
                for (int i4 = 0; i4 < length; i4++) {
                    if (i3 != i4) {
                        double d2 = metric.d(tArr2[i3], tArr2[i4]);
                        if (d2 < dArr[i3]) {
                            dArr[i3] = d2;
                        }
                    }
                }
                int i5 = i3;
                dArr[i5] = dArr[i5] / 2.0d;
            }
            int i6 = i3;
            dArr[i6] = dArr[i6] * d;
        }
        return dArr;
    }

    private static GaussianRadialBasis[] gaussian(double[] dArr) {
        int length = dArr.length;
        GaussianRadialBasis[] gaussianRadialBasisArr = new GaussianRadialBasis[length];
        for (int i = 0; i < length; i++) {
            gaussianRadialBasisArr[i] = new GaussianRadialBasis(dArr[i]);
        }
        return gaussianRadialBasisArr;
    }

    public static RBF<double[]>[] fit(double[][] dArr, int i) {
        double[][] dArr2 = (double[][]) KMeans.fit(dArr, i, 10, 1.0E-4d).centroids;
        EuclideanDistance euclideanDistance = new EuclideanDistance();
        return of(dArr2, new GaussianRadialBasis(estimateWidth(dArr2, euclideanDistance)), euclideanDistance);
    }

    public static RBF<double[]>[] fit(double[][] dArr, int i, int i2) {
        if (i2 < 1 || i2 >= i) {
            throw new IllegalArgumentException("Invalid number of nearest neighbors: " + i2);
        }
        double[][] dArr2 = (double[][]) KMeans.fit(dArr, i, 10, 1.0E-4d).centroids;
        EuclideanDistance euclideanDistance = new EuclideanDistance();
        return of(dArr2, gaussian(estimateWidth(dArr2, euclideanDistance, i2)), euclideanDistance);
    }

    public static RBF<double[]>[] fit(double[][] dArr, int i, double d) {
        if (d <= 0.0d) {
            throw new IllegalArgumentException("Invalid scaling parameter: " + d);
        }
        KMeans fit = KMeans.fit(dArr, i, 10, 1.0E-4d);
        double[][] dArr2 = (double[][]) fit.centroids;
        EuclideanDistance euclideanDistance = new EuclideanDistance();
        return of(dArr2, gaussian(estimateWidth(dArr, fit.y, dArr2, fit.size, euclideanDistance, d)), euclideanDistance);
    }

    public static <T> RBF<T>[] fit(T[] tArr, Metric<T> metric, int i) {
        metric.getClass();
        T[] tArr2 = CLARANS.fit(tArr, metric::d, i).centroids;
        return of(tArr2, new GaussianRadialBasis(estimateWidth(tArr2, metric)), metric);
    }

    public static <T> RBF<T>[] fit(T[] tArr, Metric<T> metric, int i, int i2) {
        if (i2 < 1 || i2 >= i) {
            throw new IllegalArgumentException("Invalid number of nearest neighbors: " + i2);
        }
        metric.getClass();
        T[] tArr2 = CLARANS.fit(tArr, metric::d, i).centroids;
        return of(tArr2, gaussian(estimateWidth(tArr2, metric, i2)), metric);
    }

    public static <T> RBF<T>[] fit(T[] tArr, Metric<T> metric, int i, double d) {
        if (d <= 0.0d) {
            throw new IllegalArgumentException("Invalid scaling parameter: " + d);
        }
        metric.getClass();
        CLARANS fit = CLARANS.fit(tArr, metric::d, i);
        T[] tArr2 = fit.centroids;
        return of(tArr2, gaussian(estimateWidth(tArr, fit.y, tArr2, fit.size, metric, d)), metric);
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 100:
                if (implMethodName.equals("d")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 9 && serializedLambda.getFunctionalInterfaceClass().equals("smile/math/distance/Distance") && serializedLambda.getFunctionalInterfaceMethodName().equals("d") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)D") && serializedLambda.getImplClass().equals("smile/math/distance/Distance") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)D")) {
                    Metric metric = (Metric) serializedLambda.getCapturedArg(0);
                    return metric::d;
                }
                if (serializedLambda.getImplMethodKind() == 9 && serializedLambda.getFunctionalInterfaceClass().equals("smile/math/distance/Distance") && serializedLambda.getFunctionalInterfaceMethodName().equals("d") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)D") && serializedLambda.getImplClass().equals("smile/math/distance/Distance") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)D")) {
                    Metric metric2 = (Metric) serializedLambda.getCapturedArg(0);
                    return metric2::d;
                }
                if (serializedLambda.getImplMethodKind() == 9 && serializedLambda.getFunctionalInterfaceClass().equals("smile/math/distance/Distance") && serializedLambda.getFunctionalInterfaceMethodName().equals("d") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)D") && serializedLambda.getImplClass().equals("smile/math/distance/Distance") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)D")) {
                    Metric metric3 = (Metric) serializedLambda.getCapturedArg(0);
                    return metric3::d;
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
