package org.reactome.factorgraph;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.log4j.Logger;

/* loaded from: input_file:caBIGR3-minimal-3.0.jar:org/reactome/factorgraph/LoopyBeliefPropagation.class */
public class LoopyBeliefPropagation extends AbstractInferencer {
    private static final Logger logger = Logger.getLogger(LoopyBeliefPropagation.class);
    private boolean logSpace;
    private boolean updateViaFactors;
    private InferenceType inferenceType = InferenceType.SUM_PRODUCT;
    private double initialMessage = 1.0d;
    private double dumping = 0.0d;
    private boolean enableConvergenceCheck = true;

    public boolean isEnableConvergenceCheck() {
        return this.enableConvergenceCheck;
    }

    public void setEnableConvergenceCheck(boolean z) {
        this.enableConvergenceCheck = z;
    }

    public boolean isUpdateViaFactors() {
        return this.updateViaFactors;
    }

    public void setDumping(double d) {
        this.dumping = d;
    }

    public double getDumping() {
        return this.dumping;
    }

    public void setUpdateViaFactors(boolean z) {
        this.updateViaFactors = z;
    }

    public boolean getUpdateViaFactors() {
        return this.updateViaFactors;
    }

    public void setInitialMessage(double d) {
        this.initialMessage = d;
    }

    public double getInitialMessage() {
        return this.initialMessage;
    }

    public void setUseLogSpace(boolean z) {
        this.logSpace = z;
    }

    public boolean getUseLogSpace() {
        return this.logSpace;
    }

    public void setInferenceType(InferenceType inferenceType) {
        this.inferenceType = inferenceType;
    }

    public InferenceType getInferenceType() {
        return this.inferenceType;
    }

    @Override // org.reactome.factorgraph.AbstractInferencer, org.reactome.factorgraph.Inferencer
    public synchronized void runInference() throws InferenceCannotConvergeException {
        super.runInference();
        truncateContinuousFactors();
        attachObservation();
        initializeMessages(this.factorGraph);
        ArrayList<Factor> arrayList = new ArrayList(this.factorGraph.getFactors());
        ArrayList<Variable> arrayList2 = new ArrayList(this.factorGraph.getVariables());
        this.maxDiff = Double.MAX_VALUE;
        this.iteration = 0;
        long currentTimeMillis = System.currentTimeMillis();
        ArrayList arrayList3 = new ArrayList();
        while (this.iteration <= this.maxIteration && this.maxDiff > this.tolerance) {
            this.maxDiff = 0.0d;
            if (this.updateViaFactors) {
                Collections.shuffle(arrayList);
                for (Factor factor : arrayList) {
                    Iterator<Edge> it = factor.getOutEdges().iterator();
                    while (it.hasNext()) {
                        sendMessage(it.next());
                    }
                    Iterator<Edge> it2 = factor.getInEdges().iterator();
                    while (it2.hasNext()) {
                        sendMessage(it2.next());
                    }
                }
            } else {
                Collections.shuffle(arrayList2);
                for (Variable variable : arrayList2) {
                    Iterator<Edge> it3 = variable.getInEdges().iterator();
                    while (it3.hasNext()) {
                        sendMessage(it3.next());
                    }
                    Iterator<Edge> it4 = variable.getOutEdges().iterator();
                    while (it4.hasNext()) {
                        sendMessage(it4.next());
                    }
                }
            }
            this.iteration++;
            if (this.debug) {
                logger.info("Iteration: " + this.iteration + ", maxDiff: " + this.maxDiff);
            }
            arrayList3.add(Double.valueOf(this.maxDiff));
            if (!validateConverge(arrayList3)) {
                detachObservation();
                throw new InferenceCannotConvergeException("Inference for " + this.factorGraph + ": cannot converge.");
            }
        }
        long currentTimeMillis2 = System.currentTimeMillis();
        if (this.debug) {
            logger.info("Inference is done: " + this.iteration + ", maxDiff: " + this.maxDiff + ", using " + ((currentTimeMillis2 - currentTimeMillis) / 1000.0d) + " seconds.");
        }
        if (this.iteration > this.maxIteration) {
            logger.warn("Inferece for " + this.factorGraph + ": reach max iterations " + this.iteration + " with maxDiff " + this.maxDiff);
        }
        calculateBeliefs(this.factorGraph);
        detachObservation();
        addBackContinuosFactors();
    }

    private boolean validateConverge(List<Double> list) throws InferenceCannotConvergeException {
        if (!this.enableConvergenceCheck || list.size() < 50) {
            return true;
        }
        int i = 0;
        for (int i2 = 0; i2 < list.size() - 1; i2++) {
            if (list.get(i2 + 1).doubleValue() > list.get(i2).doubleValue()) {
                i++;
            }
        }
        return i < 10;
    }

    private void calculateBeliefs(FactorGraph factorGraph) {
        Iterator<Variable> it = factorGraph.getVariables().iterator();
        while (it.hasNext()) {
            it.next().updateBelief(this.logSpace);
        }
        Iterator<Factor> it2 = factorGraph.getFactors().iterator();
        while (it2.hasNext()) {
            it2.next().updateBelief(this.logSpace);
        }
    }

    private void initializeMessages(FactorGraph factorGraph) {
        Iterator<Factor> it = factorGraph.getFactors().iterator();
        while (it.hasNext()) {
            it.next().resetEdges();
        }
        Iterator<Variable> it2 = factorGraph.getVariables().iterator();
        while (it2.hasNext()) {
            it2.next().resetEdges();
        }
        Iterator<Factor> it3 = factorGraph.getFactors().iterator();
        while (it3.hasNext()) {
            initializeMessages(it3.next());
        }
    }

    private void initializeMessages(Factor factor) {
        for (Variable variable : factor.getVariables()) {
            Edge edge = new Edge(factor, variable);
            edge.initializeMessage(this.initialMessage, this.logSpace);
            Edge edge2 = new Edge(variable, factor);
            edge2.initializeMessage(this.initialMessage, this.logSpace);
            factor.addOutEdge(edge);
            factor.addInEdge(edge2);
            variable.addOutEdge(edge2);
            variable.addInEdge(edge);
        }
    }

    public Map<Variable, Integer> findMaximum() {
        if (this.inferenceType != InferenceType.MAX_PRODUCT) {
            throw new IllegalArgumentException("findMaximum should be called after a MAX_PROD inference.");
        }
        HashMap hashMap = new HashMap();
        HashSet<Variable> hashSet = new HashSet(this.factorGraph.getVariables());
        for (Variable variable : hashSet) {
            double[] belief = variable.getBelief();
            if (belief.length == 1) {
                hashMap.put(variable, Integer.valueOf(belief.length - 1));
            } else {
                int findMaxiumIndex = findMaxiumIndex(belief);
                if (findMaxiumIndex > -1) {
                    hashMap.put(variable, Integer.valueOf(findMaxiumIndex));
                }
            }
        }
        hashSet.removeAll(hashMap.keySet());
        if (hashSet.size() > 0 && this.debug) {
            logger.info("There are ambiguous variables: " + hashSet.size());
        }
        for (Factor factor : this.factorGraph.getFactors()) {
            boolean z = false;
            Iterator<Variable> it = factor.getVariables().iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                if (hashSet.contains(it.next())) {
                    z = true;
                    break;
                }
            }
            if (z) {
                extractMaximumFromFactor(factor, hashMap);
            }
        }
        hashSet.removeAll(hashMap.keySet());
        if (hashSet.size() > 0) {
            logger.error("There are variables that cannot be assigned: " + hashSet);
        }
        return hashMap;
    }

    private void extractMaximumFromFactor(Factor factor, Map<Variable, Integer> map) {
        double[] belief = factor.getBelief();
        List<Integer> sortIndices = sortIndices(belief);
        int i = -1;
        Map<Variable, Integer> map2 = null;
        for (int i2 = 0; i2 < sortIndices.size(); i2++) {
            if (i2 > 0) {
                if (belief[sortIndices.get(i2).intValue()] < belief[sortIndices.get(i2 - 1).intValue()]) {
                    break;
                }
            }
            Map<Variable, Integer> assignment = factor.getAssignment(sortIndices.get(i2).intValue());
            int i3 = 0;
            for (Variable variable : assignment.keySet()) {
                Integer num = map.get(variable);
                if (num != null && num.equals(assignment.get(variable))) {
                    i3++;
                }
            }
            if (i3 > i) {
                i = i3;
                map2 = assignment;
            }
        }
        for (Variable variable2 : map2.keySet()) {
            if (!map.containsKey(variable2)) {
                map.put(variable2, map2.get(variable2));
            } else if (map.get(variable2) != map2.get(variable2)) {
                logger.warn("Factor has a different assignment for Variable " + variable2 + " in Factor " + factor);
            }
        }
    }

    private List<Integer> sortIndices(final double[] dArr) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < dArr.length; i++) {
            arrayList.add(Integer.valueOf(i));
        }
        Collections.sort(arrayList, new Comparator<Integer>() { // from class: org.reactome.factorgraph.LoopyBeliefPropagation.1
            @Override // java.util.Comparator
            public int compare(Integer num, Integer num2) {
                double d = dArr[num.intValue()];
                double d2 = dArr[num2.intValue()];
                if (d > d2) {
                    return -1;
                }
                return d < d2 ? 1 : 0;
            }
        });
        return arrayList;
    }

    private int findMaxiumIndex(double[] dArr) {
        double d = Double.MIN_VALUE;
        int i = -1;
        double d2 = Double.MIN_VALUE;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            if (dArr[i2] > d) {
                d = dArr[i2];
                i = i2;
            } else if (dArr[i2] > d2) {
                d2 = dArr[i2];
            }
        }
        if (d == d2) {
            return -1;
        }
        return i;
    }

    private void sendMessage(Edge edge) {
        double[] sendMessage = edge.getFromNode().sendMessage(edge.getToNode(), this.inferenceType, this.logSpace);
        double[] message = edge.getMessage();
        for (int i = 0; i < sendMessage.length; i++) {
            if (Double.isNaN(sendMessage[i])) {
                throw new IllegalStateException("A Message contains NaN: a possible numerical underflow occurs. Probably the log-space should be used for computation.");
            }
            double abs = Math.abs(sendMessage[i] - message[i]);
            if (abs > this.maxDiff) {
                this.maxDiff = abs;
            }
        }
        edge.setMessage(sendMessage);
    }

    private void addDumping(double[] dArr, double[] dArr2, boolean z, FGNode fGNode) {
        if (z) {
            fGNode.convertLogToProb(dArr);
            fGNode.convertLogToProb(dArr2);
        }
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = ((1.0d - this.dumping) * dArr[i]) + (this.dumping * dArr2[i]);
        }
        if (z) {
            fGNode.convertProbToLog(dArr);
            fGNode.convertProbToLog(dArr2);
        }
    }
}
