package jsat.classifiers.neuralnetwork;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.distributions.empirical.kernelfunc.EpanechnikovKF;
import jsat.distributions.empirical.kernelfunc.KernelFunction;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.linear.VecPaired;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.vectorcollection.DefaultVectorCollectionFactory;
import jsat.linear.vectorcollection.VectorCollection;
import jsat.linear.vectorcollection.VectorCollectionFactory;
import jsat.math.decayrates.DecayRate;
import jsat.math.decayrates.ExponetialDecay;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.ArrayUtils;
import jsat.utils.PairedReturn;
import jsat.utils.SystemInfo;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/neuralnetwork/SOM.class */
public class SOM implements Classifier, Parameterized {
    private static final long serialVersionUID = -6444988770441043797L;
    public static final int DEFAULT_MAX_ITERS = 500;
    public static final double DEFAULT_LEARNING_RATE = 0.1d;
    private int somWidth;
    private int somHeight;
    private int maxIters;
    private KernelFunction kf;
    private double initialLearningRate;
    private DecayRate learningDecay;
    private DecayRate neighborDecay;
    private DistanceMetric dm;
    private VectorCollectionFactory<VecPaired<Vec, Integer>> vcFactory;
    private Vec[][] weights;
    private CategoricalResults[] crWeightPairs;
    private VectorCollection<VecPaired<Vec, Integer>> vcCollection;
    private List<List<List<DataPoint>>> weightUpdates;
    public static final KernelFunction DEFAULT_KF = EpanechnikovKF.getInstance();
    public static final DecayRate DEFAULT_LEARNING_DECAY = new ExponetialDecay();
    public static final DecayRate DEFAULT_NEIGHBOR_DECAY = new ExponetialDecay();

    public SOM(int i, int i2) {
        this(new EuclideanDistance(), i, i2);
    }

    public SOM(DistanceMetric distanceMetric, int i, int i2) {
        this(distanceMetric, i, i2, new DefaultVectorCollectionFactory());
    }

    public SOM(DistanceMetric distanceMetric, int i, int i2, VectorCollectionFactory<VecPaired<Vec, Integer>> vectorCollectionFactory) {
        this(500, DEFAULT_KF, 0.1d, DEFAULT_LEARNING_DECAY, DEFAULT_NEIGHBOR_DECAY, distanceMetric, i, i2, vectorCollectionFactory);
    }

    private SOM(int i, KernelFunction kernelFunction, double d, DecayRate decayRate, DecayRate decayRate2, DistanceMetric distanceMetric, int i2, int i3, VectorCollectionFactory<VecPaired<Vec, Integer>> vectorCollectionFactory) {
        this.somHeight = i2;
        this.somWidth = i3;
        this.maxIters = i;
        this.kf = kernelFunction;
        this.initialLearningRate = d;
        this.learningDecay = decayRate;
        this.neighborDecay = decayRate2;
        this.dm = distanceMetric;
        this.vcFactory = vectorCollectionFactory;
    }

    public void setMaxIterations(int i) {
        if (i < 1) {
            throw new ArithmeticException("At least one iteration must be performed");
        }
        this.maxIters = i;
    }

    public int getMaxIterations() {
        return this.maxIters;
    }

    public void setSomWidth(int i) {
        if (i < 1) {
            throw new ArithmeticException("Lattice width must be positive, not " + i);
        }
        this.somWidth = i;
    }

    public void setSomHeight(int i) {
        if (i < 1) {
            throw new ArithmeticException("ALttice height must be positive, not " + i);
        }
        this.somHeight = i;
    }

    public int getSomHeight() {
        return this.somHeight;
    }

    public int getSomWidth() {
        return this.somWidth;
    }

    public void setInitialLearningRate(double d) {
        if (Double.isInfinite(d) || Double.isNaN(d) || d <= 0.0d) {
            throw new ArithmeticException("Learning rate must be a positive constant, not " + d);
        }
        this.initialLearningRate = d;
    }

    public double getInitialLearningRate() {
        return this.initialLearningRate;
    }

    public void setLearningDecay(DecayRate decayRate) {
        if (decayRate == null) {
            throw new NullPointerException("Can not set a decay rate to null");
        }
        this.learningDecay = decayRate;
    }

    public DecayRate getLearningDecay() {
        return this.learningDecay;
    }

    public void setNeighborDecay(DecayRate decayRate) {
        if (decayRate == null) {
            throw new NullPointerException("Can not set a decay rate to null");
        }
        this.neighborDecay = decayRate;
    }

    public DecayRate getNeighborDecay() {
        return this.neighborDecay;
    }

    private double intitalizeWeights(int i) {
        for (int i2 = 0; i2 < this.somHeight; i2++) {
            for (int i3 = 0; i3 < this.somWidth; i3++) {
                this.weights[i2][i3] = DenseVector.random(i);
            }
        }
        return Math.max(this.somWidth, this.somHeight);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void iterationStep(ExecutorService executorService, int i, DataSet dataSet, double d, double d2, Vec vec, double d3) {
        Vec numericalValues = dataSet.getDataPoint(i).getNumericalValues();
        PairedReturn<Integer, Integer> bmu = getBMU(numericalValues);
        int intValue = bmu.getFirstItem().intValue();
        int intValue2 = bmu.getSecondItem().intValue();
        int max = Math.max(((int) (intValue - d)) - 1, 0);
        int max2 = Math.max(((int) (intValue2 - d)) - 1, 0);
        int min = Math.min(((int) (intValue + d)) + 1, this.somWidth);
        int min2 = Math.min(((int) (intValue2 + d)) + 1, this.somHeight);
        for (int i2 = max; i2 < min; i2++) {
            Vec[] vecArr = this.weights[i2];
            for (int i3 = max2; i3 < min2; i3++) {
                int i4 = intValue - i2;
                int i5 = intValue2 - i3;
                int i6 = (i4 * i4) + (i5 * i5);
                if (i6 < d2) {
                    double k = this.kf.k(Math.sqrt(i6) / d);
                    Vec vec2 = vecArr[i3];
                    if (executorService == null) {
                        updateWeight(numericalValues, vec, vec2, k * d3);
                    } else {
                        this.weightUpdates.get(i2).get(i3).add(dataSet.getDataPoint(i));
                    }
                }
            }
        }
    }

    private List<VecPaired<Vec, Integer>> setUpVectorCollection(ExecutorService executorService) {
        ArrayList arrayList = new ArrayList(this.somWidth * this.somHeight);
        for (int i = 0; i < this.weights.length; i++) {
            for (int i2 = 0; i2 < this.weights[i].length; i2++) {
                arrayList.add(new VecPaired(this.weights[i][i2], Integer.valueOf(arrayList.size())));
            }
        }
        if (executorService == null) {
            this.vcCollection = this.vcFactory.getVectorCollection(arrayList, this.dm);
        } else {
            this.vcCollection = this.vcFactory.getVectorCollection(arrayList, this.dm, executorService);
        }
        return arrayList;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void updateWeight(Vec vec, Vec vec2, Vec vec3, double d) {
        vec.copyTo(vec2);
        vec2.mutableSubtract(vec3);
        vec3.mutableAdd(d, vec2);
    }

    private PairedReturn<Integer, Integer> getBMU(Vec vec) {
        double d = Double.MAX_VALUE;
        int i = -1;
        int i2 = -1;
        for (int i3 = 0; i3 < this.weights.length; i3++) {
            Vec[] vecArr = this.weights[i3];
            for (int i4 = 0; i4 < this.weights[i3].length; i4++) {
                double dist = this.dm.dist(vecArr[i4], vec);
                if (dist < d) {
                    d = dist;
                    i = i3;
                    i2 = i4;
                }
            }
        }
        return new PairedReturn<>(Integer.valueOf(i), Integer.valueOf(i2));
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }

    private void trainSOM(final DataSet dataSet, final ExecutorService executorService) throws InterruptedException {
        ThreadLocal<Vec> threadLocal;
        ThreadLocal<Vec> threadLocal2;
        final int numNumericalVars = dataSet.getNumNumericalVars();
        this.weights = new Vec[this.somHeight][this.somWidth];
        double intitalizeWeights = intitalizeWeights(numNumericalVars);
        Random random = new Random();
        DenseVector denseVector = new DenseVector(numNumericalVars);
        int[] iArr = new int[dataSet.getSampleSize()];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = i;
        }
        if (executorService != null) {
            this.weightUpdates = new ArrayList(this.somHeight);
            for (int i2 = 0; i2 < this.somHeight; i2++) {
                ArrayList arrayList = new ArrayList(this.somWidth);
                this.weightUpdates.add(arrayList);
                for (int i3 = 0; i3 < this.somWidth; i3++) {
                    arrayList.add(Collections.synchronizedList(new ArrayList()));
                }
            }
            threadLocal = new ThreadLocal<Vec>() { // from class: jsat.classifiers.neuralnetwork.SOM.1
                /* JADX INFO: Access modifiers changed from: protected */
                /* JADX WARN: Can't rename method to resolve collision */
                @Override // java.lang.ThreadLocal
                public Vec initialValue() {
                    return new DenseVector(numNumericalVars);
                }
            };
            threadLocal2 = new ThreadLocal<Vec>() { // from class: jsat.classifiers.neuralnetwork.SOM.2
                /* JADX INFO: Access modifiers changed from: protected */
                /* JADX WARN: Can't rename method to resolve collision */
                @Override // java.lang.ThreadLocal
                public Vec initialValue() {
                    return new DenseVector(numNumericalVars);
                }
            };
        } else {
            threadLocal = null;
            threadLocal2 = null;
        }
        for (int i4 = 0; i4 < this.maxIters; i4++) {
            final double rate = this.neighborDecay.rate(i4, this.maxIters, intitalizeWeights);
            final double d = rate * rate;
            final double rate2 = this.learningDecay.rate(i4, this.maxIters, this.initialLearningRate);
            if (executorService == null) {
                ArrayUtils.shuffle(iArr, random);
            } else {
                for (int i5 = 0; i5 < this.somHeight; i5++) {
                    for (int i6 = 0; i6 < this.somWidth; i6++) {
                        this.weightUpdates.get(i5).get(i6).clear();
                    }
                }
            }
            if (executorService == null) {
                for (int i7 : iArr) {
                    iterationStep(executorService, i7, dataSet, rate, d, denseVector, rate2);
                }
            } else {
                int i8 = 0;
                int sampleSize = dataSet.getSampleSize() / SystemInfo.LogicalCores;
                int sampleSize2 = dataSet.getSampleSize() % SystemInfo.LogicalCores;
                final CountDownLatch countDownLatch = new CountDownLatch(SystemInfo.LogicalCores);
                while (i8 < dataSet.getSampleSize()) {
                    int i9 = sampleSize2;
                    sampleSize2--;
                    final int i10 = (i9 > 0 ? 1 : 0) + i8 + sampleSize;
                    final int i11 = i8;
                    i8 = i10;
                    final ThreadLocal<Vec> threadLocal3 = threadLocal;
                    executorService.submit(new Runnable() { // from class: jsat.classifiers.neuralnetwork.SOM.3
                        @Override // java.lang.Runnable
                        public void run() {
                            for (int i12 = i11; i12 < i10; i12++) {
                                SOM.this.iterationStep(executorService, i12, dataSet, rate, d, (Vec) threadLocal3.get(), rate2);
                            }
                            countDownLatch.countDown();
                        }
                    });
                }
                countDownLatch.await();
            }
            if (executorService != null) {
                final CountDownLatch countDownLatch2 = new CountDownLatch(this.somHeight * this.somWidth);
                for (int i12 = 0; i12 < this.somHeight; i12++) {
                    for (int i13 = 0; i13 < this.somWidth; i13++) {
                        final List<DataPoint> list = this.weightUpdates.get(i12).get(i13);
                        final int i14 = i12;
                        final int i15 = i13;
                        final ThreadLocal<Vec> threadLocal4 = threadLocal;
                        final ThreadLocal<Vec> threadLocal5 = threadLocal2;
                        executorService.submit(new Runnable() { // from class: jsat.classifiers.neuralnetwork.SOM.4
                            @Override // java.lang.Runnable
                            public void run() {
                                Vec vec = (Vec) threadLocal4.get();
                                vec.zeroOut();
                                double d2 = 0.0d;
                                for (DataPoint dataPoint : list) {
                                    d2 += dataPoint.getWeight();
                                    vec.mutableAdd(dataPoint.getWeight(), dataPoint.getNumericalValues());
                                }
                                vec.mutableDivide(d2);
                                SOM.this.updateWeight(vec, (Vec) threadLocal5.get(), SOM.this.weights[i14][i15], rate2);
                                countDownLatch2.countDown();
                            }
                        });
                    }
                }
                countDownLatch2.await();
            }
        }
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        if (this.crWeightPairs == null) {
            throw new UntrainedModelException();
        }
        return this.crWeightPairs[this.vcCollection.search(dataPoint.getNumericalValues(), 1).get(0).getVector().getPair().intValue()];
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        try {
            trainSOM(classificationDataSet, executorService);
            this.crWeightPairs = new CategoricalResults[setUpVectorCollection(executorService).size()];
            for (int i = 0; i < this.crWeightPairs.length; i++) {
                this.crWeightPairs[i] = new CategoricalResults(classificationDataSet.getClassSize());
            }
            for (int i2 = 0; i2 < classificationDataSet.getSampleSize(); i2++) {
                DataPoint dataPoint = classificationDataSet.getDataPoint(i2);
                this.crWeightPairs[this.vcCollection.search(dataPoint.getNumericalValues(), 1).get(0).getVector().getPair().intValue()].incProb(classificationDataSet.getDataPointCategory(i2), dataPoint.getWeight());
            }
            for (int i3 = 0; i3 < this.crWeightPairs.length; i3++) {
                this.crWeightPairs[i3].normalize();
            }
        } catch (InterruptedException e) {
            Logger.getLogger(SOM.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
        }
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        trainC(classificationDataSet, null);
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return true;
    }

    @Override // jsat.classifiers.Classifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public SOM m558clone() {
        SOM som = new SOM(this.maxIters, this.kf, this.initialLearningRate, this.learningDecay, this.neighborDecay, this.dm.mo651clone(), this.somHeight, this.somHeight, this.vcFactory.m676clone());
        if (this.weights != null) {
            som.weights = new Vec[this.weights.length][this.weights[0].length];
            for (int i = 0; i < this.weights.length; i++) {
                for (int i2 = 0; i2 < this.weights[i].length; i2++) {
                    som.weights[i][i2] = this.weights[i][i2].mo524clone();
                }
            }
        }
        if (this.vcCollection != null) {
            som.vcCollection = this.vcCollection.clone();
        }
        if (this.crWeightPairs != null) {
            som.crWeightPairs = new CategoricalResults[this.crWeightPairs.length];
            for (int i3 = 0; i3 < this.crWeightPairs.length; i3++) {
                som.crWeightPairs[i3] = this.crWeightPairs[i3].m481clone();
            }
        }
        return som;
    }
}
