package com.tcb.cluster.agglomerative;

import com.google.common.collect.ImmutableList;
import com.tcb.cluster.Cluster;
import com.tcb.cluster.ClusterFactory;
import com.tcb.cluster.ClusterStep;
import com.tcb.cluster.ClusterTree;
import com.tcb.cluster.Clusterer;
import com.tcb.cluster.TreeClusterer;
import com.tcb.cluster.limit.ClusterLimit;
import com.tcb.cluster.linkage.Linkage;
import com.tcb.common.util.ArrayUtil;
import com.tcb.common.util.Combinatorics;
import com.tcb.common.util.Tuple;
import com.tcb.matrix.LabeledMatrix;
import com.tcb.matrix.LabeledSquareMatrixImpl;
import com.tcb.tree.node.Node;
import com.tcb.tree.tree.Tree;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;

/* loaded from: input_file:cluster-0.1.3.jar:com/tcb/cluster/agglomerative/AgglomerativeClusterer.class */
public class AgglomerativeClusterer implements Clusterer, TreeClusterer {
    private Linkage linkage;
    private ConcurrentHashMap<Node, Map<Node, Double>> linkageDistanceCache;
    private volatile int currentClusterCount;
    protected volatile boolean cancelled = false;
    private LabeledMatrix<String> distances;

    public AgglomerativeClusterer(LabeledMatrix<String> labeledMatrix, Linkage linkage) {
        this.distances = labeledMatrix;
        this.linkage = linkage;
        this.currentClusterCount = labeledMatrix.getRowCount();
    }

    private ConcurrentHashMap<Node, Map<Node, Double>> createLinkageDistanceCache() {
        return new ConcurrentHashMap<>();
    }

    @Override // com.tcb.cluster.Clusterer
    public List<Cluster> cluster(ClusterLimit clusterLimit) {
        return clusterLimit.getClusters(cluster());
    }

    @Override // com.tcb.cluster.TreeClusterer
    public ClusterTree cluster() {
        this.linkageDistanceCache = createLinkageDistanceCache();
        Tree createStartClusters = new ClusterFactory(Integer.valueOf(this.distances.getRowCount())).createStartClusters();
        this.currentClusterCount = getClusters(createStartClusters).size();
        LabeledMatrix<Node> createNodeDistances = createNodeDistances(createStartClusters);
        Optional<ClusterStep> planNextStep = planNextStep(createStartClusters, createNodeDistances);
        ImmutableList.Builder builder = ImmutableList.builder();
        while (planNextStep.isPresent()) {
            if (this.cancelled) {
                throw new RuntimeException("Cancelled clustering");
            }
            builder.add((ImmutableList.Builder) planNextStep.get().closestDistance);
            runStep(createStartClusters, planNextStep.get());
            planNextStep = planNextStep(createStartClusters, createNodeDistances);
        }
        return new ClusterTree(createStartClusters, this.distances, createNodeDistances, builder.build());
    }

    private LabeledMatrix<Node> createNodeDistances(Tree tree) {
        return LabeledSquareMatrixImpl.create(getClusters(tree), this.distances);
    }

    private List<Node> getClusters(Tree tree) {
        return tree.getChildren(tree.getRoot());
    }

    private void runStep(Tree tree, ClusterStep clusterStep) {
        Tuple<Node, Node> tuple = clusterStep.closestPair;
        tree.mergeNodes(tuple.one(), tuple.two());
        this.currentClusterCount--;
    }

    @Override // com.tcb.cluster.Clusterer
    public Integer getCurrentClusterCount() {
        return Integer.valueOf(this.currentClusterCount);
    }

    private Optional<ClusterStep> planNextStep(Tree tree, LabeledMatrix<Node> labeledMatrix) {
        List<Node> clusters = getClusters(tree);
        Integer valueOf = Integer.valueOf(clusters.size());
        if (valueOf.intValue() <= 1) {
            return Optional.empty();
        }
        List combinationsNoSelf = Combinatorics.getCombinationsNoSelf(clusters);
        double[] array = combinationsNoSelf.parallelStream().map(tuple -> {
            return Double.valueOf(getCachedLinkageDistance((Node) tuple.one(), (Node) tuple.two(), tree, labeledMatrix));
        }).mapToDouble(d -> {
            return d.doubleValue();
        }).toArray();
        double min = ArrayUtil.getMin(array);
        return Optional.of(new ClusterStep(valueOf, (Tuple) combinationsNoSelf.get(ArrayUtil.indexOf(array, min)), Double.valueOf(min)));
    }

    private double getLinkageDistance(Node node, Node node2, Tree tree, LabeledMatrix<Node> labeledMatrix) {
        return this.linkage.getDistance(node, node2, tree, labeledMatrix);
    }

    private double getCachedLinkageDistance(Node node, Node node2, Tree tree, LabeledMatrix<Node> labeledMatrix) {
        if (!this.linkageDistanceCache.containsKey(node)) {
            this.linkageDistanceCache.put(node, new ConcurrentHashMap());
        }
        Map<Node, Double> map = this.linkageDistanceCache.get(node);
        if (map.containsKey(node2)) {
            return map.get(node2).doubleValue();
        }
        double linkageDistance = getLinkageDistance(node, node2, tree, labeledMatrix);
        map.put(node2, Double.valueOf(linkageDistance));
        return linkageDistance;
    }

    @Override // com.tcb.cluster.Clusterer
    public void cancel() {
        this.cancelled = true;
    }
}
