package com.tcb.cluster;

import com.google.common.collect.ImmutableList;
import com.tcb.common.util.ArrayUtil;
import com.tcb.matrix.LabeledMatrix;
import com.tcb.tree.node.Node;
import com.tcb.tree.tree.LeafNodeTreeSearcher;
import com.tcb.tree.tree.Tree;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;

/* loaded from: input_file:cluster-0.1.3.jar:com/tcb/cluster/ClusterTree.class */
public class ClusterTree {
    private Tree tree;
    private LabeledMatrix<String> distances;
    private LabeledMatrix<Node> nodeDistances;
    private List<Double> closestDistances;

    public ClusterTree(Tree tree, LabeledMatrix<String> labeledMatrix, LabeledMatrix<Node> labeledMatrix2, List<Double> list) {
        this.tree = tree;
        this.distances = labeledMatrix;
        this.nodeDistances = labeledMatrix2;
        this.closestDistances = list;
    }

    public List<Double> getClosestDistances() {
        return this.closestDistances;
    }

    public Integer getDataPointCount() {
        return Integer.valueOf(this.distances.getRowCount());
    }

    public List<Cluster> getClusters(Integer num) {
        if (num.intValue() < 1) {
            throw new IllegalArgumentException("Cannot have less than 1 cluster");
        }
        if (num.intValue() > getDataPointCount().intValue()) {
            throw new IllegalArgumentException("Cannot have more clusters than data points");
        }
        ArrayList arrayList = new ArrayList();
        arrayList.add(this.tree.getRoot());
        while (arrayList.size() < num.intValue()) {
            expandLevel(arrayList);
        }
        return toClusters(arrayList);
    }

    private void expandLevel(List<Node> list) {
        list.addAll(this.tree.getChildren(list.remove(getMaxSuidNodeIndex(list))));
    }

    private List<Cluster> toClusters(List<Node> list) {
        ArrayList arrayList = new ArrayList();
        Iterator<Node> it = list.iterator();
        while (it.hasNext()) {
            List<Node> leafNodes = LeafNodeTreeSearcher.getLeafNodes(it.next(), this.tree);
            arrayList.add(new ClusterImpl((List) getIndices(leafNodes).stream().sorted().map(num -> {
                return this.distances.getLabel(num);
            }).collect(ImmutableList.toImmutableList()), getCentroid(leafNodes), getCentroidSquaredError(leafNodes)));
        }
        arrayList.sort(new ClusterComparator());
        return arrayList;
    }

    private int getMaxSuidNodeIndex(List<Node> list) {
        int i = -1;
        Long l = Long.MIN_VALUE;
        for (int i2 = 0; i2 < list.size(); i2++) {
            Long suid = list.get(i2).getSuid();
            if (suid.longValue() > l.longValue()) {
                l = suid;
                i = i2;
            }
        }
        return i;
    }

    private String getCentroid(List<Node> list) {
        return this.distances.getLabel(this.nodeDistances.getIndex(getCentroidNode(list)));
    }

    private Node getCentroidNode(List<Node> list) {
        int size = list.size();
        double[] dArr = new double[size];
        for (int i = 0; i < size; i++) {
            for (int i2 = i + 1; i2 < size; i2++) {
                Double d = this.nodeDistances.get(list.get(i), list.get(i2));
                int i3 = i;
                dArr[i3] = dArr[i3] + d.doubleValue();
                int i4 = i2;
                dArr[i4] = dArr[i4] + d.doubleValue();
            }
        }
        return list.get(ArrayUtil.indexOf(dArr, ArrayUtil.getMin(dArr)));
    }

    private List<Integer> getIndices(List<Node> list) {
        return (List) list.stream().map(node -> {
            return this.nodeDistances.getIndex(node);
        }).collect(ImmutableList.toImmutableList());
    }

    private Double getCentroidSquaredError(List<Node> list) {
        Double valueOf = Double.valueOf(CMAESOptimizer.DEFAULT_STOPFITNESS);
        Node centroidNode = getCentroidNode(list);
        for (Node node : list) {
            if (!node.equals(centroidNode)) {
                valueOf = Double.valueOf(valueOf.doubleValue() + Math.pow(this.nodeDistances.get(node, centroidNode).doubleValue(), 2.0d));
            }
        }
        return valueOf;
    }
}
