package smile.neighbor;

import java.io.Serializable;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.MathEx;
import smile.math.distance.Metric;
import smile.sort.DoubleHeapSelect;

/* loaded from: input_file:smile-core-2.4.0.jar:smile/neighbor/CoverTree.class */
public class CoverTree<E> implements NearestNeighborSearch<E, E>, KNNSearch<E, E>, RNNSearch<E, E>, Serializable {
    private static final long serialVersionUID = 2;
    private static final Logger logger = LoggerFactory.getLogger(CoverTree.class);
    private E[] data;
    private Metric<E> distance;
    private CoverTree<E>.Node root;
    private double base;
    private double invLogBase;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:smile-core-2.4.0.jar:smile/neighbor/CoverTree$DistanceNode.class */
    public class DistanceNode implements Comparable<CoverTree<E>.DistanceNode> {
        double dist;
        CoverTree<E>.Node node;

        DistanceNode(double d, CoverTree<E>.Node node) {
            this.dist = d;
            this.node = node;
        }

        @Override // java.lang.Comparable
        public int compareTo(CoverTree<E>.DistanceNode distanceNode) {
            return Double.compare(this.dist, distanceNode.dist);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:smile-core-2.4.0.jar:smile/neighbor/CoverTree$DistanceSet.class */
    public class DistanceSet {
        int idx;
        ArrayList<Double> dist = new ArrayList<>();

        DistanceSet(int i) {
            this.idx = i;
        }

        E getObject() {
            return (E) CoverTree.this.data[this.idx];
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:smile-core-2.4.0.jar:smile/neighbor/CoverTree$Node.class */
    public class Node implements Serializable {
        int idx;
        double maxDist;
        double parentDist;
        ArrayList<CoverTree<E>.Node> children;
        int scale;

        Node(int i) {
            this.idx = i;
        }

        Node(int i, double d, double d2, ArrayList<CoverTree<E>.Node> arrayList, int i2) {
            this.idx = i;
            this.maxDist = d;
            this.parentDist = d2;
            this.children = arrayList;
            this.scale = i2;
        }

        E getObject() {
            return (E) CoverTree.this.data[this.idx];
        }

        boolean isLeaf() {
            return this.children == null;
        }
    }

    public CoverTree(E[] eArr, Metric<E> metric) {
        this(eArr, metric, 1.3d);
    }

    public CoverTree(E[] eArr, Metric<E> metric, double d) {
        this.base = 1.3d;
        this.invLogBase = 1.0d / Math.log(this.base);
        if (eArr.length == 0) {
            throw new IllegalArgumentException("Empty dataset");
        }
        this.data = eArr;
        this.distance = metric;
        this.base = d;
        this.invLogBase = 1.0d / Math.log(d);
        buildCoverTree();
    }

    public String toString() {
        return String.format("Cover Tree (%s)", this.distance);
    }

    private void buildCoverTree() {
        ArrayList<CoverTree<E>.DistanceSet> arrayList = new ArrayList<>();
        ArrayList<CoverTree<E>.DistanceSet> arrayList2 = new ArrayList<>();
        E e = this.data[0];
        double d = -1.0d;
        for (int i = 1; i < this.data.length; i++) {
            CoverTree<E>.DistanceSet distanceSet = new DistanceSet(i);
            double d2 = this.distance.d(e, this.data[i]);
            distanceSet.dist.add(Double.valueOf(d2));
            arrayList.add(distanceSet);
            if (d2 > d) {
                d = d2;
            }
        }
        this.root = batchInsert(0, getScale(d), getScale(d), arrayList, arrayList2);
    }

    private CoverTree<E>.Node batchInsert(int i, int i2, int i3, ArrayList<CoverTree<E>.DistanceSet> arrayList, ArrayList<CoverTree<E>.DistanceSet> arrayList2) {
        if (arrayList.isEmpty()) {
            return newLeaf(i);
        }
        int min = Math.min(i2 - 1, getScale(max(arrayList)));
        if (min == Integer.MIN_VALUE) {
            ArrayList<CoverTree<E>.Node> arrayList3 = new ArrayList<>();
            arrayList3.add(newLeaf(i));
            while (!arrayList.isEmpty()) {
                CoverTree<E>.DistanceSet distanceSet = arrayList.get(arrayList.size() - 1);
                arrayList.remove(arrayList.size() - 1);
                arrayList3.add(newLeaf(distanceSet.idx));
                arrayList2.add(distanceSet);
            }
            CoverTree<E>.Node node = new Node(i);
            node.scale = 100;
            node.maxDist = 0.0d;
            node.children = arrayList3;
            return node;
        }
        ArrayList<CoverTree<E>.DistanceSet> arrayList4 = new ArrayList<>();
        split(arrayList, arrayList4, i2);
        CoverTree<E>.Node batchInsert = batchInsert(i, min, i3, arrayList, arrayList2);
        if (arrayList.isEmpty()) {
            arrayList.addAll(arrayList4);
            return batchInsert;
        }
        ArrayList<CoverTree<E>.Node> arrayList5 = new ArrayList<>();
        arrayList5.add(batchInsert);
        ArrayList<CoverTree<E>.DistanceSet> arrayList6 = new ArrayList<>();
        ArrayList<CoverTree<E>.DistanceSet> arrayList7 = new ArrayList<>();
        while (!arrayList.isEmpty()) {
            CoverTree<E>.DistanceSet distanceSet2 = arrayList.get(arrayList.size() - 1);
            arrayList.remove(arrayList.size() - 1);
            double doubleValue = distanceSet2.dist.get(distanceSet2.dist.size() - 1).doubleValue();
            arrayList2.add(distanceSet2);
            distSplit(arrayList, arrayList6, distanceSet2.getObject(), i2);
            distSplit(arrayList4, arrayList6, distanceSet2.getObject(), i2);
            CoverTree<E>.Node batchInsert2 = batchInsert(distanceSet2.idx, min, i3, arrayList6, arrayList7);
            batchInsert2.parentDist = doubleValue;
            arrayList5.add(batchInsert2);
            double coverRadius = getCoverRadius(i2);
            for (int i4 = 0; i4 < arrayList6.size(); i4++) {
                CoverTree<E>.DistanceSet distanceSet3 = arrayList6.get(i4);
                distanceSet3.dist.remove(distanceSet3.dist.size() - 1);
                if (distanceSet3.dist.get(distanceSet3.dist.size() - 1).doubleValue() <= coverRadius) {
                    arrayList.add(distanceSet3);
                } else {
                    arrayList4.add(distanceSet3);
                }
            }
            for (int i5 = 0; i5 < arrayList7.size(); i5++) {
                CoverTree<E>.DistanceSet distanceSet4 = arrayList7.get(i5);
                distanceSet4.dist.remove(distanceSet4.dist.size() - 1);
                arrayList2.add(distanceSet4);
            }
            arrayList6.clear();
            arrayList7.clear();
        }
        arrayList.addAll(arrayList4);
        CoverTree<E>.Node node2 = new Node(i);
        node2.scale = i3 - i2;
        node2.maxDist = max(arrayList2);
        node2.children = arrayList5;
        return node2;
    }

    private double getCoverRadius(int i) {
        return Math.pow(this.base, i);
    }

    private int getScale(double d) {
        return (int) Math.ceil(this.invLogBase * Math.log(d));
    }

    private CoverTree<E>.Node newLeaf(int i) {
        return new Node(i, 0.0d, 0.0d, null, 100);
    }

    private double max(ArrayList<CoverTree<E>.DistanceSet> arrayList) {
        double d = 0.0d;
        Iterator<CoverTree<E>.DistanceSet> it = arrayList.iterator();
        while (it.hasNext()) {
            CoverTree<E>.DistanceSet next = it.next();
            if (d < next.dist.get(next.dist.size() - 1).doubleValue()) {
                d = next.dist.get(next.dist.size() - 1).doubleValue();
            }
        }
        return d;
    }

    private void split(ArrayList<CoverTree<E>.DistanceSet> arrayList, ArrayList<CoverTree<E>.DistanceSet> arrayList2, int i) {
        double coverRadius = getCoverRadius(i);
        ArrayList arrayList3 = new ArrayList();
        for (int i2 = 0; i2 < arrayList.size(); i2++) {
            CoverTree<E>.DistanceSet distanceSet = arrayList.get(i2);
            if (distanceSet.dist.get(distanceSet.dist.size() - 1).doubleValue() <= coverRadius) {
                arrayList3.add(distanceSet);
            } else {
                arrayList2.add(distanceSet);
            }
        }
        arrayList.clear();
        arrayList.addAll(arrayList3);
    }

    private void distSplit(ArrayList<CoverTree<E>.DistanceSet> arrayList, ArrayList<CoverTree<E>.DistanceSet> arrayList2, E e, int i) {
        double coverRadius = getCoverRadius(i);
        ArrayList arrayList3 = new ArrayList();
        for (int i2 = 0; i2 < arrayList.size(); i2++) {
            CoverTree<E>.DistanceSet distanceSet = arrayList.get(i2);
            double d = this.distance.d(e, distanceSet.getObject());
            if (d <= coverRadius) {
                arrayList.get(i2).dist.add(Double.valueOf(d));
                arrayList2.add(distanceSet);
            } else {
                arrayList3.add(distanceSet);
            }
        }
        arrayList.clear();
        arrayList.addAll(arrayList3);
    }

    @Override // smile.neighbor.NearestNeighborSearch
    public Neighbor<E, E> nearest(E e) {
        return knn(e, 1)[0];
    }

    @Override // smile.neighbor.KNNSearch
    public Neighbor<E, E>[] knn(E e, int i) {
        if (i <= 0) {
            throw new IllegalArgumentException("Invalid k: " + i);
        }
        if (i > this.data.length) {
            throw new IllegalArgumentException("Neighbor array length is larger than the dataset size");
        }
        E object = this.root.getObject();
        double d = this.distance.d(object, e);
        Neighbor<E, E> neighbor = new Neighbor<>(object, object, this.root.idx, d);
        Neighbor<E, E>[] neighborArr = (Neighbor[]) Array.newInstance(neighbor.getClass(), 1);
        if (this.root.children == null) {
            neighborArr[0] = neighbor;
            return neighborArr;
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        arrayList.add(new DistanceNode(d, this.root));
        DoubleHeapSelect doubleHeapSelect = new DoubleHeapSelect(i);
        doubleHeapSelect.add(Double.MAX_VALUE);
        boolean z = true;
        if (this.root.getObject() != e) {
            doubleHeapSelect.add(d);
            z = false;
        }
        while (!arrayList.isEmpty()) {
            ArrayList arrayList3 = new ArrayList();
            for (int i2 = 0; i2 < arrayList.size(); i2++) {
                DistanceNode distanceNode = (DistanceNode) arrayList.get(i2);
                CoverTree<E>.Node node = ((DistanceNode) arrayList.get(i2)).node;
                int i3 = 0;
                while (i3 < node.children.size()) {
                    CoverTree<E>.Node node2 = node.children.get(i3);
                    double d2 = i3 == 0 ? distanceNode.dist : this.distance.d(node2.getObject(), e);
                    double peek = z ? Double.MAX_VALUE : doubleHeapSelect.peek();
                    if (d2 <= peek + node2.maxDist) {
                        if (i3 > 0 && d2 < peek && node2.getObject() != e) {
                            doubleHeapSelect.add(d2);
                        }
                        if (node2.children != null) {
                            arrayList3.add(new DistanceNode(d2, node2));
                        } else if (d2 <= peek) {
                            arrayList2.add(new DistanceNode(d2, node2));
                        }
                    }
                    i3++;
                }
            }
            arrayList = arrayList3;
        }
        ArrayList arrayList4 = new ArrayList();
        double peek2 = doubleHeapSelect.peek();
        for (int i4 = 0; i4 < arrayList2.size(); i4++) {
            DistanceNode distanceNode2 = (DistanceNode) arrayList2.get(i4);
            if (distanceNode2.dist <= peek2 && distanceNode2.node.getObject() != e) {
                E object2 = distanceNode2.node.getObject();
                arrayList4.add(new Neighbor(object2, object2, distanceNode2.node.idx, distanceNode2.dist));
            }
        }
        Neighbor<E, E>[] neighborArr2 = (Neighbor[]) arrayList4.toArray(neighborArr);
        if (neighborArr2.length < i) {
            logger.warn(String.format("CoverTree.knn(%d) returns only %d neighbors", Integer.valueOf(i), Integer.valueOf(neighborArr2.length)));
        }
        Arrays.sort(neighborArr2);
        MathEx.reverse(neighborArr2);
        if (neighborArr2.length > i) {
            neighborArr2 = (Neighbor[]) Arrays.copyOf(neighborArr2, i);
        }
        return neighborArr2;
    }

    @Override // smile.neighbor.RNNSearch
    public void range(E e, double d, List<Neighbor<E, E>> list) {
        if (d <= 0.0d) {
            throw new IllegalArgumentException("Invalid radius: " + d);
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        arrayList.add(new DistanceNode(this.distance.d(this.root.getObject(), e), this.root));
        while (!arrayList.isEmpty()) {
            ArrayList arrayList3 = new ArrayList();
            for (int i = 0; i < arrayList.size(); i++) {
                DistanceNode distanceNode = (DistanceNode) arrayList.get(i);
                CoverTree<E>.Node node = ((DistanceNode) arrayList.get(i)).node;
                int i2 = 0;
                while (i2 < node.children.size()) {
                    CoverTree<E>.Node node2 = node.children.get(i2);
                    double d2 = i2 == 0 ? distanceNode.dist : this.distance.d(node2.getObject(), e);
                    if (d2 <= d + node2.maxDist) {
                        if (node2.children != null) {
                            arrayList3.add(new DistanceNode(d2, node2));
                        } else if (d2 <= d) {
                            arrayList2.add(new DistanceNode(d2, node2));
                        }
                    }
                    i2++;
                }
            }
            arrayList = arrayList3;
        }
        for (int i3 = 0; i3 < arrayList2.size(); i3++) {
            DistanceNode distanceNode2 = (DistanceNode) arrayList2.get(i3);
            if (distanceNode2.node.getObject() != e) {
                list.add(new Neighbor<>(distanceNode2.node.getObject(), distanceNode2.node.getObject(), distanceNode2.node.idx, distanceNode2.dist));
            }
        }
    }
}
