package jsat.classifiers;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DenseSparseMetric;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.utils.FakeExecutor;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/Rocchio.class */
public class Rocchio implements Classifier {
    private static final long serialVersionUID = 889524967453326517L;
    private List<Vec> rocVecs;
    private final DistanceMetric dm;
    private final DenseSparseMetric dsdm;
    private double[] summaryConsts;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/Rocchio$RocchioAdder.class */
    public class RocchioAdder implements Runnable {
        double weightSum = 0.0d;
        final CountDownLatch latch;
        final Vec rocchioVec;
        final List<DataPoint> input;
        final int index;

        public RocchioAdder(CountDownLatch countDownLatch, int i, Vec vec, List<DataPoint> list) {
            this.latch = countDownLatch;
            this.index = i;
            this.rocchioVec = vec;
            this.input = list;
        }

        @Override // java.lang.Runnable
        public void run() {
            for (DataPoint dataPoint : this.input) {
                double weight = dataPoint.getWeight();
                Vec numericalValues = dataPoint.getNumericalValues();
                this.weightSum += weight;
                this.rocchioVec.mutableAdd(weight, numericalValues);
            }
            this.rocchioVec.mutableDivide(this.weightSum);
            if (Rocchio.this.dsdm != null) {
                Rocchio.this.summaryConsts[this.index] = Rocchio.this.dsdm.getVectorConstant(this.rocchioVec);
            }
            this.latch.countDown();
        }
    }

    public Rocchio() {
        this(new EuclideanDistance());
    }

    public Rocchio(DistanceMetric distanceMetric) {
        this.dm = distanceMetric;
        this.dsdm = distanceMetric instanceof DenseSparseMetric ? (DenseSparseMetric) distanceMetric : null;
        this.rocVecs = null;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        CategoricalResults categoricalResults = new CategoricalResults(this.rocVecs.size());
        double d = 0.0d;
        Vec numericalValues = dataPoint.getNumericalValues();
        for (int i = 0; i < this.rocVecs.size(); i++) {
            double dist = this.summaryConsts == null ? this.dm.dist(this.rocVecs.get(i), numericalValues) : this.dsdm.dist(this.summaryConsts[i], this.rocVecs.get(i), numericalValues);
            d += dist;
            categoricalResults.setProb(i, dist);
        }
        for (int i2 = 0; i2 < this.rocVecs.size(); i2++) {
            categoricalResults.setProb(i2, 1.0d - (categoricalResults.getProb(i2) / d));
        }
        return categoricalResults;
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        if (classificationDataSet.getNumCategoricalVars() != 0) {
            throw new FailedToFitException("Classifier requires all variables be numerical");
        }
        int classSize = classificationDataSet.getClassSize();
        this.rocVecs = new ArrayList(classSize);
        TrainableDistanceMetric.trainIfNeeded(this.dm, classificationDataSet, executorService);
        int numNumericalVars = classificationDataSet.getNumNumericalVars();
        this.summaryConsts = new double[numNumericalVars];
        CountDownLatch countDownLatch = new CountDownLatch(classSize);
        for (int i = 0; i < classSize; i++) {
            DenseVector denseVector = new DenseVector(numNumericalVars);
            this.rocVecs.add(denseVector);
            executorService.submit(new RocchioAdder(countDownLatch, i, denseVector, classificationDataSet.getSamples(i)));
        }
        try {
            countDownLatch.await();
        } catch (InterruptedException e) {
        }
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        trainC(classificationDataSet, new FakeExecutor());
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return true;
    }

    @Override // jsat.classifiers.Classifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public Rocchio m492clone() {
        Rocchio rocchio = new Rocchio(this.dm);
        if (this.rocVecs != null) {
            rocchio.rocVecs = new ArrayList(this.rocVecs.size());
            Iterator<Vec> it = this.rocVecs.iterator();
            while (it.hasNext()) {
                rocchio.rocVecs.add(it.next().mo525clone());
            }
        }
        if (this.summaryConsts != null) {
            rocchio.summaryConsts = Arrays.copyOf(this.summaryConsts, this.summaryConsts.length);
        }
        return rocchio;
    }
}
