package elvira.learning.classification.supervised.discrete;

import elvira.Bnet;
import elvira.Configuration;
import elvira.FiniteStates;
import elvira.Graph;
import elvira.InvalidEditException;
import elvira.Link;
import elvira.LinkList;
import elvira.Node;
import elvira.NodeList;
import elvira.Relation;
import elvira.RelationList;
import elvira.database.DataBaseCases;
import elvira.learning.classification.SizeComparableClassifier;
import elvira.learning.classification.supervised.continuous.MaximumSpanningTree;
import elvira.potential.PotentialTable;
import elvira.probabilisticDecisionGraph.tools.CasesOps;
import elvira.probabilisticDecisionGraph.tools.Measures;
import java.io.FileWriter;
import java.io.IOException;
import java.util.Iterator;
import java.util.Stack;
import java.util.Vector;

/* loaded from: input_file:bayelvira-1.0-SNAPSHOT.jar:elvira/learning/classification/supervised/discrete/TreeAugmentedNaiveBayes.class */
public class TreeAugmentedNaiveBayes implements SizeComparableClassifier {
    private Bnet classifier;
    private FiniteStates classVar = null;

    @Override // elvira.learning.classification.SizeComparableClassifier
    public long size() {
        return this.classifier.getNumberOfFreeParameters();
    }

    @Override // elvira.learning.classification.Classifier
    public Vector<Double> classify(Configuration configuration, int i) {
        if (this.classifier.getNodeList().size() != configuration.getVariables().size()) {
            System.out.println("WARNING : method classify recieved a configuration that is not full - this is probably an error!!!");
        }
        int numStates = this.classVar.getNumStates();
        Vector<Double> vector = new Vector<>(numStates);
        for (int i2 = 0; i2 < numStates; i2++) {
            configuration.setValue(i, i2);
            vector.add(Double.valueOf(this.classifier.evaluateFullConfiguration(configuration)));
        }
        return vector;
    }

    private void completeUndirectedGraph(Graph graph) {
        NodeList nodeList = graph.getNodeList();
        for (int i = 0; i < nodeList.size(); i++) {
            Node elementAt = nodeList.elementAt(i);
            for (int i2 = i + 1; i2 < nodeList.size(); i2++) {
                try {
                    graph.createLink(elementAt, nodeList.elementAt(i2), false);
                } catch (InvalidEditException e) {
                    e.printStackTrace();
                    System.exit(112);
                }
            }
        }
    }

    private void repairNodesFromLinkList() {
        if (this.classifier != null) {
            NodeList nodeList = this.classifier.getNodeList();
            LinkList linkList = this.classifier.getLinkList();
            for (int i = 0; i < nodeList.size(); i++) {
                Node elementAt = nodeList.elementAt(i);
                elementAt.setParents(new LinkList());
                elementAt.setChildren(new LinkList());
                elementAt.setSiblings(new LinkList());
            }
            for (int i2 = 0; i2 < linkList.size(); i2++) {
                Link elementAt2 = linkList.elementAt(i2);
                if (elementAt2.getDirected()) {
                    elementAt2.getTail().addChild(elementAt2);
                    elementAt2.getHead().addParent(elementAt2);
                } else {
                    elementAt2.getHead().addNeighbour(elementAt2.getTail());
                    elementAt2.getTail().addNeighbour(elementAt2.getHead());
                }
            }
        }
    }

    private void directTree(Graph graph, Node node) {
        Stack stack = new Stack();
        stack.push(node);
        Vector vector = new Vector(graph.getLinkList().getLinks());
        graph.setKindOfGraph(2);
        while (!stack.isEmpty()) {
            Node node2 = (Node) stack.pop();
            Vector vector2 = new Vector();
            Iterator it = vector.iterator();
            while (it.hasNext()) {
                Link link = (Link) it.next();
                Node node3 = null;
                if (link.getHead() == node2) {
                    node3 = link.getTail();
                } else if (link.getTail() == node2) {
                    node3 = link.getHead();
                }
                if (node3 != null) {
                    stack.push(node3);
                    try {
                        graph.removeLink(link);
                    } catch (InvalidEditException e) {
                        e.printStackTrace();
                        System.out.println("Could not remove link!!!");
                    }
                    try {
                        graph.createLink(node2, node3, true);
                    } catch (InvalidEditException e2) {
                        e2.printStackTrace();
                        System.out.println("It seems that the tree givin to directTree method was not really a tree - that is, it contained undirected cycles!!!");
                    }
                    vector2.add(link);
                }
            }
            vector.removeAll(vector2);
        }
    }

    @Override // elvira.learning.classification.Classifier
    public void learn(DataBaseCases dataBaseCases, int i) {
        Vector<Node> newVectorOfNodes = dataBaseCases.getNewVectorOfNodes();
        FiniteStates finiteStates = (FiniteStates) newVectorOfNodes.elementAt(i);
        this.classVar = finiteStates;
        newVectorOfNodes.remove(i);
        Graph graph = new Graph();
        graph.setKindOfGraph(1);
        graph.setNodeList(new NodeList(newVectorOfNodes));
        completeUndirectedGraph(graph);
        Vector vector = new Vector(finiteStates.getNumStates());
        for (int i2 = 0; i2 < finiteStates.getNumStates(); i2++) {
            vector.add(CasesOps.selectFromWhere(dataBaseCases.getCases(), finiteStates, i2));
        }
        Vector vector2 = new Vector();
        LinkList linkList = graph.getLinkList();
        for (int i3 = 0; i3 < linkList.size(); i3++) {
            Link elementAt = linkList.elementAt(i3);
            vector2.add(Double.valueOf(Measures.normalisedCMI((FiniteStates) elementAt.getHead(), (FiniteStates) elementAt.getTail(), vector)));
        }
        Graph mst = new MaximumSpanningTree(graph, vector2).getMST();
        directTree(mst, mst.getMaximallyConnectedNode());
        this.classifier = new Bnet();
        this.classifier.setNodeList(dataBaseCases.getVariables());
        this.classifier.setLinkList(mst.getLinkList());
        Iterator<Node> it = newVectorOfNodes.iterator();
        while (it.hasNext()) {
            Node next = it.next();
            try {
                this.classifier.createLink(this.classVar, next, true);
            } catch (InvalidEditException e) {
                e.printStackTrace();
                System.out.println("Unable to create link from Class to feature " + next.getName() + "!!!");
                System.exit(112);
            }
        }
        RelationList relationList = new RelationList();
        NodeList nodeList = this.classifier.getNodeList();
        for (int i4 = 0; i4 < nodeList.size(); i4++) {
            FiniteStates finiteStates2 = (FiniteStates) nodeList.elementAt(i4);
            NodeList nodeList2 = new NodeList();
            nodeList2.insertNode(finiteStates2);
            nodeList2.join(new NodeList(this.classifier.getLinkList().getParentsInList(finiteStates2)));
            PotentialTable potentialTable = dataBaseCases.getPotentialTable(nodeList2);
            potentialTable.sum(1.0d);
            relationList.insertRelation(new Relation(potentialTable.divide((PotentialTable) potentialTable.addVariable(finiteStates2))));
        }
        this.classifier.setRelationList(relationList.getRelations());
        repairNodesFromLinkList();
    }

    @Override // elvira.learning.classification.Classifier
    public void saveModelToFile(String str) throws IOException {
        FileWriter fileWriter = new FileWriter(str + "/tan.elv");
        this.classifier.saveBnet(fileWriter);
        fileWriter.close();
    }

    public Bnet getClassifier() {
        return this.classifier;
    }

    public FiniteStates getClassVar() {
        return this.classVar;
    }

    public static void main(String[] strArr) throws Exception {
        DataBaseCases dataBaseCases = new DataBaseCases(strArr[0]);
        new TreeAugmentedNaiveBayes().learn(dataBaseCases, dataBaseCases.getClassId());
    }
}
