package jsat.classifiers.knn;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import jsat.DataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.distributions.Distribution;
import jsat.distributions.discrete.UniformDiscrete;
import jsat.distributions.empirical.kernelfunc.EpanechnikovKF;
import jsat.distributions.empirical.kernelfunc.KernelFunction;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.Vec;
import jsat.linear.VecPaired;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.linear.vectorcollection.DefaultVectorCollectionFactory;
import jsat.linear.vectorcollection.VectorCollection;
import jsat.linear.vectorcollection.VectorCollectionFactory;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/knn/LWL.class */
public class LWL implements Classifier, Regressor, Parameterized {
    private static final long serialVersionUID = 6942465758987345997L;
    private CategoricalData predicting;
    private Classifier classifier;
    private Regressor regressor;
    private int k;
    private DistanceMetric dm;
    private KernelFunction kf;
    private VectorCollectionFactory<VecPaired<Vec, Double>> vcf;
    private VectorCollection<VecPaired<Vec, Double>> vc;

    private LWL(LWL lwl) {
        if (lwl.predicting != null) {
            this.predicting = lwl.predicting.m481clone();
        }
        if (lwl.classifier != null) {
            setClassifier(lwl.classifier);
        }
        if (lwl.regressor != null) {
            setRegressor(lwl.regressor);
        }
        setNeighbors(lwl.k);
        setDistanceMetric(lwl.dm.mo652clone());
        setKernelFunction(lwl.kf);
        this.vcf = lwl.vcf;
        if (lwl.vc != null) {
            this.vc = lwl.vc.clone();
        }
    }

    public LWL(Classifier classifier, int i, DistanceMetric distanceMetric) {
        this(classifier, i, distanceMetric, EpanechnikovKF.getInstance());
    }

    public LWL(Classifier classifier, int i, DistanceMetric distanceMetric, KernelFunction kernelFunction) {
        this(classifier, i, distanceMetric, kernelFunction, new DefaultVectorCollectionFactory());
    }

    public LWL(Classifier classifier, int i, DistanceMetric distanceMetric, KernelFunction kernelFunction, VectorCollectionFactory<VecPaired<Vec, Double>> vectorCollectionFactory) {
        setClassifier(classifier);
        setNeighbors(i);
        setDistanceMetric(distanceMetric);
        setKernelFunction(kernelFunction);
        this.vcf = vectorCollectionFactory;
    }

    public LWL(Regressor regressor, int i, DistanceMetric distanceMetric) {
        this(regressor, i, distanceMetric, EpanechnikovKF.getInstance());
    }

    public LWL(Regressor regressor, int i, DistanceMetric distanceMetric, KernelFunction kernelFunction) {
        this(regressor, i, distanceMetric, kernelFunction, new DefaultVectorCollectionFactory());
    }

    public LWL(Regressor regressor, int i, DistanceMetric distanceMetric, KernelFunction kernelFunction, VectorCollectionFactory<VecPaired<Vec, Double>> vectorCollectionFactory) {
        setRegressor(regressor);
        setNeighbors(i);
        setDistanceMetric(distanceMetric);
        setKernelFunction(kernelFunction);
        this.vcf = vectorCollectionFactory;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        if (this.classifier == null || this.vc == null) {
            throw new UntrainedModelException("Model has not been trained");
        }
        List<? extends VecPaired<VecPaired<Vec, Double>, Double>> search = this.vc.search(dataPoint.getNumericalValues(), this.k);
        ArrayList arrayList = new ArrayList(search.size());
        double doubleValue = search.get(search.size() - 1).getPair().doubleValue();
        for (int i = 0; i < search.size(); i++) {
            VecPaired<VecPaired<Vec, Double>, Double> vecPaired = search.get(i);
            arrayList.add(new DataPointPair(new DataPoint(vecPaired, new int[0], new CategoricalData[0], this.kf.k(vecPaired.getPair().doubleValue() / doubleValue)), Integer.valueOf(vecPaired.getVector().getPair().intValue())));
        }
        ClassificationDataSet classificationDataSet = new ClassificationDataSet(arrayList, this.predicting);
        Classifier mo714clone = this.classifier.mo714clone();
        mo714clone.trainC(classificationDataSet);
        return mo714clone.classify(dataPoint);
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        List<VecPaired<Vec, Double>> vecList = getVecList(classificationDataSet);
        TrainableDistanceMetric.trainIfNeeded(this.dm, classificationDataSet, executorService);
        this.vc = this.vcf.getVectorCollection(vecList, this.dm, executorService);
        this.predicting = classificationDataSet.getPredicting();
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        List<VecPaired<Vec, Double>> vecList = getVecList(classificationDataSet);
        TrainableDistanceMetric.trainIfNeeded(this.dm, classificationDataSet);
        this.vc = this.vcf.getVectorCollection(vecList, this.dm);
        this.predicting = classificationDataSet.getPredicting();
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    @Override // jsat.regression.Regressor
    public double regress(DataPoint dataPoint) {
        if (this.regressor == null || this.vc == null) {
            throw new UntrainedModelException("Model has not been trained");
        }
        List<? extends VecPaired<VecPaired<Vec, Double>, Double>> search = this.vc.search(dataPoint.getNumericalValues(), this.k);
        ArrayList arrayList = new ArrayList(search.size());
        double doubleValue = search.get(search.size() - 1).getPair().doubleValue();
        for (int i = 0; i < search.size(); i++) {
            VecPaired<VecPaired<Vec, Double>, Double> vecPaired = search.get(i);
            arrayList.add(new DataPointPair(new DataPoint(vecPaired, new int[0], new CategoricalData[0], this.kf.k(vecPaired.getPair().doubleValue() / doubleValue)), vecPaired.getVector().getPair()));
        }
        RegressionDataSet regressionDataSet = new RegressionDataSet(arrayList);
        Regressor mo714clone = this.regressor.mo714clone();
        mo714clone.train(regressionDataSet);
        return mo714clone.regress(dataPoint);
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet, ExecutorService executorService) {
        List<VecPaired<Vec, Double>> vecList = getVecList(regressionDataSet);
        TrainableDistanceMetric.trainIfNeeded(this.dm, regressionDataSet, executorService);
        this.vc = this.vcf.getVectorCollection(vecList, this.dm, executorService);
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet) {
        List<VecPaired<Vec, Double>> vecList = getVecList(regressionDataSet);
        TrainableDistanceMetric.trainIfNeeded(this.dm, regressionDataSet);
        this.vc = this.vcf.getVectorCollection(vecList, this.dm);
    }

    @Override // jsat.regression.Regressor
    /* renamed from: clone, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public LWL mo714clone() {
        return new LWL(this);
    }

    private List<VecPaired<Vec, Double>> getVecList(ClassificationDataSet classificationDataSet) {
        ArrayList arrayList = new ArrayList(classificationDataSet.getSampleSize());
        for (int i = 0; i < classificationDataSet.getSampleSize(); i++) {
            arrayList.add(new VecPaired(classificationDataSet.getDataPoint(i).getNumericalValues(), new Double(classificationDataSet.getDataPointCategory(i))));
        }
        return arrayList;
    }

    private List<VecPaired<Vec, Double>> getVecList(RegressionDataSet regressionDataSet) {
        ArrayList arrayList = new ArrayList(regressionDataSet.getSampleSize());
        for (int i = 0; i < regressionDataSet.getSampleSize(); i++) {
            arrayList.add(new VecPaired(regressionDataSet.getDataPoint(i).getNumericalValues(), Double.valueOf(regressionDataSet.getTargetValue(i))));
        }
        return arrayList;
    }

    private void setClassifier(Classifier classifier) {
        this.classifier = classifier;
        if (classifier instanceof Regressor) {
            this.regressor = (Regressor) classifier;
        }
    }

    private void setRegressor(Regressor regressor) {
        this.regressor = regressor;
        if (regressor instanceof Classifier) {
            this.classifier = (Classifier) regressor;
        }
    }

    public void setNeighbors(int i) {
        if (i <= 1) {
            throw new RuntimeException("An average requires at least 2 neighbors to be taken into account");
        }
        this.k = i;
    }

    public int getNeighbors() {
        return this.k;
    }

    public void setDistanceMetric(DistanceMetric distanceMetric) {
        this.dm = distanceMetric;
    }

    public DistanceMetric getDistanceMetric() {
        return this.dm;
    }

    public void setKernelFunction(KernelFunction kernelFunction) {
        this.kf = kernelFunction;
    }

    public KernelFunction getKernelFunction() {
        return this.kf;
    }

    public static Distribution guessNeighbors(DataSet dataSet) {
        return new UniformDiscrete(25, Math.min(200, dataSet.getSampleSize() / 5));
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }
}
