package elvira.learning.classification.supervised.discrete;

import elvira.Bnet;
import elvira.FiniteStates;
import elvira.Graph;
import elvira.InvalidEditException;
import elvira.Link;
import elvira.LinkList;
import elvira.Node;
import elvira.NodeList;
import elvira.database.DataBaseCases;
import elvira.learning.BDeMetrics;
import elvira.learning.BICMetrics;
import elvira.learning.DELearning;
import elvira.learning.K2Metrics;
import elvira.learning.LPLearning;
import elvira.learning.Metrics;
import elvira.learning.ParameterLearning;
import elvira.parser.ParseException;
import java.io.FileInputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.util.Vector;

/* loaded from: input_file:bayelvira-1.0-SNAPSHOT.jar:elvira/learning/classification/supervised/discrete/BANLearning.class */
public class BANLearning extends MarkovBlanketLearning {
    Metrics metric;
    String metricName;

    public BANLearning() {
        this.metricName = new String();
        this.metric = null;
    }

    public BANLearning(DataBaseCases dataBaseCases, int i, boolean z, String str) {
        super(dataBaseCases, i, z);
        setMetrics(null);
        this.metricName = str;
    }

    @Override // elvira.learning.classification.supervised.discrete.MarkovBlanketLearning, elvira.learning.classification.Classifier
    public void learn(DataBaseCases dataBaseCases, int i) {
        ParameterLearning dELearning;
        setInput(dataBaseCases);
        setVarToClassify(i);
        setMetrics(this.metricName.equals("BIC") ? new BICMetrics(getInput()) : this.metricName.equals("K2") ? new K2Metrics(getInput()) : new BDeMetrics(getInput()));
        learning();
        if (getIfAplyLaplaceCorrection()) {
            dELearning = new LPLearning(dataBaseCases, getOutput());
            dELearning.learning();
        } else {
            dELearning = new DELearning(dataBaseCases, getOutput());
            dELearning.learning();
        }
        setOutput(dELearning.getOutput());
    }

    @Override // elvira.learning.classification.supervised.discrete.MarkovBlanketLearning, elvira.learning.Learning
    public void learning() {
        NodeList variables = getInput().getVariables();
        Bnet bnet = new Bnet();
        bnet.setKindOfGraph(2);
        for (int i = 0; i < variables.size(); i++) {
            try {
                bnet.addNode(variables.elementAt(i));
            } catch (InvalidEditException e) {
            }
        }
        Node elementAt = variables.elementAt(this.classvar);
        for (int i2 = 0; i2 < variables.size(); i2++) {
            if (i2 != this.classvar) {
                try {
                    bnet.createLink(elementAt, variables.elementAt(i2), true);
                } catch (InvalidEditException e2) {
                    System.out.println("Error adding a link when creating the Bnet");
                }
            }
        }
        if (getMetrics() == null) {
            setMetrics(this.metricName.equals("BIC") ? new BICMetrics(getInput()) : this.metricName.equals("K2") ? new K2Metrics(getInput()) : new BDeMetrics(getInput()));
        }
        setOutput(BANLocalSearch(getInput(), getMetrics(), this.classvar));
    }

    public Bnet BANLocalSearch(DataBaseCases dataBaseCases, Metrics metrics, int i) {
        new Bnet();
        NodeList duplicate = dataBaseCases.getNodeList().duplicate();
        Graph graph = new Graph(0);
        for (int i2 = 0; i2 < duplicate.size(); i2++) {
            try {
                graph.addNode(duplicate.elementAt(i2));
            } catch (InvalidEditException e) {
            }
        }
        Node elementAt = duplicate.elementAt(i);
        for (int i3 = 0; i3 < duplicate.size(); i3++) {
            if (i3 != i) {
                try {
                    graph.createLink(elementAt, duplicate.elementAt(i3), true);
                } catch (InvalidEditException e2) {
                    System.out.println("Error adding a link when creating the Bnet");
                }
            }
        }
        Bnet bnet = new Bnet(graph.getNodeList());
        double score = metrics.score(bnet);
        Graph duplicate2 = bnet.duplicate();
        Bnet bnet2 = new Bnet();
        bnet2.setNodeList(duplicate2.getNodeList());
        bnet2.setLinkList(duplicate2.getLinkList());
        boolean z = true;
        while (z) {
            Vector vector = new Vector();
            Vector vector2 = new Vector();
            int maxScore = maxScore(dataBaseCases, metrics, vector, vector2, bnet2);
            Link link = (Link) vector.elementAt(0);
            NodeList nodeList = (NodeList) vector2.elementAt(0);
            if (maxScore == -1) {
                z = false;
            } else if (maxScore == 0) {
                double score2 = metrics.score(nodeList);
                NodeList nodeList2 = new NodeList();
                nodeList2.insertNode(link.getHead());
                nodeList2.join(bnet2.parents(link.getHead()));
                if (score2 > metrics.score(dataBaseCases.getNodeList().intersectionNames(nodeList2).sortNames(nodeList2))) {
                    try {
                        bnet2.removeLink(bnet2.getLink(link.getTail().getName(), link.getHead().getName()));
                        metrics.score(bnet2);
                    } catch (InvalidEditException e3) {
                    }
                } else {
                    z = false;
                }
            } else if (maxScore == 1) {
                double score3 = metrics.score(nodeList);
                NodeList nodeList3 = new NodeList();
                nodeList3.insertNode(link.getHead());
                NodeList parents = bnet2.parents(link.getHead());
                parents.removeNode(link.getTail());
                nodeList3.join(parents);
                double score4 = score3 + metrics.score(dataBaseCases.getNodeList().intersectionNames(nodeList3).sortNames(nodeList3));
                NodeList nodeList4 = new NodeList();
                nodeList4.insertNode(link.getHead());
                nodeList4.join(bnet2.parents(link.getHead()));
                double score5 = metrics.score(dataBaseCases.getNodeList().intersectionNames(nodeList4).sortNames(nodeList4));
                NodeList nodeList5 = new NodeList();
                nodeList5.insertNode(link.getTail());
                nodeList5.join(bnet2.parents(link.getTail()));
                if (score4 > score5 + metrics.score(dataBaseCases.getNodeList().intersectionNames(nodeList5).sortNames(nodeList5))) {
                    try {
                        bnet2.removeLink(bnet2.getLink(link.getTail().getName(), link.getHead().getName()));
                        bnet2.createLink(link.getHead(), link.getTail(), true);
                        metrics.score(bnet2);
                    } catch (InvalidEditException e4) {
                    }
                } else {
                    z = false;
                }
            } else if (maxScore == 2) {
                double score6 = metrics.score(nodeList);
                NodeList nodeList6 = new NodeList();
                nodeList6.insertNode(link.getHead());
                nodeList6.join(bnet2.parents(link.getHead()));
                if (score6 > metrics.score(dataBaseCases.getNodeList().intersectionNames(nodeList6).sortNames(nodeList6))) {
                    try {
                        bnet2.createLink(link.getTail(), link.getHead(), true);
                        metrics.score(bnet2);
                    } catch (InvalidEditException e5) {
                    }
                } else {
                    z = false;
                }
            } else {
                z = false;
            }
        }
        if (metrics.score(bnet2) > score) {
            bnet = new Bnet(bnet2.duplicate().getNodeList());
        }
        return bnet;
    }

    private int maxScore(DataBaseCases dataBaseCases, Metrics metrics, Vector vector, Vector vector2, Bnet bnet) {
        int i = -1;
        Link link = null;
        NodeList nodeList = null;
        double d = Double.NEGATIVE_INFINITY;
        for (int i2 = 0; i2 < bnet.getNodeList().size(); i2++) {
            FiniteStates finiteStates = (FiniteStates) bnet.getNodeList().elementAt(i2);
            for (int i3 = 0; i3 < bnet.getNodeList().size(); i3++) {
                FiniteStates finiteStates2 = (FiniteStates) bnet.getNodeList().elementAt(i3);
                if (i2 != i3) {
                    Link link2 = bnet.getLink(finiteStates.getName(), finiteStates2.getName());
                    if (link2 != null) {
                        NodeList parents = bnet.parents(finiteStates2);
                        NodeList nodeList2 = new NodeList();
                        nodeList2.insertNode(finiteStates2);
                        nodeList2.join(parents);
                        double score = metrics.score(dataBaseCases.getNodeList().intersectionNames(nodeList2).sortNames(nodeList2));
                        parents.removeNode(finiteStates);
                        NodeList nodeList3 = new NodeList();
                        nodeList3.insertNode(finiteStates2);
                        nodeList3.join(parents);
                        NodeList sortNames = dataBaseCases.getNodeList().intersectionNames(nodeList3).sortNames(nodeList3);
                        double score2 = metrics.score(sortNames);
                        double d2 = score2 - score;
                        if (finiteStates.equals(dataBaseCases.getNodeList().lastElement())) {
                            d2 = Double.NEGATIVE_INFINITY;
                        }
                        if (d2 > d) {
                            d = d2;
                            i = 0;
                            link = link2;
                            nodeList = new NodeList();
                            nodeList.join(sortNames);
                        }
                        try {
                            bnet.removeLink(link2);
                        } catch (InvalidEditException e) {
                        }
                        new Vector();
                        Vector directedDescendants = bnet.directedDescendants(link2.getTail());
                        try {
                            bnet.createLink(link2.getTail(), link2.getHead(), true);
                        } catch (InvalidEditException e2) {
                        }
                        if (directedDescendants.indexOf(link2.getHead()) == -1) {
                            NodeList parents2 = bnet.parents(finiteStates);
                            NodeList nodeList4 = new NodeList();
                            nodeList4.insertNode(finiteStates);
                            nodeList4.join(parents2);
                            double score3 = metrics.score(dataBaseCases.getNodeList().intersectionNames(nodeList4).sortNames(nodeList4)) + score;
                            parents2.insertNode(finiteStates2);
                            NodeList nodeList5 = new NodeList();
                            nodeList5.insertNode(finiteStates);
                            nodeList5.join(parents2);
                            NodeList sortNames2 = dataBaseCases.getNodeList().intersectionNames(nodeList5).sortNames(nodeList5);
                            double score4 = (metrics.score(sortNames2) + score2) - score3;
                            if (finiteStates.equals(dataBaseCases.getNodeList().lastElement())) {
                                score4 = Double.NEGATIVE_INFINITY;
                            }
                            if (score4 > d) {
                                d = score4;
                                i = 1;
                                link = link2;
                                nodeList = new NodeList();
                                nodeList.join(sortNames2);
                            }
                        }
                    } else {
                        new Vector();
                        if (bnet.directedDescendants(finiteStates2).indexOf(finiteStates) == -1) {
                            NodeList parents3 = bnet.parents(finiteStates2);
                            NodeList nodeList6 = new NodeList();
                            nodeList6.insertNode(finiteStates2);
                            nodeList6.join(parents3);
                            double score5 = metrics.score(dataBaseCases.getNodeList().intersectionNames(nodeList6).sortNames(nodeList6));
                            parents3.insertNode(finiteStates);
                            NodeList nodeList7 = new NodeList();
                            nodeList7.insertNode(finiteStates2);
                            nodeList7.join(parents3);
                            NodeList sortNames3 = dataBaseCases.getNodeList().intersectionNames(nodeList7).sortNames(nodeList7);
                            double score6 = metrics.score(sortNames3) - score5;
                            if (score6 > d) {
                                d = score6;
                                i = 2;
                                link = new Link(finiteStates, finiteStates2);
                                nodeList = new NodeList();
                                nodeList.join(sortNames3);
                            }
                        }
                    }
                }
            }
        }
        vector.addElement(link);
        vector2.addElement(nodeList);
        return i;
    }

    public void setMetrics(Metrics metrics) {
        this.metric = metrics;
    }

    public Metrics getMetrics() {
        return this.metric;
    }

    public static void main(String[] strArr) throws ParseException, IOException {
        Metrics bDeMetrics;
        String str;
        if (strArr.length < 4) {
            System.out.println("too few arguments: Usage: input.dbc ouput.elv class metric [file.elv]");
            System.out.println("\tinput.dbc : DataBaseCases file for building the bnet");
            System.out.println("\toutput.dbc : For saving the result");
            System.out.println("\tclass : The number of the variable to classify if it's the first use 0.");
            System.out.println("\tmetric : Metric used to score, it can be BIC,K2,BDe");
            System.out.println("\tfile.elv: Optional. True net to be compared.");
            System.exit(0);
        }
        DataBaseCases dataBaseCases = new DataBaseCases(new FileInputStream(strArr[0]));
        int intValue = new Integer(strArr[2]).intValue();
        if (strArr[3].equals("BIC")) {
            bDeMetrics = new BICMetrics(dataBaseCases);
            str = strArr[3];
        } else if (strArr[3].equals("K2")) {
            bDeMetrics = new K2Metrics(dataBaseCases);
            str = strArr[3];
        } else {
            bDeMetrics = new BDeMetrics(dataBaseCases);
            str = "BDe";
        }
        BANLearning bANLearning = new BANLearning(dataBaseCases, intValue, true, str);
        bANLearning.learning();
        LPLearning lPLearning = new LPLearning(dataBaseCases, bANLearning.getOutput());
        lPLearning.learning();
        double divergenceKL = dataBaseCases.getDivergenceKL(lPLearning.getOutput());
        System.out.println("KL Divergence = " + divergenceKL);
        System.out.println("Bayes Metric for the output net: " + bDeMetrics.score(bANLearning.getOutput()));
        FileWriter fileWriter = new FileWriter(strArr[1]);
        lPLearning.getOutput().saveBnet(fileWriter);
        fileWriter.close();
        if (strArr.length > 4) {
            Bnet bnet = new Bnet(new FileInputStream(strArr[4]));
            double divergenceKL2 = dataBaseCases.getDivergenceKL(bnet);
            System.out.println("kL Divergence for optimal net: " + divergenceKL2);
            System.out.println("Divergence between optimal and learned nets: " + (divergenceKL2 - divergenceKL));
            LinkList[] linkListArr = new LinkList[3];
            LinkList[] compareOutput = bANLearning.compareOutput(bnet);
            System.out.print("\nAdded links: " + compareOutput[0].size());
            System.out.print(compareOutput[0].toString());
            System.out.print("\nRemoved links: " + compareOutput[1].size());
            System.out.print(compareOutput[1].toString());
            System.out.println("\nInverted links: " + compareOutput[2].size());
            System.out.print(compareOutput[2].toString());
        }
    }
}
