package jsat.linear.distancemetrics;

import java.util.Iterator;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.linear.Vec;
import jsat.utils.DoubleList;
import jsat.utils.FakeExecutor;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.ParallelUtils;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/linear/distancemetrics/CosineDistance.class */
public class CosineDistance implements DistanceMetric {
    private static final long serialVersionUID = -6475546704095989078L;

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public double dist(Vec vec, Vec vec2) {
        if (vec.length() != vec2.length()) {
            throw new ArithmeticException("vectors a and b are of differeing legnths " + vec.length() + " and " + vec2.length());
        }
        double pNorm = vec.pNorm(2.0d) * vec2.pNorm(2.0d);
        return pNorm == 0.0d ? cosineToDistance(-1.0d) : cosineToDistance(Math.min(vec.dot(vec2) / pNorm, 1.0d));
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public boolean isSymmetric() {
        return true;
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public boolean isSubadditive() {
        return true;
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public boolean isIndiscemible() {
        return true;
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public double metricBound() {
        return 1.0d;
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public String toString() {
        return "Cosine Distance";
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public CosineDistance m647clone() {
        return new CosineDistance();
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public boolean supportsAcceleration() {
        return true;
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public List<Double> getAccelerationCache(List<? extends Vec> list) {
        DoubleList doubleList = new DoubleList(list.size());
        Iterator<? extends Vec> it = list.iterator();
        while (it.hasNext()) {
            doubleList.add(it.next().pNorm(2.0d));
        }
        return doubleList;
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public List<Double> getAccelerationCache(final List<? extends Vec> list, ExecutorService executorService) {
        if (executorService == null || (executorService instanceof FakeExecutor)) {
            return getAccelerationCache(list);
        }
        final double[] dArr = new double[list.size()];
        int min = Math.min(SystemInfo.LogicalCores, list.size());
        final CountDownLatch countDownLatch = new CountDownLatch(min);
        for (int i = 0; i < min; i++) {
            final int startBlock = ParallelUtils.getStartBlock(dArr.length, i, min);
            final int endBlock = ParallelUtils.getEndBlock(dArr.length, i, min);
            executorService.submit(new Runnable() { // from class: jsat.linear.distancemetrics.CosineDistance.1
                @Override // java.lang.Runnable
                public void run() {
                    for (int i2 = startBlock; i2 < endBlock; i2++) {
                        dArr[i2] = ((Vec) list.get(i2)).pNorm(2.0d);
                    }
                    countDownLatch.countDown();
                }
            });
        }
        try {
            countDownLatch.await();
        } catch (InterruptedException e) {
            Logger.getLogger(CosineDistance.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
        }
        return DoubleList.view(dArr, dArr.length);
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public double dist(int i, int i2, List<? extends Vec> list, List<Double> list2) {
        if (list2 == null) {
            return dist(list.get(i), list.get(i2));
        }
        double doubleValue = list2.get(i).doubleValue() * list2.get(i2).doubleValue();
        return doubleValue == 0.0d ? cosineToDistance(-1.0d) : cosineToDistance(Math.min(list.get(i).dot(list.get(i2)) / doubleValue, 1.0d));
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public double dist(int i, Vec vec, List<? extends Vec> list, List<Double> list2) {
        if (list2 == null) {
            return dist(list.get(i), vec);
        }
        double doubleValue = list2.get(i).doubleValue() * vec.pNorm(2.0d);
        return doubleValue == 0.0d ? cosineToDistance(-1.0d) : cosineToDistance(Math.min(list.get(i).dot(vec) / doubleValue, 1.0d));
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public List<Double> getQueryInfo(Vec vec) {
        DoubleList doubleList = new DoubleList(1);
        doubleList.add(vec.pNorm(2.0d));
        return doubleList;
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public double dist(int i, Vec vec, List<Double> list, List<? extends Vec> list2, List<Double> list3) {
        if (list3 == null) {
            return dist(list2.get(i), vec);
        }
        double doubleValue = list3.get(i).doubleValue() * list.get(0).doubleValue();
        return doubleValue == 0.0d ? cosineToDistance(-1.0d) : cosineToDistance(Math.min(list2.get(i).dot(vec) / doubleValue, 1.0d));
    }

    public static double cosineToDistance(double d) {
        return Math.sqrt(0.5d * (1.0d - d));
    }

    public static double distanceToCosine(double d) {
        return 1.0d - (2.0d * (d * d));
    }
}
