package elvira.inference.approximate;

import elvira.Bnet;
import elvira.Configuration;
import elvira.Evidence;
import elvira.FiniteStates;
import elvira.Node;
import elvira.NodeList;
import elvira.PairTable;
import elvira.Relation;
import elvira.RelationList;
import elvira.parser.ParseException;
import elvira.potential.CanonicalPotential;
import elvira.potential.PotentialTable;
import elvira.potential.PotentialTree;
import java.io.FileInputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Date;
import java.util.Hashtable;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.lazy.kstar.KStarConstants;

/* loaded from: input_file:bayelvira-1.0-SNAPSHOT.jar:elvira/inference/approximate/ImportanceSamplingTable.class */
public class ImportanceSamplingTable extends ImportanceSampling {
    public static void main(String[] strArr) throws ParseException, IOException {
        if (strArr.length < 4) {
            System.out.println("Too few arguments");
            return;
        }
        if (strArr.length < 6) {
            Bnet bnet = new Bnet(new FileInputStream(strArr[0]));
            ImportanceSamplingTable importanceSamplingTable = new ImportanceSamplingTable(bnet, strArr.length == 5 ? new Evidence(new FileInputStream(strArr[4]), bnet.getNodeList()) : new Evidence(), Integer.valueOf(strArr[2]).intValue(), Integer.valueOf(strArr[3]).intValue(), 1);
            importanceSamplingTable.propagate();
            importanceSamplingTable.saveResults(strArr[1]);
            return;
        }
        Bnet bnet2 = new Bnet(new FileInputStream(strArr[3]));
        ImportanceSamplingTable importanceSamplingTable2 = new ImportanceSamplingTable(bnet2, strArr.length == 7 ? new Evidence(new FileInputStream(strArr[6]), bnet2.getNodeList()) : new Evidence(), Integer.valueOf(strArr[0]).intValue(), Integer.valueOf(strArr[1]).intValue(), Integer.valueOf(strArr[2]).intValue());
        System.out.println("Reading exact results");
        importanceSamplingTable2.readExactResults(strArr[5]);
        System.out.println("Done");
        importanceSamplingTable2.propagate(strArr[4]);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public ImportanceSamplingTable() {
    }

    ImportanceSamplingTable(Bnet bnet, Evidence evidence) {
        this.observations = evidence;
        this.network = bnet;
        this.positions = new Hashtable(20);
    }

    public ImportanceSamplingTable(Bnet bnet, Evidence evidence, int i, int i2, int i3) {
        this.observations = evidence;
        this.network = bnet;
        setLimitSize(i);
        setSampleSize(i2);
        setNumberOfRuns(i3);
        this.positions = new Hashtable(20);
    }

    @Override // elvira.inference.Propagation
    public RelationList getInitialRelations() {
        RelationList relationList = new RelationList();
        for (int i = 0; i < this.network.getRelationList().size(); i++) {
            Relation relation = (Relation) this.network.getRelationList().elementAt(i);
            if (relation.getActive()) {
                Relation relation2 = new Relation();
                relation2.setVariables(relation.getVariables().copy());
                if (relation.getValues().getClass().getName().equals("elvira.potential.PotentialTree")) {
                    relation2.setValues(new PotentialTable((PotentialTree) relation.getValues()));
                } else if (relation.getValues().getClass().getName().equals("elvira.potential.CanonicalPotential")) {
                    relation2.setValues(((CanonicalPotential) relation.getValues()).getCPT());
                } else {
                    relation2.setValues(relation.getValues());
                }
                relationList.insertRelation(relation2);
            }
        }
        return relationList;
    }

    public void getSamplingDistributions() {
        double d = 1.0E20d;
        int i = 0;
        int i2 = 0;
        NodeList nodeList = new NodeList();
        PairTable pairTable = new PairTable();
        this.deletionSequence = new NodeList();
        this.samplingDistributions = new Vector();
        for (int i3 = 0; i3 < this.network.getNodeList().size(); i3++) {
            FiniteStates finiteStates = (FiniteStates) this.network.getNodeList().elementAt(i3);
            if (!this.observations.isObserved(finiteStates)) {
                nodeList.insertNode(finiteStates);
                pairTable.addElement(finiteStates);
            }
        }
        RelationList initialRelations = getInitialRelations();
        if (this.observations.size() > 0) {
            restrictToObservations(initialRelations);
        }
        for (int i4 = 0; i4 < initialRelations.size(); i4++) {
            pairTable.addRelation(initialRelations.elementAt(i4));
        }
        for (int size = nodeList.size(); size > 0; size--) {
            Node nextToRemove = pairTable.nextToRemove();
            this.positions.put(nextToRemove, new Integer(size - 1));
            nodeList.removeNode(nextToRemove);
            pairTable.removeVariable(nextToRemove);
            this.deletionSequence.insertNode(nextToRemove);
            RelationList relationsOfAndRemove = initialRelations.getRelationsOfAndRemove(nextToRemove);
            for (int i5 = 0; i5 < relationsOfAndRemove.size(); i5++) {
                pairTable.removeRelation(relationsOfAndRemove.elementAt(i5));
            }
            this.samplingDistributions.addElement(relationsOfAndRemove.copy());
            boolean z = true;
            while (z) {
                z = false;
                int size2 = relationsOfAndRemove.size();
                double d2 = 1.0E20d;
                for (int i6 = 0; i6 < size2 - 1; i6++) {
                    RelationList relationList = new RelationList();
                    Relation elementAt = relationsOfAndRemove.elementAt(i6);
                    relationList.insertRelation(elementAt);
                    double length = ((PotentialTable) elementAt.getValues()).getValues().length;
                    for (int i7 = i6 + 1; i7 < size2; i7++) {
                        Relation elementAt2 = relationsOfAndRemove.elementAt(i7);
                        double length2 = ((PotentialTable) elementAt2.getValues()).getValues().length;
                        double d3 = length > length2 ? length : length2;
                        relationList.insertRelation(elementAt2);
                        double d4 = relationList.totalSize();
                        double d5 = d4 - d3;
                        if (d5 < d2) {
                            i = i6;
                            i2 = i7;
                            z = true;
                            d2 = d5;
                            d = d4;
                        }
                        relationList.removeRelationAt(1);
                    }
                }
                if (!z || d > this.limitSize) {
                    z = false;
                } else {
                    Relation elementAt3 = relationsOfAndRemove.elementAt(i);
                    Relation elementAt4 = relationsOfAndRemove.elementAt(i2);
                    relationsOfAndRemove.removeRelationAt(i2);
                    relationsOfAndRemove.removeRelationAt(i);
                    PotentialTable combine = ((PotentialTable) elementAt3.getValues()).combine((PotentialTable) elementAt4.getValues());
                    Relation relation = new Relation();
                    relation.setKind(1);
                    relation.getVariables().setNodes((Vector) combine.getVariables().clone());
                    relation.setValues(combine);
                    relationsOfAndRemove.insertRelation(relation);
                }
            }
            if (size > 1) {
                for (int i8 = 0; i8 < relationsOfAndRemove.size(); i8++) {
                    Relation elementAt5 = relationsOfAndRemove.elementAt(i8);
                    if (elementAt5.getVariables().size() > 1) {
                        PotentialTable potentialTable = (PotentialTable) ((PotentialTable) elementAt5.getValues()).addVariable(nextToRemove);
                        elementAt5.setKind(1);
                        elementAt5.getVariables().setNodes((Vector) potentialTable.getVariables().clone());
                        elementAt5.setValues(potentialTable);
                        initialRelations.insertRelation(elementAt5);
                        pairTable.addRelation(elementAt5);
                    }
                }
            }
        }
    }

    public boolean simulateConfiguration(Random random) {
        boolean z = true;
        int size = this.samplingDistributions.size() - 1;
        int i = size;
        while (true) {
            if (i < 0) {
                break;
            }
            int simulateValue = simulateValue((FiniteStates) this.deletionSequence.elementAt(i), size - i, (RelationList) this.samplingDistributions.elementAt(i), random);
            if (simulateValue == -1) {
                z = false;
                break;
            }
            this.currentConf[size - i] = simulateValue;
            i--;
        }
        return z;
    }

    public int simulateValue(FiniteStates finiteStates, int i, RelationList relationList, Random random) {
        int i2 = -1;
        double d = 0.0d;
        double d2 = 0.0d;
        int numStates = finiteStates.getNumStates();
        double[] dArr = new double[numStates];
        for (int i3 = 0; i3 < numStates; i3++) {
            this.currentConf[i] = i3;
            dArr[i3] = evaluate(relationList);
            d += dArr[i3];
        }
        if (d == KStarConstants.FLOOR) {
            System.out.println("Zero valuation");
            return -1;
        }
        double nextDouble = random.nextDouble();
        int i4 = 0;
        while (true) {
            if (i4 >= numStates) {
                break;
            }
            d2 += dArr[i4] / d;
            if (nextDouble <= d2) {
                i2 = i4;
                break;
            }
            i4++;
        }
        this.currentWeight /= dArr[i2] / d;
        return i2;
    }

    public void simulate() {
        Random random = new Random();
        this.currentConf = new int[this.network.getNodeList().size()];
        for (int i = 0; i < this.sampleSize; i++) {
            this.currentWeight = 1.0d;
            if (simulateConfiguration(random)) {
                this.currentWeight *= evaluate();
                updateSimulationInformation();
            }
        }
    }

    public double evaluate(RelationList relationList, Configuration configuration) {
        double d = 1.0d;
        int size = relationList.size();
        for (int i = 0; i < size; i++) {
            d *= ((PotentialTable) relationList.elementAt(i).getValues()).getValue(configuration);
        }
        return d;
    }

    public double evaluate(RelationList relationList) {
        double d = 1.0d;
        int size = relationList.size();
        for (int i = 0; i < size; i++) {
            PotentialTable potentialTable = (PotentialTable) relationList.elementAt(i).getValues();
            if (potentialTable.getVariables().size() > 0) {
                d *= potentialTable.getValue(this.positions, this.currentConf);
            }
        }
        return d;
    }

    public double evaluate(Configuration configuration) {
        double d = 1.0d;
        int size = this.initialRelations.size();
        for (int i = 0; i < size; i++) {
            PotentialTable potentialTable = (PotentialTable) this.initialRelations.elementAt(i).getValues();
            if (potentialTable.getVariables().size() > 0) {
                d *= potentialTable.getValue(configuration);
            }
        }
        return d;
    }

    public double evaluate() {
        double d = 1.0d;
        int size = this.initialRelations.size();
        for (int i = 0; i < size; i++) {
            PotentialTable potentialTable = (PotentialTable) this.initialRelations.elementAt(i).getValues();
            if (potentialTable.getVariables().size() > 0) {
                d *= potentialTable.getValue(this.positions, this.currentConf);
            }
        }
        return d;
    }

    public void restrictToObservations(RelationList relationList) {
        int size = relationList.size();
        for (int i = 0; i < size; i++) {
            Relation elementAt = relationList.elementAt(i);
            elementAt.setValues(((PotentialTable) elementAt.getValues()).restrictVariable(this.observations));
            elementAt.getVariables().setNodes(elementAt.getValues().getVariables());
        }
    }

    public void propagate(String str) throws IOException {
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        double[] dArr = new double[2];
        this.initialRelations = getInitialRelations();
        if (this.observations.size() > 0) {
            restrictToObservations(this.initialRelations);
        }
        double time = new Date().getTime();
        System.out.println("Computing sampling distributions");
        getSamplingDistributions();
        System.out.println("Sampling distributions computed");
        double time2 = (new Date().getTime() - time) / 1000.0d;
        initSimulationInformation();
        System.out.println("Simulating");
        for (int i = 0; i < this.numberOfRuns; i++) {
            double time3 = new Date().getTime();
            simulate();
            normalizeResults();
            d3 += (new Date().getTime() - time3) / 1000.0d;
            computeError(dArr);
            d += dArr[0];
            d2 += dArr[1];
            if (i < this.numberOfRuns - 1) {
                clearSimulationInformation();
            }
        }
        FileWriter fileWriter = new FileWriter(str);
        PrintWriter printWriter = new PrintWriter(fileWriter);
        printWriter.println("Time computing sampling distributions (secs): " + time2);
        printWriter.println("Time simulating (avg) : " + (d3 / this.numberOfRuns));
        printWriter.println("G : " + (d / this.numberOfRuns));
        printWriter.println("MSE : " + (d2 / this.numberOfRuns));
        printWriter.println("Variance : " + varianceOfWeights());
        fileWriter.close();
        System.out.println("Done");
    }

    public void propagate() {
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        double[] dArr = new double[2];
        this.initialRelations = getInitialRelations();
        if (this.observations.size() > 0) {
            restrictToObservations(this.initialRelations);
        }
        double time = new Date().getTime();
        System.out.println("Computing sampling distributions");
        getSamplingDistributions();
        System.out.println("Sampling distributions computed");
        double time2 = (new Date().getTime() - time) / 1000.0d;
        initSimulationInformation();
        System.out.println("Simulating");
        for (int i = 0; i < this.numberOfRuns; i++) {
            double time3 = new Date().getTime();
            simulate();
            normalizeResults();
            d3 += (new Date().getTime() - time3) / 1000.0d;
            if (this.exactResults != null) {
                computeError(dArr);
            }
            d += dArr[0];
            d2 += dArr[1];
            if (i < this.numberOfRuns - 1) {
                clearSimulationInformation();
            }
        }
        System.out.println("Time computing sampling distributions (secs): " + time2);
        System.out.println("Time simulating (avg) : " + (d3 / this.numberOfRuns));
        System.out.println("G : " + (d / this.numberOfRuns));
        System.out.println("MSE : " + (d2 / this.numberOfRuns));
        System.out.println("Variance : " + varianceOfWeights());
        System.out.println("Done");
    }
}
