package org.reactome.pathway.factorgraph;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.math.random.RandomDataImpl;
import org.apache.commons.math3.random.EmpiricalDistribution;
import org.apache.log4j.Logger;
import org.gk.model.GKInstance;
import org.gk.persistence.MySQLAdaptor;
import org.junit.Test;
import org.reactome.factorgraph.Factor;
import org.reactome.factorgraph.FactorGraph;
import org.reactome.factorgraph.FactorValueAssignment;
import org.reactome.factorgraph.GibbsSampling;
import org.reactome.factorgraph.InferenceCannotConvergeException;
import org.reactome.factorgraph.LoopyBeliefPropagation;
import org.reactome.factorgraph.Observation;
import org.reactome.factorgraph.Variable;
import org.reactome.factorgraph.common.DataType;
import org.reactome.pathway.factorgraph.ReactomePathwayFGRunner;
import org.reactome.r3.util.FileUtility;
import org.reactome.r3.util.InteractionUtilities;
import org.reactome.r3.util.MathUtilities;
import org.reactome.r3.util.R3Constants;

/* loaded from: input_file:modeling-1.0.3.jar:org/reactome/pathway/factorgraph/GibbsSampler.class */
public class GibbsSampler extends GibbsSampling {
    private static final String DIR_NAME = "results/paradigm/twoCases/GibbsSampling/";
    public static final String SAMPLE_FILE_NAME = "results/paradigm/twoCases/GibbsSampling/Transcription_Regulation_Pluripoten_Samples_AssignedParams.txt";
    private static Logger logger = Logger.getLogger(GibbsSampler.class);
    private final double MINIMUM_PROB = 1.0E-12d;

    @Test
    public void learnParametersFromGeneratedSamples() throws Exception {
        FileUtility.initializeLogging();
        GKInstance pathway = getPathway(452723L);
        ReactomePathwayFGRunner reactomePathwayFGRunner = new ReactomePathwayFGRunner();
        PathwayToFactorGraphConverter pathwayToFactorGraphConverter = new PathwayToFactorGraphConverter();
        ReactomePathwayFGRunner.ConvertedFactorGraph convertPathway = reactomePathwayFGRunner.convertPathway(pathway, pathwayToFactorGraphConverter);
        logger.info("Converted factors: " + convertPathway.fg.getFactors().size() + " factors, and " + convertPathway.fg.getVariables().size() + " variables.");
        Set randomSampling = MathUtilities.randomSampling(loadSamples(convertPathway.fg, SAMPLE_FILE_NAME), EmpiricalDistribution.DEFAULT_BIN_COUNT, new RandomDataImpl());
        filterSamples(randomSampling, convertPathway.fg);
        checkSamples(randomSampling);
        convertPathway.observations = new ArrayList(randomSampling);
        reactomePathwayFGRunner.performLearning(pathway, pathwayToFactorGraphConverter, convertPathway);
        reactomePathwayFGRunner.performInference(pathway, pathwayToFactorGraphConverter, convertPathway);
    }

    private void checkSamples(Collection<Observation<Integer>> collection) {
        new HashMap();
        Iterator<Observation<Integer>> it = collection.iterator();
        while (it.hasNext()) {
            it.next().getVariableToAssignment();
        }
    }

    private void filterSamples(Collection<Observation<Integer>> collection, FactorGraph factorGraph) {
        LoopyBeliefPropagation lbp = PathwayPGMConfiguration.getConfig().getLBP();
        lbp.setFactorGraph(factorGraph);
        Iterator<Observation<Integer>> it = collection.iterator();
        while (it.hasNext()) {
            Observation<Integer> next = it.next();
            lbp.setObservation(next.getVariableToAssignment());
            try {
                lbp.runInference();
            } catch (InferenceCannotConvergeException e) {
                logger.error(next.getName() + " cannot converge!");
                it.remove();
            }
        }
        logger.info("Total samples aftering filtering: " + collection.size());
    }

    private List<Observation<Integer>> loadSamples(FactorGraph factorGraph, String str) throws IOException {
        Variable variable;
        HashMap hashMap = new HashMap();
        for (Variable variable2 : factorGraph.getVariables()) {
            hashMap.put(variable2.getName(), variable2);
        }
        ArrayList arrayList = new ArrayList();
        FileUtility fileUtility = new FileUtility();
        fileUtility.setInput(str);
        String[] split = fileUtility.readLine().split("\t");
        while (true) {
            String readLine = fileUtility.readLine();
            if (readLine == null) {
                fileUtility.close();
                return arrayList;
            }
            String[] split2 = readLine.split("\t");
            Observation observation = new Observation();
            HashMap hashMap2 = new HashMap();
            observation.setVariableToAssignment(hashMap2);
            observation.setName(split2[0]);
            for (int i = 2; i < split2.length; i++) {
                String str2 = split[i];
                if ((str2.endsWith("_mRNA") || str2.endsWith("_CNV")) && (variable = (Variable) hashMap.get(str2)) != null) {
                    hashMap2.put(variable, new Integer(split2[i]));
                }
            }
            arrayList.add(observation);
        }
    }

    @Test
    public void checkLoglikelihood() throws Exception {
        FileUtility.initializeLogging();
        FactorGraph convertPathway = new PathwayToFactorGraphConverter().convertPathway(getPathway(452723L));
        logger.info("Converted factors: " + convertPathway.getFactors().size() + " factors, and " + convertPathway.getVariables().size() + " variables.");
        setFactorGraph(convertPathway);
        setBurnin(100000);
        List<Observation<Integer>> generateSamples = generateSamples(100000);
        Set<String> outputEntityNames = getOutputEntityNames();
        HashSet hashSet = new HashSet();
        for (Variable variable : convertPathway.getVariables()) {
            if (outputEntityNames.contains(variable.getName())) {
                hashSet.add(variable);
            }
        }
        double d = Double.NEGATIVE_INFINITY;
        FileUtility fileUtility = new FileUtility();
        fileUtility.setOutput("results/paradigm/twoCases/GibbsSampling/Loglikelihood_Test.txt");
        for (Observation<Integer> observation : generateSamples) {
            Map<Variable, Integer> variableToAssignment = observation.getVariableToAssignment();
            double logLikelihood = convertPathway.getLogLikelihood(variableToAssignment);
            Iterator it = hashSet.iterator();
            while (it.hasNext()) {
                variableToAssignment.put((Variable) it.next(), 2);
            }
            double logLikelihood2 = convertPathway.getLogLikelihood(variableToAssignment);
            if (logLikelihood2 > d) {
                d = logLikelihood2;
            }
            fileUtility.printLine(observation.getName() + "\t" + logLikelihood + "\t" + logLikelihood2);
        }
        fileUtility.close();
        System.out.println("Maximum loglikelihood after clamping the output variables to 2: " + d);
    }

    @Test
    public void runGenerateSamples() throws Exception {
        FileUtility.initializeLogging();
        PathwayToFactorGraphConverter pathwayToFactorGraphConverter = new PathwayToFactorGraphConverter();
        GKInstance pathway = getPathway(452723L);
        FactorGraph convertPathway = pathwayToFactorGraphConverter.convertPathway(pathway);
        logger.info("Converted factors: " + convertPathway.getFactors().size() + " factors, and " + convertPathway.getVariables().size() + " variables.");
        setFactorGraph(convertPathway);
        setBurnin(100000);
        outputSamples(generateSamples(100000), pathway, convertPathway, pathwayToFactorGraphConverter.getInstToVarMap(), SAMPLE_FILE_NAME);
    }

    private Set<Variable> getObservationVariables(FactorGraph factorGraph, DataType... dataTypeArr) {
        HashSet hashSet = new HashSet();
        for (Variable variable : factorGraph.getVariables()) {
            String name = variable.getName();
            int length = dataTypeArr.length;
            int i = 0;
            while (true) {
                if (i >= length) {
                    break;
                }
                if (name.endsWith(dataTypeArr[i].toString())) {
                    hashSet.add(variable);
                    break;
                }
                i++;
            }
        }
        return hashSet;
    }

    private Set<Variable> getGeneVariables(FactorGraph factorGraph, Set<Variable> set) {
        HashSet hashSet = new HashSet();
        Iterator<Variable> it = set.iterator();
        while (it.hasNext()) {
            String name = it.next().getName();
            hashSet.add(name.substring(0, name.indexOf("_")));
        }
        HashSet hashSet2 = new HashSet();
        for (Variable variable : factorGraph.getVariables()) {
            String name2 = variable.getName();
            int indexOf = name2.indexOf("_");
            if (indexOf > 0 && hashSet.contains(name2.substring(0, indexOf))) {
                hashSet2.add(variable);
            }
        }
        return hashSet2;
    }

    private void outputSamples(List<Observation<Integer>> list, GKInstance gKInstance, FactorGraph factorGraph, Map<GKInstance, Variable> map, String str) throws Exception, IOException {
        List<Variable> pathwayVariables = new ReactomePathwayFGRunner().getPathwayVariables(gKInstance, map, factorGraph);
        Set<Variable> observationVariables = getObservationVariables(factorGraph, DataType.CNV, DataType.mRNA_EXP);
        Set<Variable> geneVariables = getGeneVariables(factorGraph, observationVariables);
        HashSet hashSet = new HashSet(pathwayVariables);
        hashSet.addAll(observationVariables);
        hashSet.addAll(geneVariables);
        ArrayList arrayList = new ArrayList(hashSet);
        Collections.sort(arrayList, new Comparator<Variable>() { // from class: org.reactome.pathway.factorgraph.GibbsSampler.1
            @Override // java.util.Comparator
            public int compare(Variable variable, Variable variable2) {
                return variable.getName().compareTo(variable2.getName());
            }
        });
        FileUtility fileUtility = new FileUtility();
        fileUtility.setOutput(str);
        StringBuilder sb = new StringBuilder();
        sb.append("Sample\tLogLikelihood");
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            sb.append("\t").append(((Variable) it.next()).getName());
        }
        fileUtility.printLine(sb.toString());
        for (Observation<Integer> observation : list) {
            sb.setLength(0);
            sb.append(observation.getName());
            Map<Variable, Integer> variableToAssignment = observation.getVariableToAssignment();
            sb.append("\t").append(factorGraph.getLogLikelihood(variableToAssignment));
            Iterator it2 = arrayList.iterator();
            while (it2.hasNext()) {
                sb.append("\t").append(variableToAssignment.get((Variable) it2.next()));
            }
            fileUtility.printLine(sb.toString());
        }
        fileUtility.close();
    }

    private GKInstance getPathway(Long l) throws Exception {
        return new MySQLAdaptor("localhost", "gk_current_ver50", R3Constants.DB_USER, R3Constants.DB_PWD).fetchInstance(l);
    }

    public List<Observation<Integer>> generateSamples(int i) {
        if (this.factorGraph == null) {
            throw new IllegalStateException("FactorGraph has not been specified.");
        }
        setMaxIteration(i);
        ArrayList arrayList = new ArrayList();
        resetCache();
        sample(burn(initializeAssignment()), arrayList);
        return arrayList;
    }

    private Observation<Integer> initializeAssignment() {
        Observation<Integer> observation = new Observation<>();
        HashMap hashMap = new HashMap();
        observation.setVariableToAssignment(hashMap);
        FactorValueAssignment factorValueAssignment = new FactorValueAssignment();
        Iterator<Factor> it = this.factorGraph.getFactors().iterator();
        while (it.hasNext()) {
            Factor slice = slice(it.next(), hashMap, factorValueAssignment);
            if (slice != null) {
                initializeAssignment(slice, hashMap, factorValueAssignment);
            }
        }
        System.out.println("Initialize state: " + hashMap);
        return observation;
    }

    private Factor slice(Factor factor, Map<Variable, Integer> map, FactorValueAssignment factorValueAssignment) {
        if (map.size() == 0) {
            return factor;
        }
        List<Variable> variables = factor.getVariables();
        ArrayList arrayList = new ArrayList(variables);
        arrayList.removeAll(map.keySet());
        if (arrayList.size() == 0) {
            return factor;
        }
        Factor factor2 = new Factor();
        factor2.setVariables(arrayList);
        ArrayList arrayList2 = new ArrayList();
        factorValueAssignment.setFactor(factor2);
        List<Map<Variable, Integer>> iterate = factorValueAssignment.iterate();
        HashMap hashMap = new HashMap();
        for (int i = 0; i < iterate.size(); i++) {
            hashMap.putAll(iterate.get(i));
            for (Variable variable : variables) {
                Integer num = map.get(variable);
                if (num != null) {
                    hashMap.put(variable, num);
                }
            }
            arrayList2.add(Double.valueOf(factor.getValue(hashMap).doubleValue()));
        }
        factor2.setValues(arrayList2);
        return factor2;
    }

    private void initializeAssignment(Factor factor, Map<Variable, Integer> map, FactorValueAssignment factorValueAssignment) {
        factorValueAssignment.setFactor(factor);
        List<Map<Variable, Integer>> iterate = factorValueAssignment.iterate();
        int i = 0;
        while (factorValueAssignment.getFactorValue(i) <= 1.0E-12d) {
            i++;
        }
        map.putAll(iterate.get(i));
    }

    private Set<String> getOutputEntityNames() {
        HashSet hashSet = new HashSet();
        for (String str : new String[]{"NANOG [nucleoplasm]", "POU5F1 [nucleoplasm]", "SOX2 [nucleoplasm]"}) {
            hashSet.add(str);
        }
        return hashSet;
    }

    @Test
    public void testGetTypeToSamples() throws IOException {
        Map<String, Set<String>> typeToSamples = getTypeToSamples(false);
        for (String str : typeToSamples.keySet()) {
            System.out.println(str + ": " + typeToSamples.get(str).size());
        }
    }

    public Map<String, Set<String>> getTypeToSamples(boolean z) throws IOException {
        HashMap hashMap = new HashMap();
        FileUtility fileUtility = new FileUtility();
        fileUtility.setInput(SAMPLE_FILE_NAME);
        String[] split = fileUtility.readLine().split("\t");
        HashSet hashSet = new HashSet();
        Set<String> outputEntityNames = getOutputEntityNames();
        while (true) {
            String readLine = fileUtility.readLine();
            if (readLine == null) {
                break;
            }
            String[] split2 = readLine.split("\t");
            hashSet.clear();
            for (int i = 2; i < split2.length; i++) {
                if (outputEntityNames.contains(split[i])) {
                    hashSet.add(split2[i]);
                }
            }
            if (hashSet.size() <= 1) {
                String str = (String) hashSet.iterator().next();
                if (str.equals("0")) {
                    InteractionUtilities.addElementToSet(hashMap, "0", split2[0]);
                } else if (str.equals("2")) {
                    InteractionUtilities.addElementToSet(hashMap, "2", split2[0]);
                }
            }
        }
        fileUtility.close();
        if (!z) {
            return hashMap;
        }
        HashMap hashMap2 = new HashMap();
        HashSet hashSet2 = new HashSet();
        Iterator it = hashMap.values().iterator();
        while (it.hasNext()) {
            hashSet2.addAll((Set) it.next());
        }
        for (String str2 : hashMap.keySet()) {
            Set set = (Set) hashMap.get(str2);
            if (hashSet2.size() <= set.size()) {
                hashMap2.put(str2, hashSet2);
            } else {
                Set randomSampling = MathUtilities.randomSampling(hashSet2, set.size());
                hashMap2.put(str2, randomSampling);
                hashSet2.removeAll(randomSampling);
            }
        }
        return hashMap2;
    }
}
