package elvira.inference.approximate;

import elvira.Bnet;
import elvira.Configuration;
import elvira.Evidence;
import elvira.FiniteStates;
import elvira.Node;
import elvira.NodeList;
import elvira.PairTable;
import elvira.Relation;
import elvira.RelationList;
import elvira.parser.ParseException;
import elvira.potential.CanonicalPotential;
import elvira.potential.PotentialTable;
import elvira.potential.PotentialTree;
import java.io.FileInputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Date;
import java.util.Hashtable;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.lazy.kstar.KStarConstants;

/* loaded from: input_file:bayelvira-1.0-SNAPSHOT.jar:elvira/inference/approximate/ImportanceSamplingTree.class */
public class ImportanceSamplingTree extends ImportanceSampling {
    double limitForPrunning;

    public static void main(String[] strArr) throws ParseException, IOException {
        if (strArr.length < 5) {
            System.out.println("Wrong number of arguments.");
            return;
        }
        if (strArr.length < 7) {
            Bnet bnet = new Bnet(new FileInputStream(strArr[0]));
            ImportanceSamplingTree importanceSamplingTree = new ImportanceSamplingTree(bnet, strArr.length == 6 ? new Evidence(new FileInputStream(strArr[5]), bnet.getNodeList()) : new Evidence(), Double.valueOf(strArr[2]).doubleValue(), Integer.valueOf(strArr[3]).intValue(), Integer.valueOf(strArr[4]).intValue(), 1);
            importanceSamplingTree.propagate();
            importanceSamplingTree.saveResults(strArr[1]);
            return;
        }
        Bnet bnet2 = new Bnet(new FileInputStream(strArr[4]));
        ImportanceSamplingTree importanceSamplingTree2 = new ImportanceSamplingTree(bnet2, strArr.length == 8 ? new Evidence(new FileInputStream(strArr[7]), bnet2.getNodeList()) : new Evidence(), Double.valueOf(strArr[0]).doubleValue(), Integer.valueOf(strArr[1]).intValue(), Integer.valueOf(strArr[2]).intValue(), Integer.valueOf(strArr[3]).intValue());
        System.out.println("Reading exact results");
        importanceSamplingTree2.readExactResults(strArr[6]);
        System.out.println("Done");
        importanceSamplingTree2.propagate(strArr[5]);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public ImportanceSamplingTree() {
    }

    ImportanceSamplingTree(Bnet bnet, Evidence evidence) {
        this.observations = evidence;
        this.network = bnet;
        this.positions = new Hashtable(20);
    }

    public ImportanceSamplingTree(Bnet bnet, Evidence evidence, double d, int i, int i2, int i3) {
        this.observations = evidence;
        this.network = bnet;
        setLimitSize(i);
        setLimitForPrunning(d);
        setSampleSize(i2);
        setNumberOfRuns(i3);
        this.positions = new Hashtable(20);
    }

    public void setLimitForPrunning(double d) {
        this.limitForPrunning = 1.0d + (((0.5d - d) * Math.log(0.5d - d)) / Math.log(2.0d)) + (((0.5d + d) * Math.log(0.5d + d)) / Math.log(2.0d));
    }

    @Override // elvira.inference.Propagation
    public RelationList getInitialRelations() {
        RelationList relationList = new RelationList();
        for (int i = 0; i < this.network.getRelationList().size(); i++) {
            Relation relation = (Relation) this.network.getRelationList().elementAt(i);
            if (relation.getActive()) {
                Relation relation2 = new Relation();
                relation2.setVariables(relation.getVariables().copy());
                if (relation.getValues().getClass().getName().equals("elvira.potential.PotentialTable")) {
                    relation2.setValues(((PotentialTable) relation.getValues()).toTree());
                } else if (relation.getValues().getClass().getName().equals("elvira.potential.CanonicalPotential")) {
                    relation2.setValues(((CanonicalPotential) relation.getValues()).toTree());
                } else {
                    relation2.setValues(relation.getValues());
                }
                relationList.insertRelation(relation2);
            }
        }
        return relationList;
    }

    public void getSamplingDistributions() {
        NodeList nodeList = new NodeList();
        PairTable pairTable = new PairTable();
        this.deletionSequence = new NodeList();
        this.samplingDistributions = new Vector();
        int size = this.network.getNodeList().size();
        for (int i = 0; i < size; i++) {
            FiniteStates finiteStates = (FiniteStates) this.network.getNodeList().elementAt(i);
            if (!this.observations.isObserved(finiteStates)) {
                nodeList.insertNode(finiteStates);
                pairTable.addElement(finiteStates);
            }
        }
        RelationList initialRelations = getInitialRelations();
        if (this.observations.size() > 0) {
            restrictToObservations(initialRelations);
        }
        for (int i2 = 0; i2 < initialRelations.size(); i2++) {
            pairTable.addRelation(initialRelations.elementAt(i2));
        }
        for (int size2 = nodeList.size(); size2 > 0; size2--) {
            Node nextToRemove = pairTable.nextToRemove();
            this.positions.put(nextToRemove, new Integer(size2 - 1));
            nodeList.removeNode(nextToRemove);
            pairTable.removeVariable(nextToRemove);
            this.deletionSequence.insertNode(nextToRemove);
            RelationList relationsOfAndRemove = initialRelations.getRelationsOfAndRemove(nextToRemove);
            Relation elementAt = relationsOfAndRemove.elementAt(0);
            pairTable.removeRelation(elementAt);
            PotentialTree potentialTree = (PotentialTree) elementAt.getValues();
            for (int i3 = 1; i3 < relationsOfAndRemove.size(); i3++) {
                Relation elementAt2 = relationsOfAndRemove.elementAt(i3);
                pairTable.removeRelation(elementAt2);
                potentialTree = (PotentialTree) potentialTree.combine((PotentialTree) elementAt2.getValues());
            }
            this.samplingDistributions.addElement(potentialTree);
            if (size2 > 1) {
                potentialTree = (PotentialTree) potentialTree.addVariable(nextToRemove);
                potentialTree.limitBound(this.limitForPrunning);
            }
            if (this.limitSize > 0) {
                potentialTree = (PotentialTree) potentialTree.sortAndBound(this.limitSize);
            }
            for (int size3 = potentialTree.getVariables().size() - 1; size3 >= 0; size3--) {
                FiniteStates finiteStates2 = (FiniteStates) potentialTree.getVariables().elementAt(size3);
                if (!potentialTree.getTree().isIn(finiteStates2) && initialRelations.isIn(finiteStates2)) {
                    potentialTree.getVariables().removeElementAt(potentialTree.getVariables().indexOf(finiteStates2));
                }
            }
            Relation relation = new Relation();
            relation.setKind(1);
            relation.getVariables().setNodes((Vector) potentialTree.getVariables().clone());
            relation.setValues(potentialTree);
            initialRelations.insertRelation(relation);
            pairTable.addRelation(relation);
        }
    }

    public boolean simulateConfiguration(Random random) {
        boolean z = true;
        int size = this.samplingDistributions.size() - 1;
        int i = size;
        while (true) {
            if (i < 0) {
                break;
            }
            int simulateValue = simulateValue((FiniteStates) this.deletionSequence.elementAt(i), size - i, (PotentialTree) this.samplingDistributions.elementAt(i), random);
            if (simulateValue == -1) {
                z = false;
                break;
            }
            this.currentConf[size - i] = simulateValue;
            i--;
        }
        return z;
    }

    public int simulateValue(FiniteStates finiteStates, int i, PotentialTree potentialTree, Random random) {
        int i2 = -1;
        double d = 0.0d;
        double d2 = 0.0d;
        int numStates = finiteStates.getNumStates();
        double[] dArr = new double[numStates];
        for (int i3 = 0; i3 < numStates; i3++) {
            this.currentConf[i] = i3;
            dArr[i3] = potentialTree.getValue(this.positions, this.currentConf);
            d += dArr[i3];
        }
        if (d == KStarConstants.FLOOR) {
            return -1;
        }
        double nextDouble = random.nextDouble();
        int i4 = 0;
        while (true) {
            if (i4 >= numStates) {
                break;
            }
            d2 += dArr[i4] / d;
            if (nextDouble <= d2) {
                i2 = i4;
                break;
            }
            i4++;
        }
        this.currentWeight /= dArr[i2] / d;
        return i2;
    }

    public void simulate() {
        Random random = new Random();
        this.currentConf = new int[this.network.getNodeList().size()];
        int i = 0;
        while (i < this.sampleSize) {
            this.currentWeight = 1.0d;
            if (simulateConfiguration(random)) {
                this.currentWeight *= evaluate();
                updateSimulationInformation();
                i++;
            }
        }
    }

    public double evaluate(Configuration configuration) {
        double d = 1.0d;
        int size = this.initialRelations.size();
        for (int i = 0; i < size; i++) {
            d *= ((PotentialTree) this.initialRelations.elementAt(i).getValues()).getValue(configuration);
        }
        return d;
    }

    public double evaluate() {
        double d = 1.0d;
        int size = this.initialRelations.size();
        for (int i = 0; i < size; i++) {
            PotentialTree potentialTree = (PotentialTree) this.initialRelations.elementAt(i).getValues();
            if (potentialTree.getVariables().size() > 0) {
                d *= potentialTree.getValue(this.positions, this.currentConf);
            }
        }
        return d;
    }

    public void restrictToObservations(RelationList relationList) {
        int size = relationList.size();
        for (int i = 0; i < size; i++) {
            Relation elementAt = relationList.elementAt(i);
            elementAt.setValues(((PotentialTree) elementAt.getValues()).restrictVariable(this.observations));
            elementAt.getVariables().setNodes(elementAt.getValues().getVariables());
        }
    }

    public void propagate(String str) throws IOException {
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        String concat = new String(str).concat(new String(".et"));
        double[] dArr = new double[2];
        double[] dArr2 = new double[2];
        this.initialRelations = getInitialRelations();
        if (this.observations.size() > 0) {
            restrictToObservations(this.initialRelations);
        }
        double time = new Date().getTime();
        System.out.println("Computing sampling distributions");
        getSamplingDistributions();
        System.out.println("Sampling distributions computed");
        double time2 = (new Date().getTime() - time) / 1000.0d;
        initSimulationInformation();
        System.out.println("Simulating");
        PrintWriter printWriter = new PrintWriter(new FileWriter(concat));
        printWriter.println("TIME\tERROR");
        for (int i = 0; i < this.numberOfRuns; i++) {
            double time3 = new Date().getTime();
            simulate();
            normalizeResults();
            double time4 = (new Date().getTime() - time3) / 1000.0d;
            d3 += time4;
            computeError(dArr2);
            d += dArr2[0];
            d2 += dArr2[1];
            printWriter.println(time4 + "\t" + dArr2[0]);
            if (i < this.numberOfRuns - 1) {
                clearSimulationInformation();
            }
        }
        printWriter.close();
        FileWriter fileWriter = new FileWriter(str);
        PrintWriter printWriter2 = new PrintWriter(fileWriter);
        printWriter2.println("Time computing sampling distributions (secs): " + time2);
        printWriter2.println("Time simulating (avg) : " + (d3 / this.numberOfRuns));
        printWriter2.println("G : " + (d / this.numberOfRuns));
        printWriter2.println("MSE : " + (d2 / this.numberOfRuns));
        printWriter2.println("Variance : " + varianceOfWeights());
        fileWriter.close();
        System.out.println("Done");
    }

    public void propagate() {
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        double[] dArr = new double[2];
        this.initialRelations = getInitialRelations();
        if (this.observations.size() > 0) {
            restrictToObservations(this.initialRelations);
        }
        double time = new Date().getTime();
        System.out.println("Computing sampling distributions");
        getSamplingDistributions();
        System.out.println("Sampling distributions computed");
        double time2 = (new Date().getTime() - time) / 1000.0d;
        initSimulationInformation();
        System.out.println("Simulating");
        for (int i = 0; i < this.numberOfRuns; i++) {
            double time3 = new Date().getTime();
            simulate();
            normalizeResults();
            d3 += (new Date().getTime() - time3) / 1000.0d;
            if (this.exactResults != null) {
                computeError(dArr);
            }
            d += dArr[0];
            d2 += dArr[1];
            if (i < this.numberOfRuns - 1) {
                clearSimulationInformation();
            }
        }
        System.out.println("Time computing sampling distributions (secs): " + time2);
        System.out.println("Time simulating (avg) : " + (d3 / this.numberOfRuns));
        System.out.println("G : " + (d / this.numberOfRuns));
        System.out.println("MSE : " + (d2 / this.numberOfRuns));
        System.out.println("Variance : " + varianceOfWeights());
        System.out.println("Done");
    }
}
