package elvira.inference.clustering;

import elvira.Bnet;
import elvira.Evidence;
import elvira.Node;
import elvira.NodeList;
import elvira.Relation;
import elvira.RelationList;
import elvira.parser.ParseException;
import elvira.potential.ListPotential;
import elvira.potential.Potential;
import elvira.tools.FactorisationTools;
import java.io.FileInputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Date;
import java.util.Hashtable;
import java.util.Vector;
import weka.classifiers.lazy.kstar.KStarConstants;

/* loaded from: input_file:bayelvira-1.0-SNAPSHOT.jar:elvira/inference/clustering/FactorisedSLP.class */
public class FactorisedSLP extends SimpleLazyPenniless {
    FactorisationTools factorisationParam;

    public static void main(String[] strArr) throws ParseException, IOException {
        Evidence evidence;
        Vector classStatistic;
        boolean z = false;
        if (strArr.length < 18) {
            System.out.println("Too few arguments. Arguments are:");
            System.out.println("");
            System.out.println("ElviraFile OutputFile OutputErrorFile InputExactResultsFile");
            System.out.println("MethodToJoinPotentials(0|1|2|3|4|5) LimitForPruning LowLimitForPruning LimitSumForPruning");
            System.out.print("HeuristicToJoinInCreatePots(0|1|2|3|4|5) ");
            System.out.println("TriangulationMethod(0|1|2) ");
            System.out.println("FactorisationMethod(0:split|1:fact|2:split&fact) ");
            System.out.println("FactorisationPhase(0:compil|1:propag|2:compil&propag) ");
            System.out.println("ApproximationMethod(0:aver|1:WeigAver|2:Chi|3:MSE|4:WMSE|5:KL|6:WP|7:Hel) ");
            System.out.println("DistanceTreesMethod(-1:none|1:Chi|2:NormChi|3:MSE|4:WMSE|5:KL|6:MAD|7:Hel) ");
            System.out.println("FactorisationError_Nodes(-1 for none)  FactorisationError_Trees(-1 for none)  ProportChildren  maxLevel ");
            System.out.println("[EvidenceFile]");
            return;
        }
        Bnet bnet = new Bnet(new FileInputStream(strArr[0]));
        if (strArr.length == 18 + 1) {
            evidence = new Evidence(new FileInputStream(strArr[18]), bnet.getNodeList());
            System.out.println("Evidence file: " + strArr[18]);
        } else {
            evidence = new Evidence();
        }
        int intValue = Integer.valueOf(strArr[4]).intValue();
        double doubleValue = Double.valueOf(strArr[5]).doubleValue();
        double doubleValue2 = Double.valueOf(strArr[6]).doubleValue();
        double doubleValue3 = Double.valueOf(strArr[7]).doubleValue();
        int intValue2 = Integer.valueOf(strArr[8]).intValue();
        int intValue3 = Integer.valueOf(strArr[9]).intValue();
        int intValue4 = Integer.valueOf(strArr[10]).intValue();
        int intValue5 = Integer.valueOf(strArr[11]).intValue();
        int intValue6 = Integer.valueOf(strArr[12]).intValue();
        int intValue7 = Integer.valueOf(strArr[13]).intValue();
        double doubleValue4 = Double.valueOf(strArr[14]).doubleValue();
        double doubleValue5 = Double.valueOf(strArr[15]).doubleValue();
        double doubleValue6 = Double.valueOf(strArr[16]).doubleValue();
        double doubleValue7 = Double.valueOf(strArr[17]).doubleValue();
        if (intValue7 < 0) {
            doubleValue5 = -1.0d;
            if (doubleValue4 < KStarConstants.FLOOR) {
                System.out.println("->ERROR: No Divergence (Trees) method selected. Setting the FactorisationError_Nodes = 0");
                doubleValue4 = 0.0d;
            }
        }
        if (doubleValue5 < KStarConstants.FLOOR && doubleValue4 < KStarConstants.FLOOR) {
            System.out.println("->ERROR: any factorisation error has been selected");
            System.exit(0);
        }
        System.out.println("Method to join potentials: " + intValue);
        System.out.println("Low limit: " + doubleValue);
        System.out.println("Low limit for pruning: " + doubleValue2);
        System.out.println("Limit sum for pruning: " + doubleValue3);
        System.out.println("Heuristic to JoinInCreatePots: " + intValue2);
        System.out.println("Triangulation method: " + intValue3);
        if (doubleValue5 == KStarConstants.FLOOR || doubleValue4 == KStarConstants.FLOOR) {
            z = true;
            doubleValue4 = 0.0d;
            doubleValue5 = -1.0d;
            intValue6 = -1;
            intValue7 = -2;
        }
        System.out.println(FactorisationTools.printFactoriMethod(intValue4));
        if (intValue4 > 2) {
            intValue5 = -1;
            System.out.println("Factorisation Phase: None");
        } else {
            System.out.println("Factorisation Phase: " + intValue5);
        }
        System.out.println(FactorisationTools.printApproxMethod(intValue6));
        System.out.println(FactorisationTools.printDivergenceMethod(intValue7));
        System.out.println("Factorisation error (between nodes): " + doubleValue4);
        if (z) {
            System.out.println("Factorisation error (between trees): None");
        } else {
            System.out.println("Factorisation error (between trees): " + doubleValue5);
        }
        System.out.println("Proportional Children: " + doubleValue6);
        System.out.println("Max Factorisation level: " + doubleValue7);
        double time = new Date().getTime();
        FactorisedSLP factorisedSLP = new FactorisedSLP(bnet, evidence, intValue, doubleValue, doubleValue2, doubleValue3, intValue2, intValue3, doubleValue4, doubleValue5, doubleValue6, doubleValue7, intValue5, intValue4, intValue6, intValue7);
        double time2 = (new Date().getTime() - time) / 1000.0d;
        double time3 = new Date().getTime();
        factorisedSLP.propagate(strArr[1]);
        double time4 = (new Date().getTime() - time3) / 1000.0d;
        FileWriter fileWriter = new FileWriter(strArr[2]);
        PrintWriter printWriter = new PrintWriter(fileWriter);
        printWriter.println("Method to join potentials: " + intValue);
        printWriter.println("Low limit: " + doubleValue);
        printWriter.println("Low limit for pruning: " + doubleValue2);
        printWriter.println("Limit sum for pruning: " + doubleValue3);
        printWriter.println("Heuristic to JoinInCreatePots: " + intValue2);
        printWriter.println("Triangulation method: " + intValue3);
        printWriter.println(FactorisationTools.printFactoriMethod(intValue4));
        if (intValue4 > 2) {
            printWriter.println("Factorisation Phase: None");
        } else {
            printWriter.println("Factorisation Phase: " + intValue5);
        }
        printWriter.println(FactorisationTools.printApproxMethod(intValue6));
        printWriter.println(FactorisationTools.printDivergenceMethod(intValue7));
        printWriter.println("Factorisation error (between nodes): " + doubleValue4);
        if (z) {
            printWriter.println("Factorisation error (between trees): None");
        } else {
            printWriter.println("Factorisation error (between trees): " + doubleValue5);
        }
        printWriter.println("Proportional Children: " + doubleValue6);
        printWriter.println("Factorisation level: " + doubleValue7);
        printWriter.println("");
        printWriter.println("\nTime compiling (secs) : " + time2);
        printWriter.println("Time propagating (secs) : " + time4);
        if (!strArr[3].equals("NORESULTS")) {
            System.out.println("Reading exact results");
            factorisedSLP.readExactResults(strArr[3]);
            System.out.println("Exact results read");
            System.out.println("Computing errors");
            double[] dArr = new double[2];
            factorisedSLP.computeError(dArr);
            double computeMaxAbsoluteError = factorisedSLP.computeMaxAbsoluteError();
            double d = dArr[0];
            double d2 = dArr[1];
            factorisedSLP.computeKLError(dArr);
            printWriter.println("G : " + d);
            printWriter.println("MSE : " + d2);
            printWriter.println("Max absolute error : " + computeMaxAbsoluteError);
            printWriter.println("KL error : " + dArr[0]);
            printWriter.println("Std. deviation of KL-error : " + dArr[1]);
        }
        factorisedSLP.binTree.calculateStatistics();
        factorisedSLP.binTree.saveStatistics(printWriter);
        if (factorisedSLP.factorisationParam.sizesPot && (classStatistic = factorisedSLP.factorisationParam.getClassStatistic(0)) != null) {
            printWriter.println(".....These are the statistics about the sizes:.....");
            printWriter.println("Mean :" + ((Double) classStatistic.elementAt(0)).doubleValue());
            printWriter.println("Standard Deviation : " + ((Double) classStatistic.elementAt(1)).doubleValue());
            printWriter.println("Max : " + ((Double) classStatistic.elementAt(2)).doubleValue());
            printWriter.println("Min : " + ((Double) classStatistic.elementAt(3)).doubleValue());
            System.out.println(".....These are the statistics about the sizes:.....");
            FactorisationTools factorisationTools = factorisedSLP.factorisationParam;
            FactorisationTools.printStatistics(classStatistic);
        }
        if (intValue4 > 0) {
            Vector classStatistic2 = factorisedSLP.factorisationParam.getClassStatistic(1);
            if (classStatistic2 != null) {
                printWriter.println(".....These are the statistics about the probability factors in the approximations: .......");
                printWriter.print("Number of approximations :" + factorisedSLP.factorisationParam.vecDistApproxim.size());
                printWriter.println(" (" + factorisedSLP.factorisationParam.getCounterFCompil() + " in compilation)");
                printWriter.println("Mean (of max of probability factors) :" + ((Double) classStatistic2.elementAt(0)).doubleValue());
                printWriter.println("Standard Deviation : " + ((Double) classStatistic2.elementAt(1)).doubleValue());
                printWriter.println("Max : " + ((Double) classStatistic2.elementAt(2)).doubleValue());
                printWriter.println("Min : " + ((Double) classStatistic2.elementAt(3)).doubleValue());
                System.out.println(".....These are the statistics about the probability factors in the approximations: ......");
                System.out.print("Number of approximations :" + factorisedSLP.factorisationParam.vecDistApproxim.size());
                System.out.println(" (" + factorisedSLP.factorisationParam.getCounterFCompil() + " in compilation)");
                FactorisationTools factorisationTools2 = factorisedSLP.factorisationParam;
                FactorisationTools.printStatistics(classStatistic2);
            } else {
                printWriter.println("Number of approximations : NONE");
                System.out.println(".....No factorisation made ......");
            }
        }
        if (intValue4 != 1) {
            printWriter.println("Number of split operations : " + factorisedSLP.factorisationParam.getNumSplit());
            System.out.println(".....Number of split operations: " + factorisedSLP.factorisationParam.getNumSplit());
        }
        if (factorisedSLP.doubleCache) {
            printWriter.println("\n\n Size of cache 1: " + factorisedSLP.cache1.size());
            printWriter.println("Size of cache 2: " + factorisedSLP.cache2.size());
            printWriter.println("Size of cache marg: " + factorisedSLP.cache1M.size());
            printWriter.println("Size of cache marg: " + factorisedSLP.cache2M.size());
        }
        fileWriter.close();
        System.out.println("Done.");
    }

    public FactorisedSLP(Bnet bnet, Evidence evidence, int i, double d, double d2, double d3, int i2, int i3, double d4, double d5, double d6, double d7, int i4, int i5, int i6, int i7) {
        Potential factorisePotentialAllVbles;
        new NodeList();
        setHeuristicToJoin(i2);
        setSortAndBound(this.sortAndBound);
        setLimitForPruning(d);
        setLowLimitForPruning(d2);
        setLimitSumForPruning(d3);
        this.factorisationParam = new FactorisationTools(d4, d5, i5, i6, i7, d6, d7, i4);
        setFirstTour(true);
        setUseCache(false);
        setSortAndBound(false);
        this.observations = evidence;
        this.network = bnet;
        setWhenJoinPotentials(i);
        this.positions = new Hashtable();
        if (i3 == 2) {
            this.binTree = new JoinTree(bnet);
        } else {
            this.binTree = new JoinTree(bnet, evidence, i3);
        }
        RelationList initialRelations = getInitialRelations();
        initialRelations.restrictToObservations(this.observations);
        this.factorisationParam.setcompilPhase();
        for (int i8 = 0; i8 < initialRelations.size(); i8++) {
            Relation elementAt = initialRelations.elementAt(i8);
            ListPotential listPotential = (ListPotential) elementAt.getValues();
            listPotential.limitBound(this.lowLimitForPruning);
            if ((i4 == 0 || i4 == 2) && (factorisePotentialAllVbles = listPotential.factorisePotentialAllVbles(this.factorisationParam)) != null) {
                elementAt.setValues(factorisePotentialAllVbles);
            }
        }
        this.factorisationParam.setCounterFCompil();
        this.factorisationParam.setcompilPhase();
        setLimitForPruning(KStarConstants.FLOOR);
        setLowLimitForPruning(KStarConstants.FLOOR);
        setLimitSumForPruning(KStarConstants.FLOOR);
        this.marginalCliques = this.binTree.Leaves(initialRelations);
        this.binTree.binTree();
        this.binTree.setLabels();
    }

    @Override // elvira.inference.clustering.SimpleLazyPenniless
    public void sendMessage(NodeJoinTree nodeJoinTree, NodeJoinTree nodeJoinTree2) {
        Relation relation = new Relation();
        Relation relation2 = new Relation();
        Vector<Node> vector = new Vector<>();
        Potential values = nodeJoinTree.getNodeRelation().getValues();
        NeighbourTreeList neighbourList = nodeJoinTree.getNeighbourList();
        for (int i = 0; i < neighbourList.size(); i++) {
            NeighbourTree elementAt = neighbourList.elementAt(i);
            int label = elementAt.getNeighbour().getLabel();
            Relation message = elementAt.getMessage();
            if (label != nodeJoinTree2.getLabel()) {
                values = values.combine(message.getOtherValues());
            } else {
                relation2 = message;
                vector = message.getVariables().getNodes();
                relation = nodeJoinTree2.getNeighbourList().getMessage(nodeJoinTree);
            }
        }
        if (this.whenJoinPotentials != 0 && this.whenJoinPotentials != 5) {
            if (this.whenJoinPotentials == 1 || this.whenJoinPotentials == 2) {
                NodeList varsToCombine = varsToCombine(nodeJoinTree, nodeJoinTree2, this.whenJoinPotentials);
                for (int i2 = 0; i2 < varsToCombine.size(); i2++) {
                    if (this.doubleCache) {
                        ((ListPotential) values).combinePotentialsOf(varsToCombine.elementAt(i2), this.limitForPruning, this.limitSumForPruning, this.cache1, this.cache2, this.firstTour, this.heuristicToJoin);
                    } else {
                        ((ListPotential) values).combinePotentialsOf(varsToCombine.elementAt(i2), this.limitForPruning, this.limitSumForPruning);
                    }
                }
            } else if (this.whenJoinPotentials == 3) {
                Potential createPotential = this.doubleCache ? ((ListPotential) values).createPotential(this.limitForPruning, this.limitSumForPruning, this.cache1, this.cache2, this.firstTour, this.heuristicToJoin) : ((ListPotential) values).createPotential(this.limitForPruning, this.limitSumForPruning);
                if (createPotential != null) {
                    values = new ListPotential(createPotential);
                }
            } else if (this.whenJoinPotentials != 4) {
                System.out.println("Error in SimpleLazyPenniless.sendMessage(): wrong value for whenJoinPotentials=" + this.whenJoinPotentials);
                System.exit(1);
            } else if (this.doubleCache) {
                System.out.println("ERROR: whenJoinPotentials == 4 cannot be used with cache");
                System.exit(1);
            } else {
                ((ListPotential) values).joinPotentials(this.limitForPruning, this.limitSumForPruning);
            }
        }
        if (this.whenJoinPotentials == 5) {
            values = ((ListPotential) values).marginalizePotential(vector, this.limitForPruning, this.limitSumForPruning);
        } else if (((ListPotential) values).getListSize() > 0) {
            if (this.doubleCache) {
                values = ((ListPotential) values).marginalizePotential(vector, this.limitForPruning, this.limitSumForPruning, this.cache1, this.cache2, this.cache1M, this.cache2M, this.firstTour, this.heuristicToJoin);
            } else {
                if (this.factorisationParam.getFacPhase() == 0) {
                    this.factorisationParam.setFactMethod(3);
                }
                values = ((ListPotential) values).factorMarginalizePotential(vector, this.limitForPruning, this.heuristicToJoin, this.limitSumForPruning, this.factorisationParam);
            }
        }
        if (this.sortAndBound) {
            ListPotential listPotential = (ListPotential) values;
            ListPotential listPotential2 = new ListPotential();
            listPotential2.setVariables(listPotential.getVariables());
            int listSize = listPotential.getListSize();
            for (int i3 = 0; i3 < listSize; i3++) {
                listPotential2.insertPotential(listPotential.getPotentialAt(i3).sortAndBound(this.maximumSize));
            }
            values = listPotential2;
        }
        relation2.setValues(values);
        relation.setOtherValues(values);
    }
}
