package smile.validation;

import java.util.Arrays;
import java.util.function.BiFunction;
import java.util.function.Function;
import smile.classification.Classifier;
import smile.classification.DataFrameClassifier;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.math.MathEx;
import smile.regression.DataFrameRegression;
import smile.regression.Regression;
import smile.sort.QuickSort;

/* loaded from: input_file:smile-core-2.4.0.jar:smile/validation/GroupKFold.class */
public class GroupKFold {
    public final int k;
    public final int[][] train;
    public final int[][] test;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:smile-core-2.4.0.jar:smile/validation/GroupKFold$TestFolds.class */
    public class TestFolds {
        private final int[] numTestSamplesPerFold;
        private final int[] groupToTestFoldIndex;

        private TestFolds(int[] iArr, int[] iArr2) {
            this.numTestSamplesPerFold = iArr;
            this.groupToTestFoldIndex = iArr2;
        }
    }

    /* JADX WARN: Type inference failed for: r1v5, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r1v7, types: [int[], int[][]] */
    public GroupKFold(int i, int i2, int[] iArr) {
        if (i < 0) {
            throw new IllegalArgumentException("Invalid sample size: " + i);
        }
        if (i2 < 0) {
            throw new IllegalArgumentException("Invalid number of folds: " + i2);
        }
        if (iArr.length != i) {
            throw new IllegalArgumentException("Groups array must be of size n, but length is " + iArr.length);
        }
        int[] unique = MathEx.unique(iArr);
        int length = unique.length;
        if (i2 > length) {
            throw new IllegalArgumentException("Number of splits mustn't be greater than number of groups");
        }
        Arrays.sort(unique);
        for (int i3 = 0; i3 < length; i3++) {
            if (unique[i3] != i3) {
                throw new IllegalArgumentException("Invalid encoding of groups, all group indices between [0, numGroups) have to exist");
            }
        }
        this.k = i2;
        this.train = new int[i2];
        this.test = new int[i2];
        TestFolds calculateTestFolds = calculateTestFolds(iArr, length);
        for (int i4 = 0; i4 < i2; i4++) {
            this.train[i4] = new int[i - calculateTestFolds.numTestSamplesPerFold[i4]];
            this.test[i4] = new int[calculateTestFolds.numTestSamplesPerFold[i4]];
            int i5 = 0;
            int i6 = 0;
            for (int i7 = 0; i7 < i; i7++) {
                if (calculateTestFolds.groupToTestFoldIndex[iArr[i7]] == i4) {
                    int i8 = i6;
                    i6++;
                    this.test[i4][i8] = i7;
                } else {
                    int i9 = i5;
                    i5++;
                    this.train[i4][i9] = i7;
                }
            }
        }
    }

    private TestFolds calculateTestFolds(int[] iArr, int i) {
        int[] iArr2 = new int[i];
        for (int i2 : iArr) {
            iArr2[i2] = iArr2[i2] + 1;
        }
        int[] sort = QuickSort.sort(iArr2);
        int[] iArr3 = new int[this.k];
        int[] iArr4 = new int[i];
        for (int i3 = i - 1; i3 >= 0; i3--) {
            int whichMin = MathEx.whichMin(iArr3);
            iArr3[whichMin] = iArr3[whichMin] + iArr2[i3];
            iArr4[sort[i3]] = whichMin;
        }
        return new TestFolds(iArr3, iArr4);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public <T> int[] classification(T[] tArr, int[] iArr, BiFunction<T[], int[], Classifier<T>> biFunction) {
        int[] iArr2 = new int[tArr.length];
        for (int i = 0; i < this.k; i++) {
            Classifier classifier = (Classifier) biFunction.apply(MathEx.slice(tArr, this.train[i]), MathEx.slice(iArr, this.train[i]));
            for (int i2 : this.test[i]) {
                iArr2[i2] = classifier.predict((Classifier) tArr[i2]);
            }
        }
        return iArr2;
    }

    public int[] classification(DataFrame dataFrame, Function<DataFrame, DataFrameClassifier> function) {
        int[] iArr = new int[dataFrame.size()];
        for (int i = 0; i < this.k; i++) {
            DataFrameClassifier apply = function.apply(dataFrame.of(this.train[i]));
            for (int i2 : this.test[i]) {
                iArr[i2] = apply.predict((Tuple) dataFrame.get(i2));
            }
        }
        return iArr;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public <T> double[] regression(T[] tArr, double[] dArr, BiFunction<T[], double[], Regression<T>> biFunction) {
        double[] dArr2 = new double[tArr.length];
        for (int i = 0; i < this.k; i++) {
            Regression regression = (Regression) biFunction.apply(MathEx.slice(tArr, this.train[i]), MathEx.slice(dArr, this.train[i]));
            for (int i2 : this.test[i]) {
                dArr2[i2] = regression.predict((Regression) tArr[i2]);
            }
        }
        return dArr2;
    }

    public double[] regression(DataFrame dataFrame, Function<DataFrame, DataFrameRegression> function) {
        double[] dArr = new double[dataFrame.size()];
        for (int i = 0; i < this.k; i++) {
            DataFrameRegression apply = function.apply(dataFrame.of(this.train[i]));
            for (int i2 : this.test[i]) {
                dArr[i2] = apply.predict((Tuple) dataFrame.get(i2));
            }
        }
        return dArr;
    }
}
