package elvira.learning.classification.supervised.mixed;

import elvira.Bnet;
import elvira.FiniteStates;
import elvira.Graph;
import elvira.InvalidEditException;
import elvira.Link;
import elvira.LinkList;
import elvira.Node;
import elvira.NodeList;
import elvira.Relation;
import elvira.database.DataBaseCases;
import elvira.learning.MTELearning;
import elvira.learning.classification.ClassifierValidator;
import elvira.learning.classification.supervised.continuous.MaximumSpanningTree;
import elvira.learning.classification.supervised.continuous.NaiveMTEPredictor;
import elvira.potential.ContinuousProbabilityTree;
import elvira.potential.PotentialContinuousPT;
import elvira.tools.VectorManipulator;
import java.io.FileInputStream;
import java.util.Iterator;
import java.util.Stack;
import java.util.Vector;

/* loaded from: input_file:bayelvira-1.0-SNAPSHOT.jar:elvira/learning/classification/supervised/mixed/TANMTEClassifier.class */
public class TANMTEClassifier extends MixedClassifier {
    int intervals;

    public TANMTEClassifier(DataBaseCases dataBaseCases, int i, int i2) throws InvalidEditException {
        super(dataBaseCases, true, i);
        this.intervals = i2;
    }

    private void FullConnectedUndirectedGraph(Graph graph) {
        NodeList nodeList = graph.getNodeList();
        for (int i = 0; i < nodeList.size(); i++) {
            Node elementAt = nodeList.elementAt(i);
            for (int i2 = i + 1; i2 < nodeList.size(); i2++) {
                try {
                    graph.createLink(elementAt, nodeList.elementAt(i2), false);
                } catch (InvalidEditException e) {
                    e.printStackTrace();
                    System.exit(112);
                }
            }
        }
    }

    private Vector<Double> ConditionalMutualInformation(Graph graph) {
        MTELearning mTELearning = new MTELearning(this.cases);
        LinkList linkList = graph.getLinkList();
        Vector<Double> vector = new Vector<>();
        System.out.println("\nEstimating the CMI for each link...\n");
        for (int i = 0; i < linkList.size(); i++) {
            NodeList nodeList = new NodeList();
            nodeList.insertNode(this.classVar);
            ContinuousProbabilityTree learnConditional = mTELearning.learnConditional(this.classVar, new NodeList(), this.cases, this.intervals, 4);
            ContinuousProbabilityTree learnConditional2 = mTELearning.learnConditional(linkList.elementAt(i).getHead(), nodeList, this.cases, this.intervals, 4);
            ContinuousProbabilityTree learnConditional3 = mTELearning.learnConditional(linkList.elementAt(i).getTail(), nodeList, this.cases, this.intervals, 4);
            Node head = linkList.elementAt(i).getHead();
            nodeList.insertNode(linkList.elementAt(i).getTail());
            double estimateConditionalMutualInformation = ContinuousProbabilityTree.estimateConditionalMutualInformation(learnConditional, learnConditional3, mTELearning.learnConditional(head, nodeList, this.cases, this.intervals, 4), learnConditional2, 5000);
            vector.addElement(Double.valueOf(estimateConditionalMutualInformation));
            System.out.print("\nLINK: " + ((Link) linkList.getLinks().elementAt(i)).toString());
            System.out.println("---------------------------------------------");
            System.out.println("     I(Xi,Xj|C)=" + estimateConditionalMutualInformation);
            System.out.println("---------------------------------------------");
        }
        return vector;
    }

    private void CreateDirectedLinkClassFeatures() {
        for (int i = 0; i < this.nVariables; i++) {
            if (i != this.classIndex) {
                try {
                    this.classifier.createLink(this.classifier.getNodeList().elementAt(this.classIndex), this.classifier.getNodeList().elementAt(i), true);
                } catch (Exception e) {
                    System.out.println("Problems to create the link");
                }
            }
        }
    }

    private void directTree(Graph graph, Node node) {
        Stack stack = new Stack();
        stack.push(node);
        Vector vector = new Vector(graph.getLinkList().getLinks());
        graph.setKindOfGraph(2);
        while (!stack.isEmpty()) {
            Node node2 = (Node) stack.pop();
            Vector vector2 = new Vector();
            Iterator it = vector.iterator();
            while (it.hasNext()) {
                Link link = (Link) it.next();
                Node node3 = null;
                if (link.getHead() == node2) {
                    node3 = link.getTail();
                } else if (link.getTail() == node2) {
                    node3 = link.getHead();
                }
                if (node3 != null) {
                    stack.push(node3);
                    try {
                        graph.removeLink(link);
                    } catch (InvalidEditException e) {
                        e.printStackTrace();
                        System.out.println("Could not remove link!!!");
                    }
                    try {
                        graph.createLink(node2, node3, true);
                    } catch (InvalidEditException e2) {
                        e2.printStackTrace();
                        System.out.println("It seems that the tree givin to directTree method was not really a tree - that is, it contained undirected cycles!!!");
                    }
                    vector2.add(link);
                }
            }
            vector.removeAll(vector2);
        }
    }

    @Override // elvira.learning.classification.supervised.discrete.DiscreteClassifier
    public void structuralLearning() {
        System.out.println("\nSTRUCTURAL LEARNING ...\n");
        NaiveMTEPredictor naiveMTEPredictor = new NaiveMTEPredictor(this.cases, this.classIndex, this.intervals);
        System.out.println(naiveMTEPredictor.getListOfMututlinformation().toString());
        int findMaxDoubles = VectorManipulator.findMaxDoubles(naiveMTEPredictor.getListOfMututlinformation());
        NodeList nodeList = new NodeList(this.cases.getVariables().getNodes());
        Node elementAt = nodeList.elementAt(findMaxDoubles);
        nodeList.removeNode(this.classIndex);
        Graph graph = new Graph(1);
        graph.setNodeList(nodeList);
        FullConnectedUndirectedGraph(graph);
        Graph mst = new MaximumSpanningTree(graph, ConditionalMutualInformation(graph)).getMST();
        directTree(mst, elementAt);
        this.classifier = new Bnet();
        this.classifier.setNodeList(nodeList);
        this.classifier.setLinkList(mst.getLinkList());
        nodeList.getNodes().insertElementAt(this.classVar, this.classIndex);
        CreateDirectedLinkClassFeatures();
    }

    @Override // elvira.learning.classification.supervised.discrete.DiscreteClassifier
    public void parametricLearning() {
        MTELearning mTELearning = new MTELearning(this.cases);
        Vector vector = new Vector();
        System.out.println("\n\n====> Learning TAN <=======");
        for (int i = 0; i < this.nVariables; i++) {
            if (i != this.classIndex) {
                Node elementAt = this.classifier.getNodeList().elementAt(i);
                new NodeList();
                NodeList parentNodes = elementAt.getParentNodes();
                System.out.println("   Learning " + elementAt.getName() + " ...");
                ContinuousProbabilityTree learnConditional = mTELearning.learnConditional(elementAt, parentNodes, this.cases, this.intervals, 4);
                NodeList nodeList = new NodeList();
                nodeList.insertNode(elementAt);
                for (int i2 = 0; i2 < parentNodes.size(); i2++) {
                    nodeList.insertNode(parentNodes.elementAt(i2));
                }
                PotentialContinuousPT potentialContinuousPT = new PotentialContinuousPT(nodeList, learnConditional);
                Relation relation = new Relation();
                relation.setVariables(nodeList);
                relation.setValues(potentialContinuousPT);
                vector.addElement(relation);
            }
        }
        FiniteStates finiteStates = this.classVar;
        System.out.println("   Learning " + finiteStates.getName() + " ... (CLASS VARIABLE)");
        ContinuousProbabilityTree learnConditional2 = mTELearning.learnConditional(finiteStates, new NodeList(), this.cases, this.intervals, 4);
        NodeList nodeList2 = new NodeList();
        nodeList2.insertNode(this.classVar);
        PotentialContinuousPT potentialContinuousPT2 = new PotentialContinuousPT(nodeList2, learnConditional2);
        Relation relation2 = new Relation();
        relation2.setVariables(nodeList2);
        relation2.setValues(potentialContinuousPT2);
        vector.addElement(relation2);
        this.classifier.setRelationList(vector);
    }

    public static void main(String[] strArr) throws Exception {
        FileInputStream fileInputStream = new FileInputStream(strArr[0]);
        int intValue = Integer.valueOf(strArr[2]).intValue();
        int intValue2 = Integer.valueOf(strArr[3]).intValue();
        DataBaseCases dataBaseCases = new DataBaseCases(fileInputStream);
        TANMTEClassifier tANMTEClassifier = new TANMTEClassifier(dataBaseCases, intValue, intValue2);
        tANMTEClassifier.structuralLearning();
        tANMTEClassifier.parametricLearning();
        tANMTEClassifier.saveModelToFile("TAN");
        if (strArr[1].compareTo("CV") != 0) {
            System.out.println("Classifier tested. Train accuracy: " + tANMTEClassifier.test(new DataBaseCases(new FileInputStream(strArr[1])), intValue));
        } else {
            int intValue3 = Integer.valueOf(strArr[4]).intValue();
            System.out.println(intValue3 + "-folds Cross-Validation. Accuracy=" + (1.0d - new ClassifierValidator(tANMTEClassifier, dataBaseCases, intValue).kFoldCrossValidation(intValue3).getError()) + "\n\n");
        }
    }
}
