package jsat.linear.vectorcollection;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Stack;
import java.util.concurrent.ExecutorService;
import jsat.linear.Vec;
import jsat.linear.VecPaired;
import jsat.linear.VecPairedComparable;
import jsat.linear.distancemetrics.ChebyshevDistance;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.distancemetrics.ManhattanDistance;
import jsat.linear.distancemetrics.MinkowskiDistance;
import jsat.math.OnLineStatistics;
import jsat.utils.BoundedSortedList;
import jsat.utils.DoubleList;
import jsat.utils.FakeExecutor;
import jsat.utils.IntList;
import jsat.utils.ListUtils;
import jsat.utils.ModifiableCountDownLatch;
import jsat.utils.ProbailityMatch;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/linear/vectorcollection/KDTree.class */
public class KDTree<V extends Vec> implements VectorCollection<V> {
    private static final long serialVersionUID = -7401342201406776463L;
    private DistanceMetric distanceMetric;
    private KDTree<V>.KDNode root;
    private PivotSelection pvSelection;
    private int size;
    private List<V> allVecs;
    private List<Double> distCache;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/linear/vectorcollection/KDTree$KDNode.class */
    public class KDNode implements Cloneable, Serializable {
        int locatin;
        int axis;
        KDTree<V>.KDNode left;
        KDTree<V>.KDNode right;

        public KDNode(int i, int i2) {
            this.locatin = i;
            this.axis = i2;
        }

        public void setAxis(int i) {
            this.axis = i;
        }

        public void setLeft(KDTree<V>.KDNode kDNode) {
            this.left = kDNode;
        }

        public void setLocatin(int i) {
            this.locatin = i;
        }

        public void setRight(KDTree<V>.KDNode kDNode) {
            this.right = kDNode;
        }

        public int getAxis() {
            return this.axis;
        }

        public KDTree<V>.KDNode getLeft() {
            return this.left;
        }

        public int getLocatin() {
            return this.locatin;
        }

        public KDTree<V>.KDNode getRight() {
            return this.right;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* renamed from: clone, reason: merged with bridge method [inline-methods] */
        public KDTree<V>.KDNode m661clone() {
            KDTree<V>.KDNode kDNode = new KDNode(this.locatin, this.axis);
            if (this.left != null) {
                kDNode.left = this.left.m661clone();
            }
            if (this.right != null) {
                kDNode.right = this.right.m661clone();
            }
            return kDNode;
        }
    }

    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/linear/vectorcollection/KDTree$KDTreeFactory.class */
    public static class KDTreeFactory<V extends Vec> implements VectorCollectionFactory<V> {
        private static final long serialVersionUID = 3508731608962277804L;
        private PivotSelection pivotSelectionMethod;

        public KDTreeFactory(PivotSelection pivotSelection) {
            this.pivotSelectionMethod = pivotSelection;
        }

        public KDTreeFactory() {
            this(PivotSelection.Variance);
        }

        public PivotSelection getPivotSelectionMethod() {
            return this.pivotSelectionMethod;
        }

        public void setPivotSelectionMethod(PivotSelection pivotSelection) {
            this.pivotSelectionMethod = pivotSelection;
        }

        @Override // jsat.linear.vectorcollection.VectorCollectionFactory
        public VectorCollection<V> getVectorCollection(List<V> list, DistanceMetric distanceMetric) {
            return getVectorCollection(list, distanceMetric, null);
        }

        @Override // jsat.linear.vectorcollection.VectorCollectionFactory
        public VectorCollection<V> getVectorCollection(List<V> list, DistanceMetric distanceMetric, ExecutorService executorService) {
            return new KDTree(list, distanceMetric, this.pivotSelectionMethod, executorService);
        }

        @Override // jsat.linear.vectorcollection.VectorCollectionFactory
        /* renamed from: clone, reason: merged with bridge method [inline-methods] */
        public KDTreeFactory<V> m662clone() {
            return new KDTreeFactory<>(this.pivotSelectionMethod);
        }
    }

    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/linear/vectorcollection/KDTree$PivotSelection.class */
    public enum PivotSelection {
        Incremental,
        Variance
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/linear/vectorcollection/KDTree$VecIndexComparator.class */
    public class VecIndexComparator implements Comparator<Integer> {
        private final int index;

        public VecIndexComparator(int i) {
            this.index = i;
        }

        @Override // java.util.Comparator
        public int compare(Integer num, Integer num2) {
            return Double.compare(((Vec) KDTree.this.allVecs.get(num.intValue())).get(this.index), ((Vec) KDTree.this.allVecs.get(num2.intValue())).get(this.index));
        }
    }

    public KDTree(List<V> list, DistanceMetric distanceMetric, PivotSelection pivotSelection, ExecutorService executorService) {
        if (!(distanceMetric instanceof EuclideanDistance) && !(distanceMetric instanceof ChebyshevDistance) && !(distanceMetric instanceof ManhattanDistance) && !(distanceMetric instanceof MinkowskiDistance)) {
            throw new ArithmeticException("KD Trees are not compatible with the given distance metric.");
        }
        this.distanceMetric = distanceMetric;
        this.pvSelection = pivotSelection;
        this.size = list.size();
        ArrayList arrayList = new ArrayList(list);
        this.allVecs = arrayList;
        if (executorService == null || (executorService instanceof FakeExecutor)) {
            this.distCache = distanceMetric.getAccelerationCache(this.allVecs);
        } else {
            this.distCache = distanceMetric.getAccelerationCache(arrayList, executorService);
        }
        IntList intList = new IntList(this.size);
        ListUtils.addRange(intList, 0, this.size, 1);
        if (executorService == null) {
            this.root = buildTree(intList, 0, null, null);
            return;
        }
        ModifiableCountDownLatch modifiableCountDownLatch = new ModifiableCountDownLatch(1);
        this.root = buildTree(intList, 0, executorService, modifiableCountDownLatch);
        try {
            modifiableCountDownLatch.await();
        } catch (InterruptedException e) {
            this.root = buildTree(intList, 0, null, null);
        }
    }

    public KDTree(List<V> list, DistanceMetric distanceMetric, PivotSelection pivotSelection) {
        this(list, distanceMetric, pivotSelection, null);
    }

    public KDTree(List<V> list, DistanceMetric distanceMetric) {
        this(list, distanceMetric, PivotSelection.Variance);
    }

    private KDTree(DistanceMetric distanceMetric, PivotSelection pivotSelection) {
        this.distanceMetric = distanceMetric;
        this.pvSelection = pivotSelection;
    }

    public KDTree() {
        this(new EuclideanDistance(), PivotSelection.Variance);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public KDTree<V>.KDNode buildTree(final List<Integer> list, final int i, final ExecutorService executorService, final ModifiableCountDownLatch modifiableCountDownLatch) {
        if (list == null || list.isEmpty()) {
            if (executorService == null) {
                return null;
            }
            modifiableCountDownLatch.countDown();
            return null;
        }
        int length = this.allVecs.get(0).length();
        if (list.size() == 1) {
            if (executorService != null) {
                modifiableCountDownLatch.countDown();
            }
            return new KDNode(list.get(0).intValue(), i % length);
        }
        int i2 = -1;
        if (this.pvSelection == PivotSelection.Incremental) {
            i2 = i % length;
        } else {
            OnLineStatistics[] onLineStatisticsArr = new OnLineStatistics[length];
            for (int i3 = 0; i3 < onLineStatisticsArr.length; i3++) {
                onLineStatisticsArr[i3] = new OnLineStatistics();
            }
            for (int i4 = 0; i4 < list.size(); i4++) {
                V v = this.allVecs.get(list.get(i4).intValue());
                for (int i5 = 0; i5 < onLineStatisticsArr.length; i5++) {
                    onLineStatisticsArr[i5].add(v.get(i5));
                }
            }
            double d = -1.0d;
            for (int i6 = 0; i6 < onLineStatisticsArr.length; i6++) {
                if (onLineStatisticsArr[i6].getVarance() > d) {
                    d = onLineStatisticsArr[i6].getVarance();
                    i2 = i6;
                }
            }
            if (i2 < 0) {
                i2 = i % length;
            }
        }
        Collections.sort(list, new VecIndexComparator(i2));
        final int size = list.size() / 2;
        final KDTree<V>.KDNode kDNode = new KDNode(list.get(size).intValue(), i2);
        if (executorService == null) {
            kDNode.setLeft(buildTree(list.subList(0, size), i + 1, executorService, modifiableCountDownLatch));
            kDNode.setRight(buildTree(list.subList(size + 1, list.size()), i + 1, executorService, modifiableCountDownLatch));
        } else {
            modifiableCountDownLatch.countUp();
            executorService.submit(new Runnable() { // from class: jsat.linear.vectorcollection.KDTree.1
                @Override // java.lang.Runnable
                public void run() {
                    kDNode.setRight(KDTree.this.buildTree(list.subList(size + 1, list.size()), i + 1, executorService, modifiableCountDownLatch));
                }
            });
            kDNode.setLeft(buildTree(list.subList(0, size), i + 1, executorService, modifiableCountDownLatch));
        }
        return kDNode;
    }

    /* JADX WARN: Type inference failed for: r0v17, types: [jsat.linear.Vec, double, java.lang.Object] */
    private void knnKDSearch(Vec vec, BoundedSortedList<ProbailityMatch<V>> boundedSortedList) {
        Stack stack = new Stack();
        stack.push(this.root);
        List<Double> queryInfo = this.distanceMetric.supportsAcceleration() ? this.distanceMetric.getQueryInfo(vec) : null;
        while (!stack.isEmpty()) {
            KDNode kDNode = (KDNode) stack.pop();
            if (kDNode != null) {
                V v = this.allVecs.get(kDNode.locatin);
                boundedSortedList.add((BoundedSortedList<ProbailityMatch<V>>) new ProbailityMatch<>(this.distanceMetric.dist(kDNode.locatin, vec, queryInfo, this.allVecs, this.distCache), v));
                double d = vec.get(kDNode.axis);
                double d2 = v.get(kDNode.axis);
                if (v - d2 <= 0.0d) {
                    if (d - boundedSortedList.last().getProbability() <= d2 || boundedSortedList.size() < boundedSortedList.maxSize()) {
                        stack.push(kDNode.left);
                    }
                    if (d + boundedSortedList.last().getProbability() > d2 || boundedSortedList.size() < boundedSortedList.maxSize()) {
                        stack.push(kDNode.right);
                    }
                } else {
                    if (d + boundedSortedList.last().getProbability() > d2 || boundedSortedList.size() < boundedSortedList.maxSize()) {
                        stack.push(kDNode.right);
                    }
                    if (d - boundedSortedList.last().getProbability() <= d2 || boundedSortedList.size() < boundedSortedList.maxSize()) {
                        stack.push(kDNode.left);
                    }
                }
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // jsat.linear.vectorcollection.VectorCollection
    public List<? extends VecPaired<V, Double>> search(Vec vec, int i) {
        if (i < 1) {
            throw new RuntimeException("Invalid number of neighbors to search for");
        }
        BoundedSortedList<ProbailityMatch<V>> boundedSortedList = new BoundedSortedList<>(i);
        knnKDSearch(vec, boundedSortedList);
        ArrayList arrayList = new ArrayList(boundedSortedList.size());
        for (int i2 = 0; i2 < boundedSortedList.size(); i2++) {
            ProbailityMatch probailityMatch = (ProbailityMatch) boundedSortedList.get(i2);
            arrayList.add(new VecPaired((Vec) probailityMatch.getMatch(), Double.valueOf(probailityMatch.getProbability())));
        }
        return arrayList;
    }

    private void distanceSearch(Vec vec, List<Double> list, KDTree<V>.KDNode kDNode, List<VecPairedComparable<V, Double>> list2, double d) {
        if (kDNode == null) {
            return;
        }
        V v = this.allVecs.get(kDNode.locatin);
        double dist = this.distanceMetric.dist(kDNode.locatin, vec, list, this.allVecs, this.distCache);
        if (dist <= d) {
            list2.add(new VecPairedComparable<>(v, Double.valueOf(dist)));
        }
        double d2 = vec.get(kDNode.axis) - v.get(kDNode.axis);
        KDTree<V>.KDNode kDNode2 = kDNode.left;
        KDTree<V>.KDNode kDNode3 = kDNode.right;
        if (d2 > 0.0d) {
            kDNode2 = kDNode.right;
            kDNode3 = kDNode.left;
        }
        distanceSearch(vec, list, kDNode2, list2, d);
        if (d2 * d2 <= d) {
            distanceSearch(vec, list, kDNode3, list2, d);
        }
    }

    @Override // jsat.linear.vectorcollection.VectorCollection
    public int size() {
        return this.size;
    }

    @Override // jsat.linear.vectorcollection.VectorCollection
    public List<? extends VecPaired<V, Double>> search(Vec vec, double d) {
        if (d <= 0.0d) {
            throw new RuntimeException("Range must be a positive number");
        }
        ArrayList arrayList = new ArrayList();
        distanceSearch(vec, this.distanceMetric.supportsAcceleration() ? this.distanceMetric.getQueryInfo(vec) : null, this.root, arrayList, d);
        Collections.sort(arrayList);
        return arrayList;
    }

    @Override // jsat.linear.vectorcollection.VectorCollection
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public KDTree<V> m660clone() {
        KDTree<V> kDTree = new KDTree<>(this.distanceMetric, this.pvSelection);
        if (this.distCache != null) {
            kDTree.distCache = new DoubleList(this.distCache);
        }
        if (this.allVecs != null) {
            kDTree.allVecs = new ArrayList(this.allVecs);
        }
        kDTree.size = this.size;
        if (this.root != null) {
            kDTree.root = this.root.m661clone();
        }
        return kDTree;
    }
}
