package org.reactome.fi.pgm;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.log4j.PropertyConfigurator;
import org.junit.Test;
import org.reactome.factorgraph.Factor;
import org.reactome.factorgraph.FactorGraph;
import org.reactome.factorgraph.InferenceCannotConvergeException;
import org.reactome.factorgraph.InferenceType;
import org.reactome.factorgraph.LoopyBeliefPropagation;
import org.reactome.factorgraph.Variable;
import org.reactome.factorgraph.common.CentralDogmaHandler;
import org.reactome.factorgraph.common.DataType;
import org.reactome.factorgraph.common.PGMConfiguration;

/* loaded from: input_file:modeling-1.0.3.jar:org/reactome/fi/pgm/CentralDogmaSubGraphHandler.class */
public class CentralDogmaSubGraphHandler {
    private FactorGraph fg;
    private Map<String, Variable> nameToVariable;
    private LoopyBeliefPropagation lbp;
    private Map<Variable, Integer> varToAssgn;
    private Map<Integer, double[]> observationToBelief;

    public CentralDogmaSubGraphHandler() {
        construct();
        this.lbp = new LoopyBeliefPropagation();
        this.lbp.setFactorGraph(this.fg);
        this.observationToBelief = new HashMap();
    }

    public void setInferenceType(InferenceType inferenceType) {
        this.lbp.setInferenceType(inferenceType);
    }

    public InferenceType getInferenceType() {
        return this.lbp.getInferenceType();
    }

    public double[] calculateBelief(Map<DataType, Integer> map) throws InferenceCannotConvergeException {
        Integer generateObservationKey = generateObservationKey(map);
        double[] dArr = this.observationToBelief.get(generateObservationKey);
        if (dArr != null) {
            return dArr;
        }
        if (this.varToAssgn == null) {
            this.varToAssgn = new HashMap();
        } else {
            this.varToAssgn.clear();
        }
        if (map != null) {
            new HashMap();
            for (DataType dataType : map.keySet()) {
                Variable variable = this.nameToVariable.get(dataType.toString());
                if (variable == null) {
                    throw new IllegalArgumentException(dataType + " is not enabled!");
                }
                Integer num = map.get(dataType);
                if (num != null) {
                    this.varToAssgn.put(variable, num);
                }
            }
        }
        this.lbp.setObservation(this.varToAssgn);
        this.lbp.runInference();
        double[] belief = this.nameToVariable.get("protein").getBelief();
        double[] dArr2 = new double[belief.length];
        System.arraycopy(belief, 0, dArr2, 0, belief.length);
        this.observationToBelief.put(generateObservationKey, dArr2);
        return belief;
    }

    private Integer generateObservationKey(Map<DataType, Integer> map) {
        int i = 0;
        if (map == null) {
            return 0;
        }
        for (DataType dataType : map.keySet()) {
            Integer num = map.get(dataType);
            if (num == null) {
                num = -1;
            }
            i += DataType.getKeyStride(dataType) * (num.intValue() + 1);
        }
        return Integer.valueOf(i);
    }

    public double[] calculateBelief() throws InferenceCannotConvergeException {
        return calculateBelief(null);
    }

    private void construct() {
        this.fg = new FactorGraph();
        this.nameToVariable = new HashMap();
        HashSet hashSet = new HashSet();
        Variable createVariable = createVariable("protein");
        Variable createVariable2 = createVariable(PGMConfiguration.mRNA);
        Variable createVariable3 = createVariable(PGMConfiguration.DNA);
        CentralDogmaHandler centralDogmaHandler = new CentralDogmaHandler();
        centralDogmaHandler.setConfiguration(FIPGMConfiguration.getConfig());
        centralDogmaHandler.createCentralDogmaFactor(createVariable2, createVariable, hashSet, null);
        centralDogmaHandler.createCentralDogmaFactor(createVariable3, createVariable, hashSet, null);
        this.fg.setFactors(new HashSet(hashSet));
        this.fg.validatVariables();
    }

    public void enableDataType(DataType dataType) {
        if (this.nameToVariable.containsKey(dataType)) {
            return;
        }
        double[] dataTypeValues = FIPGMConfiguration.getConfig().getDataTypeValues(dataType);
        if (dataTypeValues == null) {
            throw new IllegalArgumentException(dataType + " has not supported yet: no values have been assigned.");
        }
        Variable createVariable = createVariable(dataType.toString());
        Set<Factor> factors = this.fg.getFactors();
        Variable variable = null;
        switch (dataType) {
            case CNV:
                variable = this.nameToVariable.get(PGMConfiguration.DNA);
                break;
            case mRNA_EXP:
                variable = this.nameToVariable.get(PGMConfiguration.mRNA);
                break;
            case Methylation:
                variable = this.nameToVariable.get(PGMConfiguration.DNA);
                break;
            case miRNA:
                variable = this.nameToVariable.get(PGMConfiguration.mRNA);
                break;
            case Mutation:
                variable = this.nameToVariable.get("protein");
                break;
        }
        if (variable == null) {
            throw new IllegalArgumentException(dataType + " has not supported yet: no central dogma node can be assigned.");
        }
        factors.add(new Factor(variable, createVariable, dataTypeValues));
        this.fg.validatVariables();
    }

    private Variable createVariable(String str) {
        Variable variable = new Variable();
        variable.setName(str);
        variable.setStates(2);
        this.nameToVariable.put(str, variable);
        return variable;
    }

    @Test
    public void testInference() throws InferenceCannotConvergeException {
        PropertyConfigurator.configure("resources/log4j.properties");
        enableDataType(DataType.mRNA_EXP);
        enableDataType(DataType.Mutation);
        setInferenceType(InferenceType.SUM_PRODUCT);
        StringBuilder sb = new StringBuilder();
        for (double d : calculateBelief()) {
            sb.append(d).append("\t");
        }
        System.out.println("Prior belief: " + sb.toString());
        HashMap hashMap = new HashMap();
        ArrayList arrayList = new ArrayList();
        arrayList.add(DataType.mRNA_EXP);
        arrayList.add(DataType.Mutation);
        List<Integer> arrayList2 = new ArrayList<>();
        for (int i = 0; i < arrayList.size(); i++) {
            arrayList2.add(0);
        }
        System.out.println("mRNA_Exp\tMutation\tp(i=0)\tp(i=1)");
        while (!isDone(arrayList2)) {
            for (int i2 = 0; i2 < arrayList2.size(); i2++) {
                hashMap.put(arrayList.get(i2), arrayList2.get(i2));
            }
            int i3 = 0;
            while (true) {
                if (i3 < arrayList2.size()) {
                    Integer num = arrayList2.get(i3);
                    if (num.intValue() < 1) {
                        arrayList2.set(i3, Integer.valueOf(num.intValue() + 1));
                        for (int i4 = i3 - 1; i4 >= 0; i4--) {
                            arrayList2.set(i4, 0);
                        }
                    } else {
                        i3++;
                    }
                }
            }
            testInference(sb, hashMap);
        }
        for (int i5 = 0; i5 < arrayList2.size(); i5++) {
            hashMap.put(arrayList.get(i5), arrayList2.get(i5));
        }
        testInference(sb, hashMap);
    }

    private boolean isDone(List<Integer> list) {
        Iterator<Integer> it = list.iterator();
        while (it.hasNext()) {
            if (it.next().intValue() < 1) {
                return false;
            }
        }
        return true;
    }

    private void testInference(StringBuilder sb, Map<DataType, Integer> map) throws InferenceCannotConvergeException {
        double[] calculateBelief = calculateBelief(map);
        sb.setLength(0);
        ArrayList arrayList = new ArrayList(map.keySet());
        Collections.sort(arrayList);
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            sb.append(map.get((DataType) it.next()) + "\t");
        }
        sb.replace(sb.length() - 1, sb.length(), "\t");
        for (double d : calculateBelief) {
            sb.append(d).append("\t");
        }
        System.out.println(sb.toString());
    }

    public void mergeBeliefToFactorValues(double[] dArr, double[] dArr2) {
        ArrayList arrayList = new ArrayList(1);
        arrayList.add(dArr2);
        mergeBeliefToFactorValues(dArr, arrayList);
    }

    public void mergeBeliefToFactorValues(double[] dArr, List<double[]> list) {
        for (double[] dArr2 : list) {
            for (int i = 0; i < dArr.length; i++) {
                int i2 = i;
                dArr[i2] = dArr[i2] * dArr2[i];
            }
        }
        double d = 0.0d;
        for (double d2 : dArr) {
            d += d2;
        }
        for (int i3 = 0; i3 < dArr.length; i3++) {
            int i4 = i3;
            dArr[i4] = dArr[i4] / d;
        }
    }
}
