package jsat.classifiers.neuralnetwork;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.classifiers.PriorClassifier;
import jsat.classifiers.bayesian.MultivariateNormals;
import jsat.classifiers.neuralnetwork.LVQ;
import jsat.linear.Vec;
import jsat.linear.VecPaired;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.math.decayrates.DecayRate;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/neuralnetwork/LVQLLC.class */
public class LVQLLC extends LVQ {
    private static final long serialVersionUID = 3602640001545233744L;
    private Classifier localClassifier;
    private Classifier[] localClassifeirs;

    public LVQLLC(DistanceMetric distanceMetric, int i) {
        this(distanceMetric, i, new MultivariateNormals(true));
    }

    public LVQLLC(DistanceMetric distanceMetric, int i, Classifier classifier) {
        super(distanceMetric, i);
        setLocalClassifier(classifier);
    }

    public LVQLLC(DistanceMetric distanceMetric, int i, Classifier classifier, double d, int i2) {
        super(distanceMetric, i, d, i2);
        setLocalClassifier(classifier);
    }

    public LVQLLC(DistanceMetric distanceMetric, int i, Classifier classifier, double d, int i2, LVQ.LVQVersion lVQVersion, DecayRate decayRate) {
        super(distanceMetric, i, d, i2, lVQVersion, decayRate);
        setLocalClassifier(classifier);
    }

    protected LVQLLC(LVQLLC lvqllc) {
        super(lvqllc);
        if (lvqllc.localClassifier != null) {
            this.localClassifier = lvqllc.localClassifier.clone();
        }
        if (lvqllc.localClassifeirs != null) {
            this.localClassifeirs = new Classifier[lvqllc.localClassifeirs.length];
            for (int i = 0; i < this.localClassifeirs.length; i++) {
                this.localClassifeirs[i] = lvqllc.localClassifeirs[i].clone();
            }
        }
    }

    public void setLocalClassifier(Classifier classifier) {
        this.localClassifier = classifier;
    }

    public Classifier getLocalClassifier() {
        return this.localClassifier;
    }

    @Override // jsat.classifiers.neuralnetwork.LVQ, jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        List<? extends VecPaired<VecPaired<Vec, Integer>, Double>> search = this.vc.search(dataPoint.getNumericalValues(), 2);
        double doubleValue = search.get(0).getPair().doubleValue();
        int intValue = search.get(0).getVector().getPair().intValue();
        double doubleValue2 = search.get(1).getPair().doubleValue();
        int intValue2 = search.get(1).getVector().getPair().intValue();
        CategoricalResults classify = this.localClassifeirs[intValue].classify(dataPoint);
        if (getLVQMethod().ordinal() < LVQ.LVQVersion.LVQ2.ordinal() || !epsClose(doubleValue, doubleValue2)) {
            return classify;
        }
        CategoricalResults categoricalResults = new CategoricalResults(classify.size());
        CategoricalResults classify2 = this.localClassifeirs[intValue2].classify(dataPoint);
        double d = doubleValue + doubleValue2;
        for (int i = 0; i < classify.size(); i++) {
            categoricalResults.incProb(i, classify.getProb(i) * (d - doubleValue));
            categoricalResults.incProb(i, classify2.getProb(i) * (d - doubleValue2));
        }
        categoricalResults.normalize();
        return categoricalResults;
    }

    @Override // jsat.classifiers.neuralnetwork.LVQ, jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        super.trainC(classificationDataSet, executorService);
        ArrayList arrayList = new ArrayList(this.weights.length);
        for (int i = 0; i < this.weights.length; i++) {
            arrayList.add(new ArrayList((this.wins[i] * 3) / 2));
        }
        for (DataPointPair<Integer> dataPointPair : classificationDataSet.getAsDPPList()) {
            List<? extends VecPaired<VecPaired<Vec, Integer>, Double>> search = this.vc.search(dataPointPair.getVector(), 2);
            VecPaired<VecPaired<Vec, Integer>, Double> vecPaired = search.get(0);
            int intValue = vecPaired.getVector().getPair().intValue();
            double doubleValue = vecPaired.getPair().doubleValue();
            VecPaired<VecPaired<Vec, Integer>, Double> vecPaired2 = search.get(0);
            int intValue2 = vecPaired2.getVector().getPair().intValue();
            double doubleValue2 = vecPaired2.getPair().doubleValue();
            ((List) arrayList.get(intValue)).add(dataPointPair);
            double epsilonDistance = getEpsilonDistance();
            if (Math.min(doubleValue / doubleValue2, doubleValue2 / doubleValue) > 1.0d - epsilonDistance && Math.max(doubleValue / doubleValue2, doubleValue2 / doubleValue) < 1.0d + epsilonDistance) {
                ((List) arrayList.get(intValue2)).add(dataPointPair);
            }
        }
        this.localClassifeirs = new Classifier[this.weights.length];
        for (int i2 = 0; i2 < this.weights.length; i2++) {
            if (this.wins[i2] != 0) {
                ClassificationDataSet classificationDataSet2 = new ClassificationDataSet((List<DataPointPair<Integer>>) arrayList.get(i2), classificationDataSet.getPredicting());
                if (this.wins[i2] < 10) {
                    CategoricalResults categoricalResults = new CategoricalResults(classificationDataSet.getPredicting().getNumOfCategories());
                    categoricalResults.setProb(this.weightClass[i2], 1.0d);
                    this.localClassifeirs[i2] = new PriorClassifier(categoricalResults);
                } else {
                    this.localClassifeirs[i2] = this.localClassifier.clone();
                    this.localClassifeirs[i2].trainC(classificationDataSet2);
                }
            }
        }
    }

    @Override // jsat.classifiers.neuralnetwork.LVQ, jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        trainC(classificationDataSet, null);
    }

    @Override // jsat.classifiers.neuralnetwork.LVQ
    public LVQLLC mo548clone() {
        return new LVQLLC(this);
    }
}
