package edu.princeton.safe.grouping;

import com.carrotsearch.hppc.IntArrayList;
import com.carrotsearch.hppc.cursors.IntCursor;
import edu.princeton.safe.AnnotationProvider;
import edu.princeton.safe.GroupingMethod;
import edu.princeton.safe.ProgressReporter;
import edu.princeton.safe.internal.ScoringFunction;
import edu.princeton.safe.internal.cluster.Dendrogram;
import edu.princeton.safe.internal.cluster.DendrogramBuilder;
import edu.princeton.safe.internal.cluster.DendrogramNode;
import edu.princeton.safe.internal.cluster.ObservationNode;
import edu.princeton.safe.internal.fastcluster.HierarchicalClusterer;
import edu.princeton.safe.internal.fastcluster.MethodCode;
import edu.princeton.safe.internal.fastcluster.Node;
import edu.princeton.safe.io.DomainConsumer;
import edu.princeton.safe.model.CompositeMap;
import edu.princeton.safe.model.EnrichmentLandscape;
import edu.princeton.safe.model.Neighborhood;
import java.util.Iterator;
import java.util.List;
import java.util.OptionalDouble;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/* loaded from: input_file:safe-core-1.0.0-beta4.jar:edu/princeton/safe/grouping/ClusterBasedGroupingMethod.class */
public class ClusterBasedGroupingMethod implements GroupingMethod {
    double threshold;
    DistanceMethod distanceMethod;

    /* loaded from: input_file:safe-core-1.0.0-beta4.jar:edu/princeton/safe/grouping/ClusterBasedGroupingMethod$ClusterConsumer.class */
    interface ClusterConsumer {
        void startCluster();

        void addMember(int i);

        void endCluster();
    }

    public ClusterBasedGroupingMethod(double d, DistanceMethod distanceMethod) {
        this.threshold = d;
        this.distanceMethod = distanceMethod;
    }

    @Override // edu.princeton.safe.Identifiable
    public String getId() {
        return this.distanceMethod.getId();
    }

    @Override // edu.princeton.safe.GroupingMethod
    public void group(EnrichmentLandscape enrichmentLandscape, CompositeMap compositeMap, int i, DomainConsumer domainConsumer, ProgressReporter progressReporter) {
        AnnotationProvider annotationProvider = enrichmentLandscape.getAnnotationProvider();
        int attributeCount = annotationProvider.getAttributeCount();
        IntArrayList intArrayList = new IntArrayList();
        for (int i2 = 0; i2 < attributeCount; i2++) {
            if (compositeMap.isTop(i2, i)) {
                intArrayList.add(i2);
            }
        }
        int size = intArrayList.size();
        if (size < 2) {
            progressReporter.setStatus("Warning: Less than two attributes remain after filtering", Integer.valueOf(size));
            if (size == 1) {
                domainConsumer.startDomain(i);
                domainConsumer.attribute(intArrayList.get(0));
                domainConsumer.endDomain();
                return;
            }
            return;
        }
        progressReporter.setStatus("Top attributes: %d", Integer.valueOf(size));
        progressReporter.setStatus("Computing attribute distances...", new Object[0]);
        double[][] computeScores = computeScores(enrichmentLandscape, attributeCount, intArrayList, i);
        progressReporter.setStatus("Computing dissimilarity matrix...", new Object[0]);
        List<IntArrayList> computeClusters = computeClusters(size, pdist(computeScores, this.distanceMethod), progressReporter, annotationProvider, intArrayList);
        progressReporter.setStatus("Assigning clusters...", new Object[0]);
        for (IntArrayList intArrayList2 : computeClusters) {
            if (!intArrayList2.isEmpty()) {
                domainConsumer.startDomain(i);
                Iterator<IntCursor> it = intArrayList2.iterator();
                while (it.hasNext()) {
                    domainConsumer.attribute(intArrayList.get(it.next().value));
                }
                domainConsumer.endDomain();
            }
        }
    }

    /* JADX WARN: Type inference failed for: r0v7, types: [double[], double[][]] */
    static double[][] computeScores(EnrichmentLandscape enrichmentLandscape, int i, IntArrayList intArrayList, int i2) {
        ScoringFunction scoringFunction = Neighborhood.getScoringFunction(i2);
        List<? extends Neighborhood> neighborhoods = enrichmentLandscape.getNeighborhoods();
        int size = intArrayList.size();
        ?? r0 = new double[size];
        IntStream.range(0, size).parallel().forEach(i3 -> {
            int i3 = intArrayList.get(i3);
            r0[i3] = neighborhoods.stream().mapToDouble(neighborhood -> {
                return scoringFunction.get(neighborhood, i3);
            }).toArray();
        });
        return r0;
    }

    List<IntArrayList> computeClusters(int i, double[] dArr, ProgressReporter progressReporter, AnnotationProvider annotationProvider, IntArrayList intArrayList) {
        progressReporter.setStatus("Computing cluster tree...", new Object[0]);
        int[] iArr = new int[i];
        IntStream.range(0, iArr.length).forEach(i2 -> {
            iArr[i2] = 1;
        });
        List<Node> NN_chain_core = HierarchicalClusterer.NN_chain_core(i, dArr, iArr, MethodCode.METHOD_METR_AVERAGE);
        OptionalDouble max = NN_chain_core.stream().mapToDouble(node -> {
            return node.dist;
        }).max();
        DendrogramBuilder dendrogramBuilder = new DendrogramBuilder(i);
        HierarchicalClusterer.buildClusters(false, NN_chain_core, dendrogramBuilder);
        DendrogramNode root = dendrogramBuilder.getRoot();
        double asDouble = max.getAsDouble();
        List<DendrogramNode> cut = Dendrogram.cut(root, this.threshold * asDouble);
        progressReporter.setStatus("Cluster tree height: %f", Double.valueOf(asDouble));
        return (List) cut.stream().map(dendrogramNode -> {
            return getObservations(dendrogramNode);
        }).collect(Collectors.toList());
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static IntArrayList getObservations(DendrogramNode dendrogramNode) {
        IntArrayList intArrayList = new IntArrayList();
        getObservations(intArrayList, dendrogramNode);
        return intArrayList;
    }

    static void getObservations(IntArrayList intArrayList, DendrogramNode dendrogramNode) {
        if (dendrogramNode instanceof ObservationNode) {
            intArrayList.add(((ObservationNode) dendrogramNode).getObservation());
        } else {
            getObservations(intArrayList, dendrogramNode.getLeft());
            getObservations(intArrayList, dendrogramNode.getRight());
        }
    }

    static double[] pdist(double[][] dArr, DistanceMethod distanceMethod) {
        int length = dArr.length;
        double[] dArr2 = new double[(length * (length - 1)) / 2];
        IntStream.range(0, length).parallel().forEach(i -> {
            for (int i = i + 1; i < length; i++) {
                dArr2[getIndex(length, i, i)] = distanceMethod.apply(dArr[i], dArr[i]);
            }
        });
        return dArr2;
    }

    static int getIndex(int i, int i2, int i3) {
        return ((((i2 * (((i * 2) - i2) - 1)) / 2) + i3) - i2) - 1;
    }
}
