package jsat.linear.vectorcollection;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.linear.Vec;
import jsat.linear.VecPaired;
import jsat.linear.VecPairedComparable;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.utils.BoundedSortedList;
import jsat.utils.DoubleList;
import jsat.utils.FakeExecutor;
import jsat.utils.IntList;
import jsat.utils.ModifiableCountDownLatch;
import jsat.utils.Pair;
import jsat.utils.ProbailityMatch;
import jsat.utils.SimpleList;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/linear/vectorcollection/VPTree.class */
public class VPTree<V extends Vec> implements VectorCollection<V> {
    private static final long serialVersionUID = -7271540108746353762L;
    private DistanceMetric dm;
    private List<Double> distCache;
    private List<V> allVecs;
    private Random rand;
    private int sampleSize;
    private int searchIterations;
    private volatile VPTree<V>.TreeNode root;
    private VPSelection vpSelection;
    private int size;
    private int maxLeafSize;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/linear/vectorcollection/VPTree$TreeNode.class */
    public abstract class TreeNode implements Cloneable, Serializable {
        private TreeNode() {
        }

        public abstract void searchKNN(Vec vec, int i, BoundedSortedList<ProbailityMatch<V>> boundedSortedList, double d, List<Double> list);

        public abstract void searchRange(Vec vec, double d, List<VecPaired<V, Double>> list, double d2, List<Double> list2);

        @Override // 
        /* renamed from: clone, reason: merged with bridge method [inline-methods] */
        public abstract VPTree<V>.TreeNode mo673clone();
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/linear/vectorcollection/VPTree$VPLeaf.class */
    public class VPLeaf extends VPTree<V>.TreeNode {
        int[] points;
        double[] bounds;

        public VPLeaf(List<Pair<Double, Integer>> list) {
            super();
            this.points = new int[list.size()];
            this.bounds = new double[this.points.length];
            for (int i = 0; i < this.points.length; i++) {
                this.points[i] = list.get(i).getSecondItem().intValue();
                this.bounds[i] = list.get(i).getFirstItem().doubleValue();
            }
        }

        public VPLeaf(int[] iArr, double[] dArr) {
            super();
            this.bounds = Arrays.copyOf(dArr, dArr.length);
            this.points = new int[iArr.length];
            for (int i = 0; i < iArr.length; i++) {
                this.points[i] = iArr[i];
            }
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // jsat.linear.vectorcollection.VPTree.TreeNode
        public void searchKNN(Vec vec, int i, BoundedSortedList<ProbailityMatch<V>> boundedSortedList, double d, List<Double> list) {
            double probability = boundedSortedList.size() == 0 ? Double.MAX_VALUE : ((ProbailityMatch) boundedSortedList.get(boundedSortedList.size() - 1)).getProbability();
            for (int i2 = 0; i2 < this.points.length; i2++) {
                if (boundedSortedList.size() < i) {
                    boundedSortedList.add((BoundedSortedList<ProbailityMatch<V>>) new ProbailityMatch<>(VPTree.this.dm.dist(this.points[i2], vec, list, VPTree.this.allVecs, VPTree.this.distCache), VPTree.this.allVecs.get(this.points[i2])));
                    probability = ((ProbailityMatch) boundedSortedList.get(boundedSortedList.size() - 1)).getProbability();
                } else if (this.bounds[i2] - probability <= d && d <= this.bounds[i2] + probability) {
                    double dist = VPTree.this.dm.dist(this.points[i2], vec, list, VPTree.this.allVecs, VPTree.this.distCache);
                    if (dist < probability) {
                        boundedSortedList.add((BoundedSortedList<ProbailityMatch<V>>) new ProbailityMatch<>(dist, VPTree.this.allVecs.get(this.points[i2])));
                        probability = ((ProbailityMatch) boundedSortedList.get(boundedSortedList.size() - 1)).getProbability();
                    }
                }
            }
        }

        @Override // jsat.linear.vectorcollection.VPTree.TreeNode
        public void searchRange(Vec vec, double d, List<VecPaired<V, Double>> list, double d2, List<Double> list2) {
            for (int i = 0; i < this.points.length; i++) {
                if (this.bounds[i] - d <= d2 && d2 <= this.bounds[i] + d) {
                    double dist = VPTree.this.dm.dist(this.points[i], vec, list2, VPTree.this.allVecs, VPTree.this.distCache);
                    if (dist < d) {
                        list.add(new VecPairedComparable((Vec) VPTree.this.allVecs.get(this.points[i]), Double.valueOf(dist)));
                    }
                }
            }
        }

        @Override // jsat.linear.vectorcollection.VPTree.TreeNode
        /* renamed from: clone */
        public VPTree<V>.TreeNode mo673clone() {
            return new VPLeaf(this.points, this.bounds);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/linear/vectorcollection/VPTree$VPNode.class */
    public class VPNode extends VPTree<V>.TreeNode {
        int p;
        double left_low;
        double left_high;
        double right_low;
        double right_high;
        VPTree<V>.TreeNode right;
        VPTree<V>.TreeNode left;

        public VPNode(int i) {
            super();
            this.p = i;
        }

        private boolean searchInLeft(double d, double d2) {
            return this.left != null && this.left_low - d2 <= d && d <= this.left_high + d2;
        }

        private boolean searchInRight(double d, double d2) {
            return this.right != null && this.right_low - d2 <= d && d <= this.right_high + d2;
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // jsat.linear.vectorcollection.VPTree.TreeNode
        public void searchKNN(Vec vec, int i, BoundedSortedList<ProbailityMatch<V>> boundedSortedList, double d, List<Double> list) {
            double dist = VPTree.this.dm.dist(this.p, vec, list, VPTree.this.allVecs, VPTree.this.distCache);
            if (boundedSortedList.size() < i || dist < ((ProbailityMatch) boundedSortedList.get(i - 1)).getProbability()) {
                boundedSortedList.add((BoundedSortedList<ProbailityMatch<V>>) new ProbailityMatch<>(dist, VPTree.this.allVecs.get(this.p)));
            }
            double probability = ((ProbailityMatch) boundedSortedList.get(boundedSortedList.size() - 1)).getProbability();
            if (dist < (this.left_high + this.right_low) * 0.5d) {
                if (searchInLeft(dist, probability) || boundedSortedList.size() < i) {
                    this.left.searchKNN(vec, i, boundedSortedList, dist, list);
                }
                if (searchInRight(dist, ((ProbailityMatch) boundedSortedList.get(boundedSortedList.size() - 1)).getProbability()) || boundedSortedList.size() < i) {
                    this.right.searchKNN(vec, i, boundedSortedList, dist, list);
                    return;
                }
                return;
            }
            if (searchInRight(dist, probability) || boundedSortedList.size() < i) {
                this.right.searchKNN(vec, i, boundedSortedList, dist, list);
            }
            if (searchInLeft(dist, ((ProbailityMatch) boundedSortedList.get(boundedSortedList.size() - 1)).getProbability()) || boundedSortedList.size() < i) {
                this.left.searchKNN(vec, i, boundedSortedList, dist, list);
            }
        }

        @Override // jsat.linear.vectorcollection.VPTree.TreeNode
        public void searchRange(Vec vec, double d, List<VecPaired<V, Double>> list, double d2, List<Double> list2) {
            double dist = VPTree.this.dm.dist(this.p, vec, list2, VPTree.this.allVecs, VPTree.this.distCache);
            if (dist <= d) {
                list.add(new VecPairedComparable((Vec) VPTree.this.allVecs.get(this.p), Double.valueOf(dist)));
            }
            if (searchInLeft(dist, d)) {
                this.left.searchRange(vec, d, list, dist, list2);
            }
            if (searchInRight(dist, d)) {
                this.right.searchRange(vec, d, list, dist, list2);
            }
        }

        @Override // jsat.linear.vectorcollection.VPTree.TreeNode
        /* renamed from: clone */
        public VPTree<V>.TreeNode mo673clone() {
            VPNode vPNode = new VPNode(this.p);
            vPNode.left_low = this.left_low;
            vPNode.left_high = this.left_high;
            vPNode.right_low = this.right_low;
            vPNode.right_high = this.right_high;
            if (this.left != null) {
                vPNode.left = this.left.mo673clone();
            }
            if (this.right != null) {
                vPNode.right = this.right.mo673clone();
            }
            return vPNode;
        }
    }

    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/linear/vectorcollection/VPTree$VPSelection.class */
    public enum VPSelection {
        Sampling,
        Random
    }

    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/linear/vectorcollection/VPTree$VPTreeFactory.class */
    public static class VPTreeFactory<V extends Vec> implements VectorCollectionFactory<V> {
        private static final long serialVersionUID = -2739851193676265510L;
        private VPSelection vpSelectionMethod;

        public VPTreeFactory(VPSelection vPSelection) {
            this.vpSelectionMethod = vPSelection;
        }

        public VPTreeFactory() {
            this(VPSelection.Random);
        }

        @Override // jsat.linear.vectorcollection.VectorCollectionFactory
        public VectorCollection<V> getVectorCollection(List<V> list, DistanceMetric distanceMetric) {
            return new VPTree(list, distanceMetric, this.vpSelectionMethod);
        }

        @Override // jsat.linear.vectorcollection.VectorCollectionFactory
        public VectorCollection<V> getVectorCollection(List<V> list, DistanceMetric distanceMetric, ExecutorService executorService) {
            return new VPTree(list, distanceMetric, this.vpSelectionMethod, new Random(10L), 80, 40, executorService);
        }

        @Override // jsat.linear.vectorcollection.VectorCollectionFactory
        /* renamed from: clone, reason: merged with bridge method [inline-methods] */
        public VectorCollectionFactory<V> m675clone() {
            return new VPTreeFactory(this.vpSelectionMethod);
        }
    }

    public VPTree(List<V> list, DistanceMetric distanceMetric, VPSelection vPSelection, Random random, int i, int i2, ExecutorService executorService) {
        this.maxLeafSize = 5;
        this.dm = distanceMetric;
        if (!distanceMetric.isSubadditive()) {
            throw new RuntimeException("VPTree only supports metrics that support the triangle inequality");
        }
        this.rand = random;
        this.sampleSize = i;
        this.searchIterations = i2;
        this.size = list.size();
        this.vpSelection = vPSelection;
        this.allVecs = list;
        if (executorService == null || (executorService instanceof FakeExecutor)) {
            this.distCache = distanceMetric.getAccelerationCache(this.allVecs);
        } else {
            this.distCache = distanceMetric.getAccelerationCache(this.allVecs, executorService);
        }
        SimpleList simpleList = new SimpleList(list.size());
        for (int i3 = 0; i3 < this.allVecs.size(); i3++) {
            simpleList.add(new Pair<>(Double.valueOf(-1.0d), Integer.valueOf(i3)));
        }
        if (executorService == null) {
            this.root = makeVPTree(simpleList);
            return;
        }
        ModifiableCountDownLatch modifiableCountDownLatch = new ModifiableCountDownLatch(1);
        this.root = makeVPTree(simpleList, executorService, modifiableCountDownLatch);
        modifiableCountDownLatch.countDown();
        try {
            modifiableCountDownLatch.await();
        } catch (InterruptedException e) {
            Logger.getLogger(VPTree.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
            System.err.println("Falling back to single threaded VPTree constructor");
            simpleList.clear();
            for (int i4 = 0; i4 < list.size(); i4++) {
                simpleList.add(new Pair<>(Double.valueOf(-1.0d), Integer.valueOf(i4)));
            }
            this.root = makeVPTree(simpleList);
        }
    }

    public VPTree(List<V> list, DistanceMetric distanceMetric, VPSelection vPSelection, Random random, int i, int i2) {
        this(list, distanceMetric, vPSelection, random, i, i2, null);
    }

    public VPTree(List<V> list, DistanceMetric distanceMetric, VPSelection vPSelection) {
        this(list, distanceMetric, vPSelection, new Random(), 80, 40);
    }

    public VPTree(List<V> list, DistanceMetric distanceMetric) {
        this(list, distanceMetric, VPSelection.Random);
    }

    protected VPTree(VPTree<V> vPTree) {
        this.maxLeafSize = 5;
        this.dm = vPTree.dm.m655clone();
        this.rand = this.rand == null ? null : new Random(this.rand.nextInt());
        this.sampleSize = vPTree.sampleSize;
        this.searchIterations = vPTree.searchIterations;
        this.root = vPTree.root == null ? null : vPTree.root.mo673clone();
        this.vpSelection = vPTree.vpSelection;
        this.size = vPTree.size;
        this.maxLeafSize = vPTree.maxLeafSize;
        if (vPTree.allVecs != null) {
            this.allVecs = new ArrayList(vPTree.allVecs);
        }
        if (vPTree.distCache != null) {
            this.distCache = new DoubleList(vPTree.distCache);
        }
    }

    protected VPTree() {
        this.maxLeafSize = 5;
    }

    @Override // jsat.linear.vectorcollection.VectorCollection
    public int size() {
        return this.size;
    }

    @Override // jsat.linear.vectorcollection.VectorCollection
    public List<? extends VecPaired<V, Double>> search(Vec vec, double d) {
        if (d <= 0.0d) {
            throw new RuntimeException("Range must be a positive number");
        }
        ArrayList arrayList = new ArrayList();
        this.root.searchRange(VecPaired.extractTrueVec(vec), d, arrayList, 0.0d, this.dm.getQueryInfo(vec));
        Collections.sort(arrayList);
        return arrayList;
    }

    @Override // jsat.linear.vectorcollection.VectorCollection
    public List<? extends VecPaired<V, Double>> search(Vec vec, int i) {
        BoundedSortedList<ProbailityMatch<V>> boundedSortedList = new BoundedSortedList<>(i, i);
        this.root.searchKNN(VecPaired.extractTrueVec(vec), i, boundedSortedList, 0.0d, this.dm.getQueryInfo(vec));
        ArrayList arrayList = new ArrayList(boundedSortedList.size());
        Iterator<E> it = boundedSortedList.iterator();
        while (it.hasNext()) {
            ProbailityMatch probailityMatch = (ProbailityMatch) it.next();
            arrayList.add(new VecPaired((Vec) probailityMatch.getMatch(), Double.valueOf(probailityMatch.getProbability())));
        }
        return arrayList;
    }

    private int sortSplitSet(List<Pair<Double, Integer>> list, VPTree<V>.VPNode vPNode) {
        for (Pair<Double, Integer> pair : list) {
            pair.setFirstItem(Double.valueOf(this.dm.dist(vPNode.p, pair.getSecondItem().intValue(), (List<? extends Vec>) this.allVecs, this.distCache)));
        }
        Collections.sort(list, new Comparator<Pair<Double, Integer>>() { // from class: jsat.linear.vectorcollection.VPTree.1
            @Override // java.util.Comparator
            public int compare(Pair<Double, Integer> pair2, Pair<Double, Integer> pair3) {
                return Double.compare(pair2.getFirstItem().doubleValue(), pair3.getFirstItem().doubleValue());
            }
        });
        int splitListIndex = splitListIndex(list);
        vPNode.left_low = list.get(0).getFirstItem().doubleValue();
        vPNode.left_high = list.get(splitListIndex).getFirstItem().doubleValue();
        vPNode.right_low = list.get(splitListIndex + 1).getFirstItem().doubleValue();
        vPNode.right_high = list.get(list.size() - 1).getFirstItem().doubleValue();
        return splitListIndex;
    }

    protected int splitListIndex(List<Pair<Double, Integer>> list) {
        return list.size() / 2;
    }

    public int getMaxLeafSize() {
        return this.maxLeafSize;
    }

    public void setMaxLeafSize(int i) {
        this.maxLeafSize = Math.max(5, i);
    }

    private VPTree<V>.TreeNode makeVPTree(List<Pair<Double, Integer>> list) {
        if (list.isEmpty()) {
            return null;
        }
        if (list.size() <= this.maxLeafSize) {
            return new VPLeaf(list);
        }
        int selectVantagePointIndex = selectVantagePointIndex(list);
        VPTree<V>.VPNode vPNode = new VPNode(list.get(selectVantagePointIndex).getSecondItem().intValue());
        Collections.swap(list, 0, selectVantagePointIndex);
        int sortSplitSet = sortSplitSet(list.subList(1, list.size()), vPNode) + 1;
        vPNode.right = makeVPTree(list.subList(sortSplitSet + 1, list.size()));
        vPNode.left = makeVPTree(list.subList(1, sortSplitSet + 1));
        return vPNode;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public VPTree<V>.TreeNode makeVPTree(List<Pair<Double, Integer>> list, final ExecutorService executorService, final ModifiableCountDownLatch modifiableCountDownLatch) {
        if (list.isEmpty()) {
            return null;
        }
        if (list.size() <= this.maxLeafSize) {
            return new VPLeaf(list);
        }
        int selectVantagePointIndex = selectVantagePointIndex(list);
        final VPTree<V>.VPNode vPNode = new VPNode(list.get(selectVantagePointIndex).getSecondItem().intValue());
        Collections.swap(list, 0, selectVantagePointIndex);
        int sortSplitSet = sortSplitSet(list.subList(1, list.size()), vPNode) + 1;
        modifiableCountDownLatch.countUp();
        final List<Pair<Double, Integer>> subList = list.subList(sortSplitSet + 1, list.size());
        List<Pair<Double, Integer>> subList2 = list.subList(1, sortSplitSet + 1);
        executorService.submit(new Runnable() { // from class: jsat.linear.vectorcollection.VPTree.2
            @Override // java.lang.Runnable
            public void run() {
                vPNode.right = VPTree.this.makeVPTree(subList, executorService, modifiableCountDownLatch);
                modifiableCountDownLatch.countDown();
            }
        });
        vPNode.left = makeVPTree(subList2, executorService, modifiableCountDownLatch);
        return vPNode;
    }

    private int selectVantagePointIndex(List<Pair<Double, Integer>> list) {
        int i;
        if (this.vpSelection == VPSelection.Random) {
            i = this.rand.nextInt(list.size());
        } else {
            IntList intList = new IntList(this.sampleSize);
            if (this.sampleSize <= list.size()) {
                for (int i2 = 0; i2 < this.sampleSize; i2++) {
                    intList.add((IntList) list.get(i2).getSecondItem());
                }
            } else {
                for (int i3 = 0; i3 < this.sampleSize; i3++) {
                    intList.add((IntList) list.get(this.rand.nextInt(list.size())).getSecondItem());
                }
            }
            double[] dArr = new double[this.sampleSize];
            int i4 = -1;
            double d = Double.NEGATIVE_INFINITY;
            for (int i5 = 0; i5 < Math.min(this.searchIterations, list.size()); i5++) {
                int nextInt = this.searchIterations <= list.size() ? i5 : this.rand.nextInt(list.size());
                int intValue = list.get(nextInt).getSecondItem().intValue();
                for (int i6 = 0; i6 < intList.size(); i6++) {
                    dArr[i6] = this.dm.dist(intValue, intList.get(i6).intValue(), (List<? extends Vec>) this.allVecs, this.distCache);
                }
                Arrays.sort(dArr);
                double d2 = dArr[dArr.length / 2];
                double d3 = 0.0d;
                for (double d4 : dArr) {
                    d3 += Math.abs(d4 - d2);
                }
                if (d3 > d) {
                    d = d3;
                    i4 = nextInt;
                }
            }
            i = i4;
        }
        return i;
    }

    private int selectVantagePoint(List<Pair<Double, Integer>> list) {
        return list.get(selectVantagePointIndex(list)).getSecondItem().intValue();
    }

    @Override // jsat.linear.vectorcollection.VectorCollection
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public VPTree<V> m672clone() {
        return new VPTree<>(this);
    }
}
