package org.reactome.factorgraph;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:caBIGR3-minimal-3.0.jar:org/reactome/factorgraph/AbstractInferencer.class */
public abstract class AbstractInferencer implements Inferencer {
    protected FactorGraph factorGraph;
    private Observation<? extends Number> observation;
    private List<Factor> observationFactors;
    protected double tolerance = 1.0E-6d;
    protected int maxIteration = 10000;
    protected int iteration;
    protected double maxDiff;
    protected boolean debug;
    private List<Factor> factorsFromContinuous;
    private List<ContinuousFactor> continuousFactors;

    @Override // org.reactome.factorgraph.Inferencer
    public void setFactorGraph(FactorGraph factorGraph) {
        this.factorGraph = factorGraph;
    }

    @Override // org.reactome.factorgraph.Inferencer
    public FactorGraph getFactorGraph() {
        return this.factorGraph;
    }

    @Override // org.reactome.factorgraph.Inferencer
    public <T extends Number> void setObservation(Map<Variable, T> map) {
        if (map == null) {
            clearObservation();
            return;
        }
        Observation<? extends Number> observation = new Observation<>();
        observation.setVariableToAssignment(map);
        this.observation = observation;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.reactome.factorgraph.Inferencer
    public <T extends Number> void setObservation(Observation<T> observation) {
        if (observation == 0) {
            clearObservation();
        } else {
            this.observation = observation;
        }
    }

    @Override // org.reactome.factorgraph.Inferencer
    public void clearObservation() {
        this.observation = null;
    }

    @Override // org.reactome.factorgraph.Inferencer
    public Observation<? extends Number> getObservation() {
        return this.observation;
    }

    public double getTolerance() {
        return this.tolerance;
    }

    public void setTolerance(double d) {
        this.tolerance = d;
    }

    public int getMaxIteration() {
        return this.maxIteration;
    }

    public void setMaxIteration(int i) {
        this.maxIteration = i;
    }

    public int getIteration() {
        return this.iteration;
    }

    public boolean getDebug() {
        return this.debug;
    }

    public void setDebug(boolean z) {
        this.debug = z;
    }

    public double getMaxDiff() {
        return this.maxDiff;
    }

    @Override // org.reactome.factorgraph.Inferencer
    public void runInference() throws InferenceCannotConvergeException {
        if (this.factorGraph == null) {
            throw new IllegalArgumentException("The target FactorGraph has not been assigned.");
        }
        if (!this.factorGraph.isInferreable()) {
            throw new IllegalArgumentException("This type of FactorGraph is not supported yet: probably continuous variables are not leaf nodes.");
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void truncateContinuousFactors() {
        if (this.continuousFactors == null) {
            this.continuousFactors = new ArrayList();
        } else {
            this.continuousFactors.clear();
        }
        if (this.factorsFromContinuous == null) {
            this.factorsFromContinuous = new ArrayList();
        } else {
            this.factorsFromContinuous.clear();
        }
        for (Factor factor : this.factorGraph.getFactors()) {
            if (factor instanceof ContinuousFactor) {
                ContinuousFactor continuousFactor = (ContinuousFactor) factor;
                this.continuousFactors.add(continuousFactor);
                this.factorsFromContinuous.add(convertContinuousFactor(continuousFactor));
            }
        }
        this.factorGraph.getFactors().removeAll(this.continuousFactors);
        Iterator<Factor> it = this.factorsFromContinuous.iterator();
        while (it.hasNext()) {
            this.factorGraph.addFactor(it.next());
        }
        this.factorGraph.validatVariables();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void addBackContinuosFactors() {
        if (this.continuousFactors.size() == 0 && this.factorsFromContinuous.size() == 0) {
            return;
        }
        detachFactors(this.factorsFromContinuous);
        Iterator<ContinuousFactor> it = this.continuousFactors.iterator();
        while (it.hasNext()) {
            this.factorGraph.addFactor(it.next());
        }
        this.factorGraph.validatVariables();
    }

    private Factor convertContinuousFactor(ContinuousFactor continuousFactor) {
        ContinuousVariable continuousVariable = continuousFactor.getContinuousVariable();
        Variable discreteVariable = continuousFactor.getDiscreteVariable();
        Factor factor = new Factor();
        ArrayList arrayList = new ArrayList();
        arrayList.add(discreteVariable);
        factor.setVariables(arrayList);
        discreteVariable.addFactor(factor);
        VariableAssignment<? extends Number> variableAssignment = null;
        if (this.observation != null) {
            variableAssignment = this.observation.getVariableAssignment(continuousVariable);
        }
        double[] marginalizeForDiscrete = continuousFactor.marginalizeForDiscrete(variableAssignment);
        if (Double.isNaN(marginalizeForDiscrete[0])) {
            System.out.println("Wrong factor values!");
        }
        factor.setValues(marginalizeForDiscrete);
        discreteVariable.removeFactor(continuousFactor);
        return factor;
    }

    @Override // org.reactome.factorgraph.Inferencer
    public double calculateLogZ() {
        double d = 0.0d;
        for (Variable variable : this.factorGraph.getVariables()) {
            double d2 = 0.0d;
            for (double d3 : variable.getBelief()) {
                if (d3 != 0.0d) {
                    d2 += d3 * Math.log(d3);
                }
            }
            d += (variable.getFactors().size() - 1) * d2;
            if (Double.isNaN(d)) {
                throw new IllegalStateException("NaN encountered in variable: " + variable + " with belief " + variable.getBelief());
            }
        }
        for (Factor factor : this.factorGraph.getFactors()) {
            double[] belief = factor.getBelief();
            double[] values = factor.getValues();
            for (int i = 0; i < belief.length; i++) {
                if (belief[i] != 0.0d && values[i] != 0.0d) {
                    d += belief[i] * (Math.log(values[i]) - Math.log(belief[i]));
                    if (Double.isNaN(d)) {
                        throw new IllegalStateException("NaN encountered in factor: " + factor + " with belief " + belief[i]);
                    }
                }
            }
        }
        return d;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void attachObservation() {
        if (this.observation == null) {
            return;
        }
        if (this.observationFactors == null) {
            this.observationFactors = new ArrayList();
        } else {
            this.observationFactors.clear();
        }
        Set<Variable> variables = this.factorGraph.getVariables();
        Map<Variable, ? extends Number> variableToAssignment = this.observation.getVariableToAssignment();
        for (Variable variable : variableToAssignment.keySet()) {
            if (variables.contains(variable)) {
                Factor factor = new Factor();
                ArrayList arrayList = new ArrayList();
                arrayList.add(variable);
                factor.setVariables(arrayList);
                variable.addFactor(factor);
                double[] dArr = new double[variable.getStates()];
                if (variable.getStates() == 2) {
                    double doubleValue = variableToAssignment.get(variable).doubleValue();
                    if (doubleValue > 1.0d || doubleValue < 0.0d) {
                        throw new IllegalStateException(variable + " value should be in [0, 1]. The assigne value is: " + doubleValue);
                    }
                    dArr[0] = 1.0d - doubleValue;
                    dArr[1] = doubleValue;
                } else {
                    dArr[variableToAssignment.get(variable).intValue()] = 1.0d;
                }
                factor.setValues(dArr);
                this.factorGraph.addFactor(factor);
                this.observationFactors.add(factor);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void detachObservation() {
        if (this.observation == null) {
            return;
        }
        detachFactors(this.observationFactors);
    }

    private void detachFactors(List<Factor> list) {
        this.factorGraph.getFactors().removeAll(list);
        Iterator<Factor> it = list.iterator();
        while (it.hasNext()) {
            detachFactor(it.next());
        }
        list.clear();
    }

    private void detachFactor(Factor factor) {
        Iterator<Variable> it = factor.getVariables().iterator();
        while (it.hasNext()) {
            it.next().removeFactor(factor);
        }
        if (factor.getOutEdges() != null) {
            for (Edge edge : factor.getOutEdges()) {
                Variable variable = (Variable) edge.getToNode();
                variable.removeInEdge(edge);
                variable.removeOutEdge(factor);
            }
        }
    }
}
