package jsat.classifiers.neuralnetwork;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicIntegerArray;
import jsat.DataSet;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.svm.DCDs;
import jsat.clustering.SeedSelectionMethods;
import jsat.clustering.kmeans.HamerlyKMeans;
import jsat.datatransform.DataTransform;
import jsat.distributions.Distribution;
import jsat.distributions.Uniform;
import jsat.distributions.discrete.UniformDiscrete;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseVector;
import jsat.linear.SparseVector;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.math.OnLineStatistics;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.utils.BoundedSortedList;
import jsat.utils.DoubleList;
import jsat.utils.FakeExecutor;
import jsat.utils.IntList;
import jsat.utils.IntSet;
import jsat.utils.ListUtils;
import jsat.utils.SystemInfo;
import jsat.utils.random.XORWOW;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/neuralnetwork/RBFNet.class */
public class RBFNet implements Classifier, Regressor, DataTransform, Parameterized {
    private static final long serialVersionUID = 5418896646203518062L;
    private int numCentroids;
    private Phase1Learner p1l;
    private Phase2Learner p2l;
    private double alpha;
    private int p;
    private DistanceMetric dm;
    private boolean normalize;
    private Classifier baseClassifier;
    private Regressor baseRegressor;
    private List<Double> centroidDistCache;
    private List<Vec> centroids;
    private double[] bandwidths;

    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/neuralnetwork/RBFNet$Phase1Learner.class */
    public enum Phase1Learner {
        RANDOM { // from class: jsat.classifiers.neuralnetwork.RBFNet.Phase1Learner.1
            @Override // jsat.classifiers.neuralnetwork.RBFNet.Phase1Learner
            protected List<Vec> getCentroids(DataSet dataSet, int i, DistanceMetric distanceMetric, ExecutorService executorService) {
                XORWOW xorwow = new XORWOW();
                ArrayList arrayList = new ArrayList();
                IntSet intSet = new IntSet();
                while (intSet.size() < i) {
                    intSet.add((IntSet) Integer.valueOf(xorwow.nextInt(dataSet.getSampleSize())));
                }
                Iterator<Integer> it = intSet.iterator();
                while (it.hasNext()) {
                    arrayList.add(dataSet.getDataPoint(it.next().intValue()).getNumericalValues());
                }
                return arrayList;
            }
        },
        K_MEANS { // from class: jsat.classifiers.neuralnetwork.RBFNet.Phase1Learner.2
            @Override // jsat.classifiers.neuralnetwork.RBFNet.Phase1Learner
            protected List<Vec> getCentroids(DataSet dataSet, int i, DistanceMetric distanceMetric, ExecutorService executorService) {
                HamerlyKMeans hamerlyKMeans = new HamerlyKMeans(distanceMetric, SeedSelectionMethods.SeedSelection.KPP);
                if (executorService == null || (executorService instanceof FakeExecutor)) {
                    hamerlyKMeans.cluster(dataSet, i);
                } else {
                    hamerlyKMeans.cluster(dataSet, i, executorService);
                }
                return hamerlyKMeans.getMeans();
            }
        };

        protected abstract List<Vec> getCentroids(DataSet dataSet, int i, DistanceMetric distanceMetric, ExecutorService executorService);
    }

    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/neuralnetwork/RBFNet$Phase2Learner.class */
    public enum Phase2Learner {
        CENTROID_DISTANCE { // from class: jsat.classifiers.neuralnetwork.RBFNet.Phase2Learner.1
            @Override // jsat.classifiers.neuralnetwork.RBFNet.Phase2Learner
            protected double[] estimateBandwidths(double d, int i, DataSet dataSet, final List<Vec> list, final List<Double> list2, final DistanceMetric distanceMetric, ExecutorService executorService) {
                final double[] dArr = new double[list.size()];
                OnLineStatistics[] onLineStatisticsArr = new OnLineStatistics[dArr.length];
                for (int i2 = 0; i2 < onLineStatisticsArr.length; i2++) {
                    onLineStatisticsArr[i2] = new OnLineStatistics();
                }
                ArrayList arrayList = new ArrayList(SystemInfo.LogicalCores);
                for (final List list3 : ListUtils.splitList(dataSet.getDataVectors(), SystemInfo.LogicalCores)) {
                    arrayList.add(executorService.submit(new Callable<OnLineStatistics[]>() { // from class: jsat.classifiers.neuralnetwork.RBFNet.Phase2Learner.1.1
                        /* JADX WARN: Can't rename method to resolve collision */
                        @Override // java.util.concurrent.Callable
                        public OnLineStatistics[] call() {
                            OnLineStatistics[] onLineStatisticsArr2 = new OnLineStatistics[dArr.length];
                            for (int i3 = 0; i3 < onLineStatisticsArr2.length; i3++) {
                                onLineStatisticsArr2[i3] = new OnLineStatistics();
                            }
                            for (Vec vec : list3) {
                                double d2 = Double.POSITIVE_INFINITY;
                                int i4 = 0;
                                for (int i5 = 0; i5 < list.size(); i5++) {
                                    double dist = distanceMetric.dist(i5, vec, list, list2);
                                    if (dist < d2) {
                                        d2 = dist;
                                        i4 = i5;
                                    }
                                }
                                onLineStatisticsArr2[i4].add(d2);
                            }
                            return onLineStatisticsArr2;
                        }
                    }));
                }
                try {
                    for (OnLineStatistics[] onLineStatisticsArr2 : ListUtils.collectFutures(arrayList)) {
                        for (int i3 = 0; i3 < onLineStatisticsArr2.length; i3++) {
                            if (onLineStatisticsArr2[i3].getSumOfWeights() != 0.0d) {
                                onLineStatisticsArr[i3] = OnLineStatistics.add(onLineStatisticsArr[i3], onLineStatisticsArr2[i3]);
                            }
                        }
                    }
                    for (int i4 = 0; i4 < dArr.length; i4++) {
                        dArr[i4] = onLineStatisticsArr[i4].getMean() + (onLineStatisticsArr[i4].getStandardDeviation() * d);
                    }
                    return dArr;
                } catch (InterruptedException e) {
                    throw new FailedToFitException(e);
                } catch (ExecutionException e2) {
                    throw new FailedToFitException(e2);
                }
            }
        },
        CLOSEST_OPPOSITE_CENTROID { // from class: jsat.classifiers.neuralnetwork.RBFNet.Phase2Learner.2
            @Override // jsat.classifiers.neuralnetwork.RBFNet.Phase2Learner
            protected double[] estimateBandwidths(final double d, int i, DataSet dataSet, final List<Vec> list, final List<Double> list2, final DistanceMetric distanceMetric, ExecutorService executorService) {
                if (!(dataSet instanceof ClassificationDataSet)) {
                    throw new FailedToFitException("CLOSEST_OPPOSITE_CENTROID only works for classification data sets");
                }
                final ClassificationDataSet classificationDataSet = (ClassificationDataSet) dataSet;
                final double[] dArr = new double[list.size()];
                final CountDownLatch countDownLatch = new CountDownLatch(SystemInfo.LogicalCores);
                final AtomicIntegerArray[] atomicIntegerArrayArr = new AtomicIntegerArray[list.size()];
                for (int i2 = 0; i2 < atomicIntegerArrayArr.length; i2++) {
                    atomicIntegerArrayArr[i2] = new AtomicIntegerArray(classificationDataSet.getClassSize());
                }
                IntList intList = new IntList(dataSet.getSampleSize());
                ListUtils.addRange(intList, 0, dataSet.getSampleSize(), 1);
                for (final List list3 : ListUtils.splitList(intList, SystemInfo.LogicalCores)) {
                    executorService.submit(new Runnable() { // from class: jsat.classifiers.neuralnetwork.RBFNet.Phase2Learner.2.1
                        @Override // java.lang.Runnable
                        public void run() {
                            Iterator it = list3.iterator();
                            while (it.hasNext()) {
                                int intValue = ((Integer) it.next()).intValue();
                                Vec numericalValues = classificationDataSet.getDataPoint(intValue).getNumericalValues();
                                double d2 = Double.POSITIVE_INFINITY;
                                int i3 = 0;
                                for (int i4 = 0; i4 < list.size(); i4++) {
                                    double dist = distanceMetric.dist(i4, numericalValues, list, list2);
                                    if (dist < d2) {
                                        d2 = dist;
                                        i3 = i4;
                                    }
                                }
                                atomicIntegerArrayArr[i3].incrementAndGet(classificationDataSet.getDataPointCategory(intValue));
                            }
                            countDownLatch.countDown();
                        }
                    });
                }
                try {
                    countDownLatch.await();
                    final int[] iArr = new int[list.size()];
                    for (int i3 = 0; i3 < iArr.length; i3++) {
                        int i4 = -1;
                        int i5 = 0;
                        for (int i6 = 0; i6 < atomicIntegerArrayArr[i3].length(); i6++) {
                            if (atomicIntegerArrayArr[i3].get(i6) > i4) {
                                i5 = i6;
                                i4 = atomicIntegerArrayArr[i3].get(i6);
                            }
                        }
                        iArr[i3] = i5;
                    }
                    final CountDownLatch countDownLatch2 = new CountDownLatch(list.size());
                    for (int i7 = 0; i7 < list.size(); i7++) {
                        final int i8 = i7;
                        executorService.submit(new Runnable() { // from class: jsat.classifiers.neuralnetwork.RBFNet.Phase2Learner.2.2
                            @Override // java.lang.Runnable
                            public void run() {
                                double d2 = Double.POSITIVE_INFINITY;
                                for (int i9 = 0; i9 < list.size(); i9++) {
                                    if (iArr[i8] != iArr[i9]) {
                                        d2 = Math.min(d2, distanceMetric.dist(i9, i8, list, list2));
                                    }
                                }
                                if (Double.isInfinite(d2)) {
                                    for (int i10 = 0; i10 < list.size(); i10++) {
                                        if (i8 != i10) {
                                            d2 = Math.min(d2, distanceMetric.dist(i10, i8, list, list2));
                                        }
                                    }
                                }
                                dArr[i8] = d * d2;
                                countDownLatch2.countDown();
                            }
                        });
                    }
                    try {
                        countDownLatch2.await();
                        return dArr;
                    } catch (InterruptedException e) {
                        throw new FailedToFitException(e);
                    }
                } catch (InterruptedException e2) {
                    throw new FailedToFitException(e2);
                }
            }
        },
        NEAREST_OTHER_CENTROID_AVERAGE { // from class: jsat.classifiers.neuralnetwork.RBFNet.Phase2Learner.3
            @Override // jsat.classifiers.neuralnetwork.RBFNet.Phase2Learner
            protected double[] estimateBandwidths(final double d, final int i, DataSet dataSet, final List<Vec> list, final List<Double> list2, final DistanceMetric distanceMetric, ExecutorService executorService) {
                final double[] dArr = new double[list.size()];
                final CountDownLatch countDownLatch = new CountDownLatch(list.size());
                for (int i2 = 0; i2 < list.size(); i2++) {
                    final int i3 = i2;
                    executorService.submit(new Runnable() { // from class: jsat.classifiers.neuralnetwork.RBFNet.Phase2Learner.3.1
                        @Override // java.lang.Runnable
                        public void run() {
                            BoundedSortedList boundedSortedList = new BoundedSortedList(i);
                            for (int i4 = 0; i4 < list.size(); i4++) {
                                if (i4 != i3) {
                                    boundedSortedList.add((BoundedSortedList) Double.valueOf(distanceMetric.dist(i4, i3, list, list2)));
                                }
                            }
                            OnLineStatistics onLineStatistics = new OnLineStatistics();
                            Iterator<E> it = boundedSortedList.iterator();
                            while (it.hasNext()) {
                                onLineStatistics.add(((Double) it.next()).doubleValue());
                            }
                            dArr[i3] = onLineStatistics.getMean() + (d * onLineStatistics.getStandardDeviation());
                            countDownLatch.countDown();
                        }
                    });
                }
                return dArr;
            }
        };

        protected abstract double[] estimateBandwidths(double d, int i, DataSet dataSet, List<Vec> list, List<Double> list2, DistanceMetric distanceMetric, ExecutorService executorService);
    }

    public RBFNet() {
        this(100);
    }

    public RBFNet(int i) {
        this(i, Phase1Learner.K_MEANS, Phase2Learner.NEAREST_OTHER_CENTROID_AVERAGE, 3.0d, 3, (DistanceMetric) new EuclideanDistance(), (Classifier) new DCDs());
    }

    public RBFNet(int i, Phase1Learner phase1Learner, Phase2Learner phase2Learner, double d, int i2, DistanceMetric distanceMetric, Classifier classifier) {
        this.normalize = true;
        setNumCentroids(i);
        setPhase1Learner(phase1Learner);
        setPhase2Learner(phase2Learner);
        setAlpha(d);
        setP(i2);
        setDistanceMetric(distanceMetric);
        this.baseClassifier = classifier;
        if (classifier instanceof Regressor) {
            this.baseRegressor = (Regressor) classifier;
        }
    }

    public RBFNet(int i, Phase1Learner phase1Learner, Phase2Learner phase2Learner, double d, int i2, DistanceMetric distanceMetric, Regressor regressor) {
        this.normalize = true;
        setNumCentroids(i);
        setPhase1Learner(phase1Learner);
        setPhase2Learner(phase2Learner);
        setAlpha(d);
        setP(i2);
        setDistanceMetric(distanceMetric);
        this.baseRegressor = regressor;
        if (regressor instanceof Classifier) {
            this.baseClassifier = (Classifier) regressor;
        }
    }

    public RBFNet(RBFNet rBFNet) {
        this.normalize = true;
        setNumCentroids(rBFNet.getNumCentroids());
        setPhase1Learner(rBFNet.getPhase1Learner());
        setPhase2Learner(rBFNet.getPhase2Learner());
        setAlpha(rBFNet.getAlpha());
        setP(rBFNet.getP());
        setDistanceMetric(rBFNet.getDistanceMetric().mo652clone());
        if (rBFNet.baseRegressor != null) {
            this.baseRegressor = rBFNet.baseRegressor.clone();
            if (this.baseRegressor instanceof Classifier) {
                this.baseClassifier = (Classifier) this.baseRegressor;
            }
        } else if (rBFNet.baseClassifier != null) {
            this.baseClassifier = rBFNet.baseClassifier.mo582clone();
            if (this.baseClassifier instanceof Regressor) {
                this.baseRegressor = (Regressor) this.baseClassifier;
            }
        }
        if (rBFNet.centroids != null) {
            this.centroids = new ArrayList(rBFNet.centroids.size());
            Iterator<Vec> it = rBFNet.centroids.iterator();
            while (it.hasNext()) {
                this.centroids.add(it.next().mo525clone());
            }
            if (rBFNet.centroidDistCache != null) {
                this.centroidDistCache = new DoubleList(rBFNet.centroidDistCache);
            }
        }
        if (rBFNet.bandwidths != null) {
            this.bandwidths = Arrays.copyOf(rBFNet.bandwidths, rBFNet.bandwidths.length);
        }
    }

    @Override // jsat.datatransform.DataTransform
    public DataPoint transform(DataPoint dataPoint) {
        Vec numericalValues = dataPoint.getNumericalValues();
        List<Double> queryInfo = this.dm.getQueryInfo(numericalValues);
        Vec sparseVector = new SparseVector(this.numCentroids);
        double d = 0.0d;
        double d2 = Double.NEGATIVE_INFINITY;
        int i = -1;
        for (int i2 = 0; i2 < this.centroids.size(); i2++) {
            double dist = this.dm.dist(i2, numericalValues, queryInfo, this.centroids, this.centroidDistCache);
            double d3 = this.bandwidths[i2];
            double exp = Math.exp((-(dist * dist)) / ((d3 * d3) * 2.0d));
            if (exp > d2) {
                d2 = exp;
                i = i2;
            }
            if (exp > 1.0E-16d) {
                sparseVector.set(i2, exp);
                d += exp;
            }
        }
        if (sparseVector.nnz() == 0) {
            sparseVector.set(i, d2);
            d = d2;
        }
        if (this.normalize && d != 0.0d) {
            sparseVector.mutableDivide(d);
        }
        if (sparseVector.nnz() > sparseVector.length() / 2) {
            sparseVector = new DenseVector(sparseVector);
        }
        return new DataPoint(sparseVector, dataPoint.getCategoricalValues(), dataPoint.getCategoricalData(), dataPoint.getWeight());
    }

    public void setAlpha(double d) {
        if (d < 0.0d || Double.isInfinite(d) || Double.isNaN(d)) {
            throw new IllegalArgumentException("Alpha must be a positive value, not " + d);
        }
        this.alpha = d;
    }

    public double getAlpha() {
        return this.alpha;
    }

    public static Distribution guessAlpha(DataSet dataSet) {
        return new Uniform(0.8d, 3.5d);
    }

    public void setP(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("neighbors parameter must be positive, not " + i);
        }
        this.p = i;
    }

    public int getP() {
        return this.p;
    }

    public static Distribution guessP(DataSet dataSet) {
        return new UniformDiscrete(2, 5);
    }

    public void setNumCentroids(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("Number of centroids must be positive, not " + i);
        }
        this.numCentroids = i;
    }

    public int getNumCentroids() {
        return this.numCentroids;
    }

    public static Distribution guessNumCentroids(DataSet dataSet) {
        return new UniformDiscrete(25, 1000);
    }

    public void setDistanceMetric(DistanceMetric distanceMetric) {
        this.dm = distanceMetric;
    }

    public DistanceMetric getDistanceMetric() {
        return this.dm;
    }

    public void setPhase1Learner(Phase1Learner phase1Learner) {
        this.p1l = phase1Learner;
    }

    public Phase1Learner getPhase1Learner() {
        return this.p1l;
    }

    public void setPhase2Learner(Phase2Learner phase2Learner) {
        this.p2l = phase2Learner;
    }

    public Phase2Learner getPhase2Learner() {
        return this.p2l;
    }

    public void setNormalize(boolean z) {
        this.normalize = z;
    }

    public boolean isNormalize() {
        return this.normalize;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        return this.baseClassifier.classify(transform(dataPoint));
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v7, types: [jsat.classifiers.ClassificationDataSet] */
    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        if (this.baseClassifier == null) {
            throw new FailedToFitException("RBFNet was not given a base classifier");
        }
        if (executorService == null) {
            executorService = new FakeExecutor();
        }
        this.centroids = this.p1l.getCentroids(classificationDataSet, this.numCentroids, this.dm, executorService);
        this.centroidDistCache = this.dm.getAccelerationCache(this.centroids, executorService);
        this.bandwidths = this.p2l.estimateBandwidths(this.alpha, this.p, classificationDataSet, this.centroids, this.centroidDistCache, this.dm, executorService);
        ?? shallowClone2 = classificationDataSet.shallowClone2();
        shallowClone2.applyTransform(this, executorService);
        this.baseClassifier.trainC(shallowClone2, executorService);
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        trainC(classificationDataSet, null);
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return this.baseClassifier != null ? this.baseClassifier.supportsWeightedData() : this.baseRegressor.supportsWeightedData();
    }

    @Override // jsat.regression.Regressor
    public double regress(DataPoint dataPoint) {
        return this.baseRegressor.regress(transform(dataPoint));
    }

    @Override // jsat.datatransform.DataTransform
    public void fit(DataSet dataSet) {
        if (dataSet instanceof ClassificationDataSet) {
            trainC((ClassificationDataSet) dataSet);
        } else {
            if (!(dataSet instanceof RegressionDataSet)) {
                throw new FailedToFitException("Data must be a classifiation or regression dataset, not " + dataSet.getClass().getSimpleName());
            }
            train((RegressionDataSet) dataSet);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v7, types: [jsat.regression.RegressionDataSet] */
    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet, ExecutorService executorService) {
        if (this.baseRegressor == null) {
            throw new FailedToFitException("RBFNet was not given a base classifier");
        }
        if (executorService == null) {
            executorService = new FakeExecutor();
        }
        this.centroids = this.p1l.getCentroids(regressionDataSet, this.numCentroids, this.dm, executorService);
        this.centroidDistCache = this.dm.getAccelerationCache(this.centroids, executorService);
        this.bandwidths = this.p2l.estimateBandwidths(this.alpha, this.p, regressionDataSet, this.centroids, this.centroidDistCache, this.dm, executorService);
        ?? shallowClone2 = regressionDataSet.shallowClone2();
        shallowClone2.applyTransform(this, executorService);
        this.baseRegressor.train(shallowClone2, executorService);
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet) {
        train(regressionDataSet, null);
    }

    @Override // jsat.datatransform.DataTransform
    /* renamed from: clone, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public RBFNet clone() {
        return new RBFNet(this);
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }
}
