package org.reactome.r3.cluster;

import java.awt.Point;
import java.io.IOException;
import java.io.OutputStream;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import org.apache.commons.math.random.RandomDataImpl;

/* loaded from: input_file:foundation-1.0.3.jar:org/reactome/r3/cluster/SOMWeightNodes.class */
public class SOMWeightNodes {
    private boolean useDotProduct = false;
    private List<ReferenceNode> nodeList = new ArrayList();
    private Map<String, ReferenceNode> posToNode = new HashMap();

    public void setUpNodes(int i, int i2) {
        for (int i3 = 0; i3 < i; i3++) {
            for (int i4 = 0; i4 < i2; i4++) {
                ReferenceNode referenceNode = new ReferenceNode();
                referenceNode.setLocation(i3, i4);
                this.nodeList.add(referenceNode);
                this.posToNode.put(i3 + "," + i4, referenceNode);
            }
        }
    }

    public List<ReferenceNode> getReferenceNodes() {
        return this.nodeList;
    }

    public void addNode(ReferenceNode referenceNode) {
        this.nodeList.add(referenceNode);
        Point location = referenceNode.getLocation();
        this.posToNode.put(location.x + "," + location.y, referenceNode);
    }

    public void initNodeReferenceVector(Map<String, double[]> map) {
        ArrayList arrayList = new ArrayList(map.keySet());
        int i = 0;
        for (int i2 = 0; i2 < this.nodeList.size(); i2++) {
            if (i == arrayList.size()) {
                i = 0;
            }
            this.nodeList.get(i2).setReferenceVector(map.get((String) arrayList.get(i)));
            i++;
        }
    }

    public void initNodeReferenceData(Map<String, Set<String>> map) {
        HashSet hashSet = new HashSet();
        Iterator<String> it = map.keySet().iterator();
        while (it.hasNext()) {
            hashSet.addAll(map.get(it.next()));
        }
        int size = hashSet.size() / map.size();
        Random random = new Random();
        RandomDataImpl randomDataImpl = new RandomDataImpl();
        for (ReferenceNode referenceNode : this.nodeList) {
            randomDataImpl.reSeedSecure(random.nextLong());
            Object[] nextSample = randomDataImpl.nextSample(hashSet, size);
            HashSet hashSet2 = new HashSet();
            for (Object obj : nextSample) {
                hashSet2.add(obj.toString());
            }
            referenceNode.addInputReferenceData(hashSet2);
        }
    }

    public ReferenceNode getNodeAt(int i, int i2) {
        return this.posToNode.get(i + "," + i2);
    }

    public ReferenceNode searchMatchNodeForDist(double[] dArr) {
        ReferenceNode referenceNode = this.nodeList.get(0);
        double calculateDistanceSqrForLearning = referenceNode.calculateDistanceSqrForLearning(dArr);
        ReferenceNode referenceNode2 = referenceNode;
        for (int i = 1; i < this.nodeList.size(); i++) {
            ReferenceNode referenceNode3 = this.nodeList.get(i);
            double calculateDistanceSqrForLearning2 = referenceNode3.calculateDistanceSqrForLearning(dArr);
            if (calculateDistanceSqrForLearning2 < calculateDistanceSqrForLearning) {
                referenceNode2 = referenceNode3;
                calculateDistanceSqrForLearning = calculateDistanceSqrForLearning2;
            }
        }
        return referenceNode2;
    }

    public ReferenceNode searchMatchNode(double[] dArr) {
        return this.useDotProduct ? searchMatchNodeForDotProduct(dArr) : searchMatchNodeForDist(dArr);
    }

    public ReferenceNode searchMatchNodeForDotProduct(double[] dArr) {
        ReferenceNode referenceNode = this.nodeList.get(0);
        double calculateDotProduct = referenceNode.calculateDotProduct(dArr);
        ReferenceNode referenceNode2 = referenceNode;
        for (int i = 1; i < this.nodeList.size(); i++) {
            ReferenceNode referenceNode3 = this.nodeList.get(i);
            double calculateDotProduct2 = referenceNode3.calculateDotProduct(dArr);
            if (calculateDotProduct2 > calculateDotProduct) {
                calculateDotProduct = calculateDotProduct2;
                referenceNode2 = referenceNode3;
            }
        }
        return referenceNode2;
    }

    public void checkNodes() {
        for (int i = 0; i < this.nodeList.size(); i++) {
            System.out.println(i + ": " + this.nodeList.get(i).getReferenceVector());
        }
    }

    public ReferenceNode searchMatchNode(Set<String> set) {
        ReferenceNode referenceNode = this.nodeList.get(0);
        double calculateSimilarity = referenceNode.calculateSimilarity(set);
        ReferenceNode referenceNode2 = referenceNode;
        for (int i = 1; i < this.nodeList.size(); i++) {
            ReferenceNode referenceNode3 = this.nodeList.get(i);
            double calculateSimilarity2 = referenceNode3.calculateSimilarity(set);
            if (calculateSimilarity2 > calculateSimilarity) {
                referenceNode2 = referenceNode3;
                calculateSimilarity = calculateSimilarity2;
            }
            if (calculateSimilarity2 == 1.0d) {
                break;
            }
        }
        if (calculateSimilarity == 0.0d) {
            Iterator<ReferenceNode> it = this.nodeList.iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                ReferenceNode next = it.next();
                if (next.isEmpty()) {
                    referenceNode2 = next;
                    break;
                }
            }
        }
        return referenceNode2;
    }

    public void recalculateNodes(int i) {
        ArrayList arrayList = new ArrayList();
        for (ReferenceNode referenceNode : this.nodeList) {
            Point location = referenceNode.getLocation();
            int i2 = location.x;
            int i3 = location.y;
            int max = Math.max(0, i2 - i);
            int max2 = Math.max(0, i3 - i);
            int min = Math.min(SOM.X_LENGTH - 1, i2 + i);
            int min2 = Math.min(SOM.Y_LENGTH - 1, i3 + i);
            arrayList.clear();
            for (int i4 = max; i4 < min + 1; i4++) {
                for (int i5 = max2; i5 < min2 + 1; i5++) {
                    List<Set<String>> inputData = getNodeAt(i4, i5).getInputData();
                    if (inputData != null) {
                        arrayList.addAll(inputData);
                    }
                }
            }
            referenceNode.setInputReferenceData(arrayList);
        }
        Iterator<ReferenceNode> it = this.nodeList.iterator();
        while (it.hasNext()) {
            it.next().resetInputData();
        }
    }

    public void recalculateNodesVectors(int i) {
        ArrayList arrayList = new ArrayList();
        for (ReferenceNode referenceNode : this.nodeList) {
            List<ReferenceNode> nodesWithin = getNodesWithin(referenceNode.getLocation(), i);
            arrayList.clear();
            Iterator<ReferenceNode> it = nodesWithin.iterator();
            while (it.hasNext()) {
                List<double[]> inputVectors = it.next().getInputVectors();
                if (inputVectors != null) {
                    arrayList.addAll(inputVectors);
                }
            }
            if (arrayList.size() > 0) {
                referenceNode.setReferenceVector(average(arrayList));
            }
        }
        for (ReferenceNode referenceNode2 : this.nodeList) {
            referenceNode2.resetInputVectors();
            referenceNode2.resetLabels();
        }
    }

    public List<ReferenceNode> getNodesWithin(Point point, int i) {
        ArrayList arrayList = new ArrayList();
        int i2 = point.x;
        int i3 = point.y;
        int max = Math.max(0, i2 - i);
        int max2 = Math.max(0, i3 - i);
        int min = Math.min(SOM.X_LENGTH - 1, i2 + i);
        int min2 = Math.min(SOM.Y_LENGTH - 1, i3 + i);
        for (int i4 = max; i4 < min + 1; i4++) {
            for (int i5 = max2; i5 < min2 + 1; i5++) {
                arrayList.add(getNodeAt(i4, i5));
            }
        }
        return arrayList;
    }

    public double calculateDistWithNeighbors(ReferenceNode referenceNode) {
        List<ReferenceNode> nodesWithin = getNodesWithin(referenceNode.getLocation(), 1);
        nodesWithin.remove(referenceNode);
        double d = 0.0d;
        for (ReferenceNode referenceNode2 : nodesWithin) {
            if (referenceNode2 == null) {
                System.out.println(referenceNode.getLocation());
            }
            d += referenceNode.calculateDistance(referenceNode2);
        }
        return d / nodesWithin.size();
    }

    public double[] average(List<double[]> list) {
        double[] dArr = new double[list.get(0).length];
        for (double[] dArr2 : list) {
            for (int i = 0; i < dArr2.length; i++) {
                int i2 = i;
                dArr[i2] = dArr[i2] + dArr2[i];
            }
        }
        int size = list.size();
        for (int i3 = 0; i3 < dArr.length; i3++) {
            int i4 = i3;
            dArr[i4] = dArr[i4] / size;
        }
        return dArr;
    }

    public Set<String> median(List<Set<String>> list) {
        if (list == null || list.size() == 0) {
            return null;
        }
        Collections.sort(list, new Comparator<Set<String>>() { // from class: org.reactome.r3.cluster.SOMWeightNodes.1
            @Override // java.util.Comparator
            public int compare(Set<String> set, Set<String> set2) {
                return set.size() - set2.size();
            }
        });
        return list.get(list.size() / 2);
    }

    public void output(OutputStream outputStream) throws IOException {
        PrintStream printStream = new PrintStream(outputStream);
        printStream.println("SOM results:");
        for (ReferenceNode referenceNode : this.nodeList) {
            Point location = referenceNode.getLocation();
            double calculateAverageDistance = referenceNode.calculateAverageDistance();
            printStream.println("location:" + location.x + ", " + location.y);
            printStream.println("index:" + calculateAverageDistance);
            Set<String> labels = referenceNode.getLabels();
            if (labels != null) {
                printStream.print("label:");
                Iterator<String> it = labels.iterator();
                while (it.hasNext()) {
                    printStream.print(it.next());
                    if (it.hasNext()) {
                        printStream.print(", ");
                    }
                }
                printStream.println();
            }
            StringBuilder sb = new StringBuilder();
            for (double d : referenceNode.getReferenceVector()) {
                sb.append(d).append(",");
            }
            printStream.println("vector:" + sb.toString());
        }
        printStream.close();
        outputStream.close();
    }
}
