package elvira.inference.clustering;

import elvira.Bnet;
import elvira.Evidence;
import elvira.FiniteStates;
import elvira.Node;
import elvira.NodeList;
import elvira.Relation;
import elvira.RelationList;
import elvira.inference.Propagation;
import elvira.parser.ParseException;
import elvira.potential.Potential;
import elvira.potential.PotentialTree;
import elvira.potential.ProbabilityTree;
import java.io.FileInputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Date;
import java.util.Hashtable;
import java.util.Vector;

/* loaded from: input_file:bayelvira-1.0-SNAPSHOT.jar:elvira/inference/clustering/SimplePenniless.class */
public class SimplePenniless extends Propagation {
    JoinTree binTree;
    public int maximumSize;
    public boolean sortAndBound;
    public double limitForPruning;
    public double limitSumForPruning;
    public double lowLimitForPruning;
    public Hashtable marginalCliques;
    int kindOfApprPruning;

    public static void main(String[] strArr) throws ParseException, IOException {
        Evidence evidence;
        if (strArr.length < 11) {
            System.out.print("Too few arguments. Arguments are: ElviraFile");
            System.out.print(" OutputFile OutputErrorFile InputExactResultsFile");
            System.out.print(" kindOfApprPruning(AVERAGE|ZERO) LimitForPruning LowLimitForPruning LimitSumForPruning MaxTreeSize");
            System.out.println(" SortAndBound(true|false) TriangulationMethod(0|1|2) [EvidenceFile]");
            return;
        }
        Bnet bnet = new Bnet(new FileInputStream(strArr[0]));
        if (strArr.length == 12) {
            evidence = new Evidence(new FileInputStream(strArr[11]), bnet.getNodeList());
            System.out.println("Evidence file" + strArr[11]);
        } else {
            evidence = new Evidence();
        }
        double doubleValue = Double.valueOf(strArr[5]).doubleValue();
        double doubleValue2 = Double.valueOf(strArr[6]).doubleValue();
        double doubleValue3 = Double.valueOf(strArr[7]).doubleValue();
        int intValue = Integer.valueOf(strArr[10]).intValue();
        System.out.println("limit for pruning: " + doubleValue);
        System.out.println("limit sum for pruning: " + doubleValue3);
        int intValue2 = Integer.valueOf(strArr[8]).intValue();
        double time = new Date().getTime();
        boolean booleanValue = Boolean.valueOf(strArr[9]).booleanValue();
        SimplePenniless simplePenniless = new SimplePenniless(bnet, evidence, doubleValue, doubleValue2, doubleValue3, intValue2, booleanValue, intValue);
        simplePenniless.setKindOfApprPruning(strArr[4]);
        double time2 = (new Date().getTime() - time) / 1000.0d;
        double time3 = new Date().getTime();
        simplePenniless.propagate(strArr[1]);
        double time4 = (new Date().getTime() - time3) / 1000.0d;
        FileWriter fileWriter = new FileWriter(strArr[2]);
        PrintWriter printWriter = new PrintWriter(fileWriter);
        printWriter.println("Low limit: " + doubleValue);
        printWriter.println("Low limit for pruning: " + doubleValue2);
        printWriter.println("Limit sum for pruning: " + doubleValue3);
        printWriter.println("Max size when bound: " + intValue2);
        printWriter.println("Sort and Bound: " + booleanValue);
        printWriter.println("Triangulation method: " + intValue);
        printWriter.println("Time compiling (secs) : " + time2);
        printWriter.println("Time propagating (secs) : " + time4);
        if (!strArr[3].equals("NORESULTS")) {
            System.out.println("Reading exact results");
            simplePenniless.readExactResults(strArr[3]);
            System.out.println("Exact results read");
            System.out.println("Computing errors");
            double[] dArr = new double[2];
            simplePenniless.computeError(dArr);
            double computeMaxAbsoluteError = simplePenniless.computeMaxAbsoluteError();
            double d = dArr[0];
            double d2 = dArr[1];
            simplePenniless.computeKLError(dArr);
            printWriter.println("G : " + d);
            printWriter.println("MSE : " + d2);
            printWriter.println("Max absoulte error : " + computeMaxAbsoluteError);
            printWriter.println("KL error : " + dArr[0]);
            printWriter.println("Std. deviation of KL-error : " + dArr[1]);
        }
        simplePenniless.binTree.calculateStatistics();
        simplePenniless.binTree.saveStatistics(printWriter);
        fileWriter.close();
        System.out.println("Done");
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public SimplePenniless() {
        this.kindOfApprPruning = 0;
    }

    SimplePenniless(Bnet bnet, Evidence evidence, double d, double d2, double d3, int i, boolean z, int i2) {
        this();
        setMaximumSize(i);
        setSortAndBound(z);
        setLimitForPruning(d);
        setLowLimitForPruning(d2);
        setLimitSumForPruning(d3);
        this.observations = evidence;
        this.network = bnet;
        this.positions = new Hashtable();
        if (i2 == 2) {
            this.binTree = new JoinTree(bnet);
        } else {
            this.binTree = new JoinTree(bnet, evidence, i2);
        }
        RelationList initialRelations = getInitialRelations();
        initialRelations.restrictToObservations(this.observations);
        for (int i3 = 0; i3 < initialRelations.size(); i3++) {
            initialRelations.elementAt(i3).getValues().limitBound(this.lowLimitForPruning);
        }
        this.marginalCliques = this.binTree.Leaves(initialRelations);
        this.binTree.binTree();
        this.binTree.setLabels();
    }

    public void setKindOfApprPruning(String str) {
        if (str.equals("AVERAGE")) {
            this.kindOfApprPruning = 0;
        } else if (str.equals("ZERO")) {
            this.kindOfApprPruning = 1;
        } else {
            System.out.println("Error in SimplePenniless.setKindOfApprPruning: ilegal value for kind=" + str);
            System.exit(1);
        }
    }

    public void setLimitForPruning(double d) {
        this.limitForPruning = 1.0d + (((0.5d - d) * Math.log(0.5d - d)) / Math.log(2.0d)) + (((0.5d + d) * Math.log(0.5d + d)) / Math.log(2.0d));
    }

    public void setLimitSumForPruning(double d) {
        this.limitSumForPruning = d;
    }

    public void setLowLimitForPruning(double d) {
        this.lowLimitForPruning = 1.0d + (((0.5d - d) * Math.log(0.5d - d)) / Math.log(2.0d)) + (((0.5d + d) * Math.log(0.5d + d)) / Math.log(2.0d));
    }

    public void setMaximumSize(int i) {
        this.maximumSize = i;
    }

    public void setSortAndBound(boolean z) {
        this.sortAndBound = z;
    }

    public void initMessages() {
        for (int i = 0; i < this.binTree.getJoinTreeNodes().size(); i++) {
            NodeJoinTree elementAt = this.binTree.elementAt(i);
            Relation nodeRelation = elementAt.getNodeRelation();
            if (nodeRelation.getValues() == null) {
                nodeRelation.setValues(makeUnitPotential(nodeRelation.getVariables()));
            }
            NeighbourTreeList neighbourList = elementAt.getNeighbourList();
            for (int i2 = 0; i2 < neighbourList.size(); i2++) {
                Relation message = neighbourList.elementAt(i2).getMessage();
                message.setValues(makeUnitPotential(message.getVariables()));
                message.setOtherValues(makeUnitPotential(message.getVariables()));
            }
        }
    }

    public void propagate(String str) throws ParseException, IOException {
        this.binTree.setLabels();
        System.out.println("Initializing messages");
        double time = new Date().getTime();
        initMessages();
        System.out.println("Time Initializing messages: " + ((new Date().getTime() - time) / 1000.0d));
        System.out.println("Starting propagation");
        navigate(this.binTree.elementAt(0));
        System.out.println("Propagation done");
        System.out.println("Computing marginals");
        double time2 = new Date().getTime();
        computeMarginals();
        System.out.println("Time computeMarginals: " + ((new Date().getTime() - time2) / 1000.0d));
        System.out.println("Done");
        saveResults(str);
    }

    protected void navigate(NodeJoinTree nodeJoinTree) {
        double time = new Date().getTime();
        navigateUp(nodeJoinTree);
        System.out.println("Time navigateUp: " + ((new Date().getTime() - time) / 1000.0d));
        double time2 = new Date().getTime();
        navigateDown(nodeJoinTree);
        System.out.println("Time navigateDown: " + ((new Date().getTime() - time2) / 1000.0d));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void navigateUp(NodeJoinTree nodeJoinTree) {
        NeighbourTreeList neighbourList = nodeJoinTree.getNeighbourList();
        for (int i = 0; i < neighbourList.size(); i++) {
            navigateUp(nodeJoinTree, neighbourList.elementAt(i).getNeighbour());
        }
    }

    private void navigateUp(NodeJoinTree nodeJoinTree, NodeJoinTree nodeJoinTree2) {
        NeighbourTreeList neighbourList = nodeJoinTree2.getNeighbourList();
        for (int i = 0; i < neighbourList.size(); i++) {
            NodeJoinTree neighbour = neighbourList.elementAt(i).getNeighbour();
            if (neighbour.getLabel() != nodeJoinTree.getLabel()) {
                navigateUp(nodeJoinTree2, neighbour);
            }
        }
        sendMessage(nodeJoinTree2, nodeJoinTree, false);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void navigateDown(NodeJoinTree nodeJoinTree) {
        NeighbourTreeList neighbourList = nodeJoinTree.getNeighbourList();
        for (int i = 0; i < neighbourList.size(); i++) {
            NodeJoinTree neighbour = neighbourList.elementAt(i).getNeighbour();
            sendMessage(nodeJoinTree, neighbour, true);
            navigateDown(nodeJoinTree, neighbour);
        }
    }

    public void navigateDown(NodeJoinTree nodeJoinTree, NodeJoinTree nodeJoinTree2) {
        NeighbourTreeList neighbourList = nodeJoinTree2.getNeighbourList();
        for (int i = 0; i < neighbourList.size(); i++) {
            NodeJoinTree neighbour = neighbourList.elementAt(i).getNeighbour();
            if (neighbour.getLabel() != nodeJoinTree.getLabel()) {
                sendMessage(nodeJoinTree2, neighbour, true);
                navigateDown(nodeJoinTree2, neighbour);
            }
        }
    }

    @Override // elvira.inference.Propagation
    public RelationList getInitialRelations() {
        RelationList relationList = new RelationList();
        for (int i = 0; i < this.network.getRelationList().size(); i++) {
            Relation relation = (Relation) this.network.getRelationList().elementAt(i);
            Relation relation2 = new Relation();
            relation2.setVariables(relation.getVariables().copy());
            relation2.setValues(convertPotential(relation.getValues()));
            relation2.setKind(relation.getKind());
            relationList.insertRelation(relation2);
        }
        return relationList;
    }

    public void computeMarginals() {
        this.binTree.getLeaves();
        NodeList nodeList = this.network.getNodeList();
        int size = nodeList.size();
        int i = 0;
        for (int i2 = 0; i2 < size; i2++) {
            FiniteStates finiteStates = (FiniteStates) nodeList.elementAt(i2);
            NodeJoinTree nodeJoinTree = (NodeJoinTree) this.marginalCliques.get(finiteStates);
            if (nodeJoinTree != null) {
                Potential values = nodeJoinTree.getNodeRelation().getValues();
                for (int i3 = 0; i3 < nodeJoinTree.getNeighbourList().size(); i3++) {
                    values = values.combine(nodeJoinTree.getNeighbourList().elementAt(i3).getMessage().getOtherValues());
                }
                Vector vector = new Vector();
                vector.addElement(finiteStates);
                Potential marginalizePotential = values.marginalizePotential(vector);
                this.results.addElement(marginalizePotential.normalize(marginalizePotential));
                this.positions.put(finiteStates, new Integer(i));
                i++;
            }
        }
    }

    public void sendMessage(NodeJoinTree nodeJoinTree, NodeJoinTree nodeJoinTree2, boolean z) {
        Relation relation = new Relation();
        Relation relation2 = new Relation();
        Vector<Node> vector = new Vector<>();
        Potential values = nodeJoinTree.getNodeRelation().getValues();
        NeighbourTreeList neighbourList = nodeJoinTree.getNeighbourList();
        for (int i = 0; i < neighbourList.size(); i++) {
            NeighbourTree elementAt = neighbourList.elementAt(i);
            int label = elementAt.getNeighbour().getLabel();
            Relation message = elementAt.getMessage();
            if (label != nodeJoinTree2.getLabel()) {
                values = values.combine(message.getOtherValues());
            } else {
                relation2 = message;
                vector = message.getVariables().getNodes();
                relation = nodeJoinTree2.getNeighbourList().getMessage(nodeJoinTree);
            }
        }
        if (z) {
            values.limitBound(this.kindOfApprPruning, this.limitForPruning, this.limitSumForPruning);
            if (this.sortAndBound) {
                values = values.sortAndBound(this.maximumSize);
            }
        }
        Potential marginalizePotential = values.marginalizePotential(vector);
        if (z) {
            marginalizePotential.limitBound(this.kindOfApprPruning, this.limitForPruning, this.limitSumForPruning);
        } else {
            marginalizePotential.limitBound(0, this.limitForPruning, this.limitSumForPruning);
        }
        if (this.sortAndBound) {
            marginalizePotential = marginalizePotential.sortAndBound(this.maximumSize);
        }
        relation2.setValues(marginalizePotential);
        relation.setOtherValues(marginalizePotential);
    }

    Potential convertPotential(Potential potential) {
        return potential.getClassName().equals("PotentialTree") ? potential : new PotentialTree(potential);
    }

    Potential makeUnitPotential(NodeList nodeList) {
        PotentialTree potentialTree = new PotentialTree(nodeList);
        potentialTree.setTree(ProbabilityTree.unitTree());
        return potentialTree;
    }
}
