package elvira.inference.approximate;

import elvira.Bnet;
import elvira.Configuration;
import elvira.Evidence;
import elvira.FiniteStates;
import elvira.Node;
import elvira.NodeList;
import elvira.PairTable;
import elvira.Relation;
import elvira.RelationList;
import elvira.parser.ParseException;
import elvira.potential.CanonicalPotential;
import elvira.potential.Function;
import elvira.potential.FunctionSumNormIdf;
import elvira.potential.Potential;
import elvira.potential.PotentialFunction;
import elvira.potential.PotentialTable;
import elvira.potential.PotentialTree;
import java.io.FileInputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Date;
import java.util.Hashtable;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.lazy.kstar.KStarConstants;
import weka.core.TestInstances;

/* loaded from: input_file:bayelvira-1.0-SNAPSHOT.jar:elvira/inference/approximate/ImportanceSamplingFunctionTree.class */
public class ImportanceSamplingFunctionTree extends ImportanceSampling {
    double limitForPrunning;

    public static void main(String[] strArr) throws ParseException, IOException {
        if (strArr.length < 6) {
            System.out.println("Too few arguments");
        } else {
            Bnet bnet = new Bnet(new FileInputStream(strArr[4]));
            new ImportanceSamplingFunctionTree(bnet, strArr.length == 7 ? new Evidence(new FileInputStream(strArr[6]), bnet.getNodeList()) : new Evidence(), Double.valueOf(strArr[0]).doubleValue(), Integer.valueOf(strArr[1]).intValue(), Integer.valueOf(strArr[2]).intValue(), Integer.valueOf(strArr[3]).intValue()).propagate(strArr[5]);
        }
    }

    ImportanceSamplingFunctionTree() {
    }

    ImportanceSamplingFunctionTree(Bnet bnet, Evidence evidence) {
        this.observations = evidence;
        this.network = bnet;
    }

    public ImportanceSamplingFunctionTree(Bnet bnet, Evidence evidence, double d, int i, int i2, int i3) {
        this.observations = evidence;
        this.network = bnet;
        setLimitSize(i);
        setLimitForPrunning(d);
        setSampleSize(i2);
        setNumberOfRuns(i3);
    }

    public void setLimitForPrunning(double d) {
        this.limitForPrunning = 1.0d + (((0.5d - d) * Math.log(0.5d - d)) / Math.log(2.0d)) + (((0.5d + d) * Math.log(0.5d + d)) / Math.log(2.0d));
    }

    @Override // elvira.inference.Propagation
    public RelationList getInitialRelations() {
        RelationList relationList = new RelationList();
        for (int i = 0; i < this.network.getRelationList().size(); i++) {
            Relation relation = (Relation) this.network.getRelationList().elementAt(i);
            if (relation.getActive()) {
                Relation relation2 = new Relation();
                relation2.setVariables(relation.getVariables().copy());
                if (relation.getValues().getClass().getName().equals("PotentialTable")) {
                    relation2.setValues(((PotentialTable) relation.getValues()).toTree());
                } else if (relation.getValues().getClass().getName().equals("CanonicalPotential")) {
                    relation2.setValues(((CanonicalPotential) relation.getValues()).toTree());
                } else {
                    if (relation.getValues().getClass().getName().equals("PotentialFunction")) {
                        Function function = ((PotentialFunction) relation.getValues()).getFunction();
                        if (function.getClass().getName().equals("FunctionSumNormIdf")) {
                            relation.setValues(FunctionSumNormIdf.sumToAddNormIdf((PotentialFunction) relation.getValues()));
                        }
                    }
                    relation2.setValues(relation.getValues());
                }
                relationList.insertRelation(relation2);
            }
        }
        return relationList;
    }

    public void getSamplingDistributions() {
        NodeList nodeList = new NodeList();
        PairTable pairTable = new PairTable();
        this.positions = new Hashtable(20);
        this.deletionSequence = new NodeList();
        this.samplingDistributions = new Vector();
        int size = this.network.getNodeList().size();
        for (int i = 0; i < size; i++) {
            FiniteStates finiteStates = (FiniteStates) this.network.getNodeList().elementAt(i);
            if (!this.observations.isObserved(finiteStates)) {
                nodeList.insertNode(finiteStates);
                pairTable.addElement(finiteStates);
            }
        }
        RelationList initialRelations = getInitialRelations();
        if (this.observations.size() > 0) {
            restrictToObservations(initialRelations);
        }
        for (int i2 = 0; i2 < initialRelations.size(); i2++) {
            pairTable.addRelation(initialRelations.elementAt(i2));
        }
        for (int size2 = nodeList.size(); size2 > 0; size2--) {
            Node nextToRemove = pairTable.nextToRemove();
            this.positions.put(nextToRemove, new Integer(size2 - 1));
            nodeList.removeNode(nextToRemove);
            pairTable.removeVariable(nextToRemove);
            this.deletionSequence.insertNode(nextToRemove);
            RelationList relationsOfAndRemove = initialRelations.getRelationsOfAndRemove(nextToRemove);
            Relation elementAt = relationsOfAndRemove.elementAt(0);
            pairTable.removeRelation(elementAt);
            Potential values = elementAt.getValues();
            for (int i3 = 1; i3 < relationsOfAndRemove.size(); i3++) {
                Relation elementAt2 = relationsOfAndRemove.elementAt(i3);
                pairTable.removeRelation(elementAt2);
                System.out.print("\nCombining...P1 * P2\n");
                for (int i4 = 0; i4 < elementAt2.getVariables().size(); i4++) {
                    System.out.print(((FiniteStates) elementAt2.getValues().getVariables().elementAt(i4)).getName() + TestInstances.DEFAULT_SEPARATORS);
                }
                System.out.print(" and ");
                for (int i5 = 0; i5 < values.getVariables().size(); i5++) {
                    System.out.print(((FiniteStates) values.getVariables().elementAt(i5)).getName() + TestInstances.DEFAULT_SEPARATORS);
                }
                values = values.combine(elementAt2.getValues());
                System.out.print("\n Resulting  ");
                for (int i6 = 0; i6 < values.getVariables().size(); i6++) {
                    System.out.print(((FiniteStates) values.getVariables().elementAt(i6)).getName() + TestInstances.DEFAULT_SEPARATORS);
                }
            }
            this.samplingDistributions.addElement(values);
            if (size2 > 1) {
                values = values.addVariable(nextToRemove);
                ((PotentialTree) values).limitBound(this.limitForPrunning);
            }
            Potential sortAndBound = ((PotentialTree) values).sortAndBound(this.limitSize);
            for (int size3 = sortAndBound.getVariables().size() - 1; size3 >= 0; size3--) {
                FiniteStates finiteStates2 = (FiniteStates) sortAndBound.getVariables().elementAt(size3);
                if (!((PotentialTree) sortAndBound).getTree().isIn(finiteStates2) && initialRelations.isIn(finiteStates2)) {
                    sortAndBound.getVariables().removeElementAt(sortAndBound.getVariables().indexOf(finiteStates2));
                }
            }
            Relation relation = new Relation();
            relation.setKind(1);
            relation.getVariables().setNodes((Vector) sortAndBound.getVariables().clone());
            relation.setValues(sortAndBound);
            initialRelations.insertRelation(relation);
            pairTable.addRelation(relation);
        }
    }

    public boolean simulateConfiguration(Random random) {
        boolean z = true;
        int size = this.samplingDistributions.size() - 1;
        int i = size;
        while (true) {
            if (i < 0) {
                break;
            }
            int simulateValue = simulateValue((FiniteStates) this.deletionSequence.elementAt(i), size - i, (Potential) this.samplingDistributions.elementAt(i), random);
            if (simulateValue == -1) {
                z = false;
                break;
            }
            this.currentConf[size - i] = simulateValue;
            i--;
        }
        return z;
    }

    public int simulateValue(FiniteStates finiteStates, int i, Potential potential, Random random) {
        int i2 = -1;
        double d = 0.0d;
        double d2 = 0.0d;
        int numStates = finiteStates.getNumStates();
        double[] dArr = new double[numStates];
        for (int i3 = 0; i3 < numStates; i3++) {
            this.currentConf[i] = i3;
            dArr[i3] = potential.getValue(this.positions, this.currentConf);
            d += dArr[i3];
        }
        if (d == KStarConstants.FLOOR) {
            System.out.println("Zero valuation");
            return -1;
        }
        double nextDouble = random.nextDouble();
        int i4 = 0;
        while (true) {
            if (i4 >= numStates) {
                break;
            }
            d2 += dArr[i4] / d;
            if (nextDouble <= d2) {
                i2 = i4;
                break;
            }
            i4++;
        }
        this.currentWeight /= dArr[i2] / d;
        return i2;
    }

    public void simulate() {
        Random random = new Random();
        this.currentConf = new int[this.network.getNodeList().size()];
        for (int i = 0; i < this.sampleSize; i++) {
            this.currentWeight = 1.0d;
            if (simulateConfiguration(random)) {
                this.currentWeight *= evaluate();
                updateSimulationInformation();
            }
        }
    }

    public double evaluate(Configuration configuration) {
        double d = 1.0d;
        int size = this.initialRelations.size();
        for (int i = 0; i < size; i++) {
            d *= this.initialRelations.elementAt(i).getValues().getValue(configuration);
        }
        return d;
    }

    public double evaluate() {
        double d = 1.0d;
        int size = this.initialRelations.size();
        for (int i = 0; i < size; i++) {
            Potential values = this.initialRelations.elementAt(i).getValues();
            if (values.getVariables().size() > 0) {
                d *= values.getValue(this.positions, this.currentConf);
            }
        }
        return d;
    }

    public void restrictToObservations(RelationList relationList) {
        int size = relationList.size();
        for (int i = 0; i < size; i++) {
            Relation elementAt = relationList.elementAt(i);
            elementAt.setValues(elementAt.getValues().restrictVariable(this.observations));
            elementAt.getVariables().setNodes(elementAt.getValues().getVariables());
        }
    }

    public void propagate(String str) throws ParseException, IOException {
        int i;
        this.sumW = KStarConstants.FLOOR;
        this.sumW2 = KStarConstants.FLOOR;
        double[] dArr = new double[2];
        this.initialRelations = getInitialRelations();
        if (this.observations.size() > 0) {
            restrictToObservations(this.initialRelations);
        }
        double time = new Date().getTime();
        System.out.println("Computing sampling distributions");
        getSamplingDistributions();
        System.out.println("Sampling distributions computed");
        double time2 = (new Date().getTime() - time) / 1000.0d;
        initSimulationInformation();
        System.out.println("Simulating");
        double time3 = new Date().getTime();
        for (int i2 = 0; i2 < this.numberOfRuns; i2 = i + 1) {
            simulate();
            System.out.println("Tras Simulate");
            normalizeResults();
            i = 0;
            while (i < this.results.size()) {
                ((Potential) this.results.elementAt(i)).print();
                i++;
            }
        }
        double time4 = (new Date().getTime() - time3) / (this.numberOfRuns * 1000);
        double d = this.sampleSize * this.numberOfRuns;
        this.sumW /= d;
        double d2 = (this.sumW2 / ((d * this.sumW) * this.sumW)) - 1.0d;
        double d3 = KStarConstants.FLOOR / this.numberOfRuns;
        double d4 = KStarConstants.FLOOR / this.numberOfRuns;
        FileWriter fileWriter = new FileWriter(str);
        PrintWriter printWriter = new PrintWriter(fileWriter);
        printWriter.println("Time computing sampling distributions (secs): " + time2);
        printWriter.println("Time simulating (avg) : " + time4);
        printWriter.println("G : " + d3);
        printWriter.println("MSE : " + d4);
        printWriter.println("Variance : " + d2);
        fileWriter.close();
        System.out.println("Done");
    }

    public void propagate() {
        int i;
        this.sumW = KStarConstants.FLOOR;
        this.sumW2 = KStarConstants.FLOOR;
        double[] dArr = new double[2];
        this.initialRelations = getInitialRelations();
        if (this.observations.size() > 0) {
            restrictToObservations(this.initialRelations);
        }
        double time = new Date().getTime();
        System.out.println("Computing sampling distributions");
        getSamplingDistributions();
        System.out.println("Sampling distributions computed");
        double time2 = (new Date().getTime() - time) / 1000.0d;
        initSimulationInformation();
        System.out.println("Simulating");
        double time3 = new Date().getTime();
        for (int i2 = 0; i2 < this.numberOfRuns; i2 = i + 1) {
            simulate();
            System.out.println("Tras Simulate");
            normalizeResults();
            i = 0;
            while (i < this.results.size()) {
                ((Potential) this.results.elementAt(i)).print();
                i++;
            }
        }
        double time4 = (new Date().getTime() - time3) / (this.numberOfRuns * 1000);
        double d = this.sampleSize * this.numberOfRuns;
        this.sumW /= d;
        double d2 = (this.sumW2 / ((d * this.sumW) * this.sumW)) - 1.0d;
        double d3 = KStarConstants.FLOOR / this.numberOfRuns;
        double d4 = KStarConstants.FLOOR / this.numberOfRuns;
        System.out.println("Time computing sampling distributions (secs): " + time2);
        System.out.println("Time simulating (avg) : " + time4);
        System.out.println("G : " + d3);
        System.out.println("MSE : " + d4);
        System.out.println("Variance : " + d2);
        System.out.println("Done");
    }
}
