package jsat.linear.distancemetrics;

import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ExecutorService;
import jsat.linear.IndexValue;
import jsat.linear.Vec;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/linear/distancemetrics/PearsonDistance.class */
public class PearsonDistance implements DistanceMetric {
    private static final long serialVersionUID = 1090726755301934198L;
    private boolean bothNonZero;
    private boolean absoluteDistance;

    public PearsonDistance() {
        this(false, false);
    }

    public PearsonDistance(boolean z, boolean z2) {
        this.bothNonZero = z;
        this.absoluteDistance = z2;
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public double dist(Vec vec, Vec vec2) {
        double correlation = correlation(vec, vec2, this.bothNonZero);
        if (Double.isNaN(correlation)) {
            return Double.MAX_VALUE;
        }
        return this.absoluteDistance ? Math.sqrt(1.0d - (correlation * correlation)) : Math.sqrt((1.0d - correlation) * 0.5d);
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public boolean isSymmetric() {
        return true;
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public boolean isSubadditive() {
        return true;
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public boolean isIndiscemible() {
        return true;
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public double metricBound() {
        return 1.0d;
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public PearsonDistance m654clone() {
        return new PearsonDistance(this.bothNonZero, this.absoluteDistance);
    }

    public static double correlation(Vec vec, Vec vec2, boolean z) {
        double mean;
        double mean2;
        if (z) {
            mean = vec.sum() / vec.nnz();
            mean2 = vec2.sum() / vec2.nnz();
        } else {
            mean = vec.mean();
            mean2 = vec2.mean();
        }
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        if (vec.isSparse() || vec2.isSparse()) {
            Iterator<IndexValue> nonZeroIterator = vec.getNonZeroIterator();
            Iterator<IndexValue> nonZeroIterator2 = vec2.getNonZeroIterator();
            if (!nonZeroIterator.hasNext() && !nonZeroIterator2.hasNext()) {
                return 1.0d;
            }
            if (!nonZeroIterator.hasNext() || !nonZeroIterator2.hasNext()) {
                return Double.MAX_VALUE;
            }
            IndexValue indexValue = null;
            IndexValue indexValue2 = null;
            boolean z2 = true;
            boolean z3 = true;
            int i = -1;
            while (true) {
                if (z2) {
                    if (!nonZeroIterator.hasNext()) {
                        break;
                    }
                    indexValue = nonZeroIterator.next();
                    z2 = false;
                }
                if (z3) {
                    if (!nonZeroIterator2.hasNext()) {
                        break;
                    }
                    indexValue2 = nonZeroIterator2.next();
                    z3 = false;
                }
                if (indexValue.getIndex() == indexValue2.getIndex()) {
                    if (!z) {
                        d += mean * mean2 * ((indexValue.getIndex() - i) - 1);
                    }
                    i = indexValue.getIndex();
                    double value = indexValue.getValue() - mean;
                    double value2 = indexValue2.getValue() - mean2;
                    d += value * value2;
                    d2 += value * value;
                    d3 += value2 * value2;
                    z3 = true;
                    z2 = true;
                } else if (indexValue.getIndex() > indexValue2.getIndex()) {
                    if (!z) {
                        double index = d + (mean * mean2 * ((indexValue2.getIndex() - i) - 1));
                        i = indexValue2.getIndex();
                        double value3 = indexValue2.getValue() - mean2;
                        d = index + ((-mean) * value3);
                        d3 += value3 * value3;
                    }
                    z3 = true;
                } else if (indexValue.getIndex() < indexValue2.getIndex()) {
                    if (!z) {
                        double index2 = d + (mean * mean2 * ((indexValue.getIndex() - i) - 1));
                        i = indexValue.getIndex();
                        double value4 = indexValue.getValue() - mean;
                        d = index2 + (value4 * (-mean2));
                        d2 += value4 * value4;
                    }
                    z2 = true;
                }
            }
            if (!z) {
                while (true) {
                    if (!z2 || (z2 && nonZeroIterator.hasNext())) {
                        if (z2) {
                            indexValue = nonZeroIterator.next();
                        }
                        double index3 = d + (mean * mean2 * ((indexValue.getIndex() - i) - 1));
                        i = indexValue.getIndex();
                        double value5 = indexValue.getValue() - mean;
                        d = index3 + (value5 * (-mean2));
                        d2 += value5 * value5;
                        z2 = true;
                    }
                }
                while (true) {
                    if (!z3 || (z3 && nonZeroIterator2.hasNext())) {
                        if (z3) {
                            indexValue2 = nonZeroIterator2.next();
                        }
                        double index4 = d + (mean * mean2 * ((indexValue2.getIndex() - i) - 1));
                        i = indexValue2.getIndex();
                        double value6 = indexValue2.getValue() - mean2;
                        d = index4 + ((-mean) * value6);
                        d3 += value6 * value6;
                        z3 = true;
                    }
                }
                d += mean * mean2 * ((vec.length() - i) - 1);
                d2 += mean * mean * (vec.length() - vec.nnz());
                d3 += mean2 * mean2 * (vec2.length() - vec2.nnz());
            }
        } else {
            for (int i2 = 0; i2 < vec.length(); i2++) {
                double d4 = vec.get(i2);
                double d5 = vec2.get(i2);
                if (!z || (d4 != 0.0d && d5 != 0.0d)) {
                    double d6 = d4 - mean;
                    double d7 = d5 - mean2;
                    d += d6 * d7;
                    d2 += d6 * d6;
                    d3 += d7 * d7;
                }
            }
        }
        if (d3 == 0.0d && d2 == 0.0d) {
            return 0.0d;
        }
        return (d3 == 0.0d || d2 == 0.0d) ? d / Math.sqrt((d2 + 1.0E-10d) * (d3 + 1.0E-10d)) : d / Math.sqrt(d2 * d3);
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public boolean supportsAcceleration() {
        return false;
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public List<Double> getAccelerationCache(List<? extends Vec> list) {
        return null;
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public double dist(int i, int i2, List<? extends Vec> list, List<Double> list2) {
        return dist(list.get(i), list.get(i2));
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public double dist(int i, Vec vec, List<? extends Vec> list, List<Double> list2) {
        return dist(list.get(i), vec);
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public List<Double> getQueryInfo(Vec vec) {
        return null;
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public List<Double> getAccelerationCache(List<? extends Vec> list, ExecutorService executorService) {
        return null;
    }

    @Override // jsat.linear.distancemetrics.DistanceMetric
    public double dist(int i, Vec vec, List<Double> list, List<? extends Vec> list2, List<Double> list3) {
        return dist(list2.get(i), vec);
    }
}
