package edu.virginia.uvacluster.internal;

import edu.virginia.uvacluster.internal.feature.FeatureSet;
import edu.virginia.uvacluster.internal.feature.FeatureUtil;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.regex.Pattern;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang.StringUtils;
import org.cytoscape.model.CyEdge;
import org.cytoscape.model.CyNetwork;
import org.cytoscape.model.CyNode;
import org.cytoscape.model.SavePolicy;
import org.cytoscape.model.subnetwork.CyRootNetwork;
import org.cytoscape.model.subnetwork.CySubNetwork;

/* loaded from: input_file:edu/virginia/uvacluster/internal/SupervisedModel.class */
public class SupervisedModel implements Model {
    private CyRootNetwork rootNetwork;
    private Graph negBayesGraph;
    private Graph posBayesGraph;
    private List<Graph> bayesGraphs;
    private List<FeatureSet> features;
    private double complexPrior;
    private InputTask userInput;

    private void setup() {
        this.negBayesGraph = new Graph("NonCluster");
        this.posBayesGraph = new Graph("Cluster");
        this.bayesGraphs = Arrays.asList(this.negBayesGraph, this.posBayesGraph);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public SupervisedModel(CyRootNetwork cyRootNetwork, CyNetwork cyNetwork, CyNetwork cyNetwork2, InputTask inputTask) throws Exception {
        setup();
        Set<String> set = null;
        this.rootNetwork = cyRootNetwork;
        this.complexPrior = inputTask.clusterPrior;
        this.userInput = inputTask;
        Iterator<Graph> it = this.bayesGraphs.iterator();
        while (it.hasNext()) {
            set = it.next().loadModelFrom(cyNetwork);
        }
        this.features = FeatureUtil.parse(set);
        System.out.println("Cluster model init is finished.");
        List<CySubNetwork> loadTrainingComplexes = loadTrainingComplexes(inputTask.trainingFile);
        System.out.println("Positive examples are loaded.");
        List<CySubNetwork> generateNegativeExamples = generateNegativeExamples(inputTask.negativeExamples, loadTrainingComplexes);
        System.out.println("Dup");
        train(loadTrainingComplexes, generateNegativeExamples);
        System.out.println("Model trained...");
        saveGraphicalBayesianNetwork(cyNetwork2, this.features);
        System.out.println("Model saved to network...");
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public SupervisedModel(CyRootNetwork cyRootNetwork, CyNetwork cyNetwork, InputTask inputTask) {
        setup();
        this.rootNetwork = cyRootNetwork;
        this.userInput = inputTask;
        this.complexPrior = inputTask.clusterPrior;
        Iterator<Graph> it = this.bayesGraphs.iterator();
        while (it.hasNext()) {
            this.features = it.next().loadTrainedModelFrom(cyNetwork);
        }
    }

    @Override // edu.virginia.uvacluster.internal.Model
    public List<FeatureSet> getFeatures() {
        return this.features;
    }

    public void saveGraphicalBayesianNetwork(CyNetwork cyNetwork, List<FeatureSet> list) {
        Iterator<Graph> it = this.bayesGraphs.iterator();
        while (it.hasNext()) {
            it.next().saveTrainedModelTo(cyNetwork, list);
        }
    }

    @Override // edu.virginia.uvacluster.internal.Model
    public double score(Cluster cluster) throws Exception {
        if (!this.userInput.supervisedLearning) {
            return ClusterScore.score(cluster);
        }
        return Math.log((this.complexPrior * this.posBayesGraph.score(cluster)) / ((1.0d - this.complexPrior) * this.negBayesGraph.score(cluster)));
    }

    public double score(CySubNetwork cySubNetwork) throws Exception {
        return this.userInput.supervisedLearning ? score(new Cluster(this.features, cySubNetwork)) : ClusterScore.score(cySubNetwork);
    }

    public List<CySubNetwork> generateNegativeExamples(int i, List<CySubNetwork> list) throws Exception {
        ArrayList arrayList = new ArrayList(i);
        int[] iArr = new int[list.size()];
        double[] dArr = new double[(20 - 3) + 1];
        double[] dArr2 = new double[(20 - 3) + 1];
        double d = 0.0d;
        int i2 = 0;
        Iterator<CySubNetwork> it = list.iterator();
        while (it.hasNext()) {
            iArr[i2] = it.next().getNodeCount();
            i2++;
        }
        double sizeDistributionExponent = getSizeDistributionExponent(iArr);
        for (int i3 = 0; i3 < dArr.length; i3++) {
            dArr[i3] = 1.0d / Math.pow(i3 + 3, sizeDistributionExponent);
            d += dArr[i3];
        }
        for (int i4 = 0; i4 < dArr.length; i4++) {
            dArr2[i4] = dArr[i4] / d;
        }
        for (int i5 = 0; i5 < dArr2.length; i5++) {
            int i6 = 0;
            while (i6 < Math.round(dArr2[i5] * i)) {
                CySubNetwork genNegativeExample = genNegativeExample(i5 + 3, list);
                if (genNegativeExample == null || genNegativeExample.getNodeCount() < i5 + 3) {
                    this.rootNetwork.removeSubNetwork(genNegativeExample);
                    i6--;
                } else {
                    arrayList.add(genNegativeExample);
                }
                i6++;
            }
        }
        return arrayList;
    }

    private CySubNetwork genNegativeExample(int i, List<CySubNetwork> list) {
        boolean z = true;
        int i2 = 0;
        List list2 = null;
        List nodeList = this.rootNetwork.getNodeList();
        int i3 = 0;
        int size = nodeList.size();
        CySubNetwork addSubNetwork = this.rootNetwork.addSubNetwork(SavePolicy.DO_NOT_SAVE);
        double d = 0.0d;
        do {
            double d2 = 1.0d;
            double ceil = Math.ceil(Math.random() * size);
            Iterator it = nodeList.iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                CyNode cyNode = (CyNode) it.next();
                if (d2 == ceil) {
                    addSubNetwork.addNode(cyNode);
                    list2 = this.rootNetwork.getNeighborList(cyNode, CyEdge.Type.ANY);
                    i3 = list2.size();
                    break;
                }
                d2 += 1.0d;
            }
            d += 1.0d;
            if (i3 != 0) {
                break;
            }
        } while (d < size * size);
        if (addSubNetwork.getNodeCount() == 0) {
            z = false;
        }
        while (addSubNetwork.getNodeCount() < i && i3 > 0 && z) {
            double d3 = 1.0d;
            double ceil2 = Math.ceil(Math.random() * i3);
            Iterator it2 = list2.iterator();
            while (true) {
                if (!it2.hasNext()) {
                    break;
                }
                CyNode cyNode2 = (CyNode) it2.next();
                if (d3 == ceil2) {
                    List nodeList2 = addSubNetwork.getNodeList();
                    addSubNetwork.addNode(cyNode2);
                    Iterator it3 = nodeList2.iterator();
                    while (it3.hasNext()) {
                        List connectingEdgeList = this.rootNetwork.getConnectingEdgeList(cyNode2, (CyNode) it3.next(), CyEdge.Type.ANY);
                        if (!connectingEdgeList.isEmpty()) {
                            addSubNetwork.addEdge((CyEdge) connectingEdgeList.get(0));
                        }
                    }
                    list2.removeAll(this.rootNetwork.getNeighborList(cyNode2, CyEdge.Type.ANY));
                    list2.addAll(this.rootNetwork.getNeighborList(cyNode2, CyEdge.Type.ANY));
                    list2.removeAll(addSubNetwork.getNodeList());
                    i3 = list2.size();
                } else {
                    d3 += 1.0d;
                }
            }
        }
        Iterator<CySubNetwork> it4 = list.iterator();
        while (true) {
            if (!it4.hasNext()) {
                break;
            }
            CySubNetwork next = it4.next();
            Iterator it5 = addSubNetwork.getNodeList().iterator();
            while (it5.hasNext()) {
                if (next.containsNode((CyNode) it5.next())) {
                    i2++;
                }
            }
            if (i2 > 1) {
                z = false;
                break;
            }
            i2 = 0;
        }
        if (z) {
            return addSubNetwork;
        }
        return null;
    }

    public void train(List<CySubNetwork> list, List<CySubNetwork> list2) {
        System.out.println("Entered TRAIN");
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        Iterator<CySubNetwork> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(new Cluster(this.features, it.next()));
        }
        System.out.println("Lists of pos and neg training examples created");
        this.posBayesGraph.trainOn(arrayList);
        System.out.println("Model has finished training on " + list.size() + " positive Examples.");
        Iterator<CySubNetwork> it2 = list2.iterator();
        while (it2.hasNext()) {
            arrayList2.add(new Cluster(this.features, it2.next()));
        }
        this.negBayesGraph.trainOn(arrayList2);
        System.out.println("Model has finished training on " + list2.size() + " negative Examples.");
    }

    public List<CySubNetwork> loadTrainingComplexes(File file) throws Exception {
        if (!file.isFile()) {
            throw new Exception("This is not a file");
        }
        if (!file.canRead()) {
            throw new Exception("This file is not readable");
        }
        ArrayList arrayList = new ArrayList();
        Iterator<String> it = FileUtils.readLines(file, (String) null).iterator();
        while (it.hasNext()) {
            String trim = StringUtils.trim(it.next());
            String[] split = StringUtils.split(StringUtils.substringAfterLast(trim, "\t"));
            String substringBeforeLast = StringUtils.substringBeforeLast(trim, "\t");
            String substringAfter = StringUtils.substringAfter(substringBeforeLast, "\t");
            if (Pattern.matches("[0-9]+", StringUtils.substringBefore(substringBeforeLast, "\t"))) {
                CySubNetwork addSubNetwork = this.rootNetwork.addSubNetwork(SavePolicy.DO_NOT_SAVE);
                for (int i = 0; i < split.length; i++) {
                    long idFromName = getIdFromName(split[i]);
                    if (idFromName != -1) {
                        addSubNetwork.addNode(this.rootNetwork.getNode(idFromName));
                    } else {
                        if (!this.userInput.ignoreMissing) {
                            throw new Exception("Protein not found in network: " + split[i]);
                        }
                        System.out.println("Protein not found in network" + split[i]);
                    }
                }
                List<CyNode> nodeList = addSubNetwork.getNodeList();
                for (CyNode cyNode : nodeList) {
                    for (CyNode cyNode2 : nodeList) {
                        if (this.rootNetwork.containsEdge(cyNode, cyNode2)) {
                            addSubNetwork.addEdge((CyEdge) this.rootNetwork.getConnectingEdgeList(cyNode, cyNode2, CyEdge.Type.ANY).get(0));
                        }
                    }
                }
                if (addSubNetwork.getNodeCount() > 2) {
                    addSubNetwork.getRow(addSubNetwork).set("name", substringAfter);
                    arrayList.add(addSubNetwork);
                } else {
                    this.rootNetwork.removeSubNetwork(addSubNetwork);
                }
            }
        }
        return arrayList;
    }

    private long getIdFromName(String str) {
        long j = -1;
        Iterator it = this.rootNetwork.getNodeList().iterator();
        while (true) {
            if (!it.hasNext()) {
                break;
            }
            CyNode cyNode = (CyNode) it.next();
            if (StringUtils.upperCase((String) this.rootNetwork.getDefaultNodeTable().getRow(cyNode.getSUID()).get("shared name", String.class)).equals(StringUtils.upperCase(str))) {
                j = cyNode.getSUID().longValue();
                break;
            }
        }
        return j;
    }

    private double getSizeDistributionExponent(int[] iArr) {
        double d = 0.0d;
        double arrayMin = ClusterUtil.arrayMin(iArr);
        double length = iArr.length;
        for (int i : iArr) {
            d += Math.log(i / arrayMin);
        }
        return 1.0d + (length / d);
    }
}
