package elvira.learning.classification.supervised.discrete;

import elvira.Bnet;
import elvira.CaseListMem;
import elvira.FiniteStates;
import elvira.Graph;
import elvira.InvalidEditException;
import elvira.Link;
import elvira.LinkList;
import elvira.Node;
import elvira.NodeList;
import elvira.Relation;
import elvira.database.DataBaseCases;
import elvira.learning.classification.AuxiliarPotentialTable;
import elvira.parser.ParseException;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileWriter;
import java.io.IOException;
import java.util.Collections;
import java.util.Comparator;
import java.util.Vector;
import weka.classifiers.lazy.kstar.KStarConstants;

/* loaded from: input_file:bayelvira-1.0-SNAPSHOT.jar:elvira/learning/classification/supervised/discrete/CMutInfTAN.class */
public class CMutInfTAN extends TAN {

    /* loaded from: input_file:bayelvira-1.0-SNAPSHOT.jar:elvira/learning/classification/supervised/discrete/CMutInfTAN$CMIComparator.class */
    private class CMIComparator implements Comparator {
        private CMIComparator() {
        }

        @Override // java.util.Comparator
        public int compare(Object obj, Object obj2) {
            CMutInf cMutInf = (CMutInf) obj;
            CMutInf cMutInf2 = (CMutInf) obj2;
            if (cMutInf.getScore() < cMutInf2.getScore()) {
                return 1;
            }
            return cMutInf.getScore() > cMutInf2.getScore() ? -1 : 0;
        }
    }

    /* loaded from: input_file:bayelvira-1.0-SNAPSHOT.jar:elvira/learning/classification/supervised/discrete/CMutInfTAN$CMutInf.class */
    private class CMutInf {
        double score;
        FiniteStates node1;
        FiniteStates node2;
        DataBaseCases data;

        public CMutInf() {
        }

        public CMutInf(DataBaseCases dataBaseCases, FiniteStates finiteStates, FiniteStates finiteStates2) {
            this.data = dataBaseCases;
            this.node1 = finiteStates;
            this.node2 = finiteStates2;
            CaseListMem caseListMem = (CaseListMem) ((Relation) this.data.getRelationList().elementAt(0)).getValues();
            NodeList nodeList = this.data.getNodeList();
            FiniteStates finiteStates3 = (FiniteStates) nodeList.lastElement();
            int numStates = getNode1().getNumStates();
            int numStates2 = getNode2().getNumStates();
            AuxiliarPotentialTable auxiliarPotentialTable = new AuxiliarPotentialTable(this.node1.getNumStates(), finiteStates3.getNumStates());
            auxiliarPotentialTable.initialize(KStarConstants.FLOOR);
            AuxiliarPotentialTable auxiliarPotentialTable2 = new AuxiliarPotentialTable(this.node2.getNumStates(), finiteStates3.getNumStates());
            auxiliarPotentialTable2.initialize(KStarConstants.FLOOR);
            AuxiliarPotentialTable auxiliarPotentialTable3 = new AuxiliarPotentialTable(numStates * numStates2, finiteStates3.getNumStates());
            auxiliarPotentialTable3.initialize(KStarConstants.FLOOR);
            for (int i = 0; i < this.data.getNumberOfCases(); i++) {
                auxiliarPotentialTable.addCase((int) caseListMem.getValue(i, nodeList.getId(this.node1)), (int) caseListMem.getValue(i, nodeList.getId(nodeList.lastElement())), 1.0d);
                auxiliarPotentialTable2.addCase((int) caseListMem.getValue(i, nodeList.getId(this.node2)), (int) caseListMem.getValue(i, nodeList.getId(nodeList.lastElement())), 1.0d);
                auxiliarPotentialTable3.addCase((numStates2 * ((int) caseListMem.getValue(i, nodeList.getId(this.node1)))) + ((int) caseListMem.getValue(i, nodeList.getId(this.node2))), (int) caseListMem.getValue(i, nodeList.getId(nodeList.lastElement())), 1.0d);
            }
            double d = 0.0d;
            for (int i2 = 0; i2 < numStates; i2++) {
                for (int i3 = 0; i3 < numStates2; i3++) {
                    for (int i4 = 0; i4 < finiteStates3.getNumStates(); i4++) {
                        double potential = auxiliarPotentialTable3.getPotential((numStates2 * i2) + i3, i4);
                        if (potential != KStarConstants.FLOOR) {
                            d += potential * (Math.log(potential) / Math.log(10.0d));
                        }
                    }
                }
            }
            double d2 = 0.0d;
            for (int i5 = 0; i5 < numStates; i5++) {
                for (int i6 = 0; i6 < finiteStates3.getNumStates(); i6++) {
                    double potential2 = auxiliarPotentialTable.getPotential(i5, i6);
                    if (potential2 != KStarConstants.FLOOR) {
                        d2 += potential2 * (Math.log(potential2) / Math.log(10.0d));
                    }
                }
            }
            double d3 = 0.0d;
            for (int i7 = 0; i7 < numStates2; i7++) {
                for (int i8 = 0; i8 < finiteStates3.getNumStates(); i8++) {
                    double potential3 = auxiliarPotentialTable2.getPotential(i7, i8);
                    if (potential3 != KStarConstants.FLOOR) {
                        d3 += potential3 * (Math.log(potential3) / Math.log(10.0d));
                    }
                }
            }
            this.score = (d - d2) - d3;
        }

        public CMutInf(double d, FiniteStates finiteStates, FiniteStates finiteStates2) {
            this.score = d;
            this.node1 = finiteStates;
            this.node2 = finiteStates2;
        }

        public void setScore(double d) {
            this.score = d;
        }

        public void setNode1(FiniteStates finiteStates) {
            this.node1 = finiteStates;
        }

        public void setNode2(FiniteStates finiteStates) {
            this.node2 = finiteStates;
        }

        public double getScore() {
            return this.score;
        }

        public FiniteStates getNode1() {
            return this.node1;
        }

        public FiniteStates getNode2() {
            return this.node2;
        }
    }

    public CMutInfTAN() {
    }

    public CMutInfTAN(DataBaseCases dataBaseCases, boolean z) throws InvalidEditException {
        super(dataBaseCases, z);
    }

    @Override // elvira.learning.classification.supervised.discrete.TAN, elvira.learning.classification.supervised.discrete.DiscreteClassifier
    public void structuralLearning() throws InvalidEditException {
        this.evaluations = 1;
        Vector vector = new Vector();
        for (int i = 0; i < this.nVariables; i++) {
            vector.add(this.cases.getVariables().elementAt(i).copy());
        }
        NodeList nodeList = new NodeList((Vector<Node>) vector);
        Graph graph = new Graph(1);
        graph.setNodeList(nodeList);
        Vector vector2 = new Vector();
        for (int i2 = 0; i2 < this.nVariables - 1; i2++) {
            for (int i3 = i2 + 1; i3 < this.nVariables - 1; i3++) {
                vector2.add(new CMutInf(this.cases, (FiniteStates) nodeList.elementAt(i2), (FiniteStates) nodeList.elementAt(i3)));
            }
        }
        Collections.sort(vector2, new CMIComparator());
        int i4 = 0;
        int i5 = 0;
        while (i4 < this.nVariables - 2) {
            CMutInf cMutInf = (CMutInf) vector2.elementAt(i5);
            if (i4 < 2) {
                graph.createLink(cMutInf.getNode1(), cMutInf.getNode2(), false);
                i4++;
            } else if (!makesCycle(graph, cMutInf.getNode1(), cMutInf.getNode2())) {
                new Link((Node) cMutInf.getNode1(), (Node) cMutInf.getNode2(), false);
                graph.createLink(cMutInf.getNode1(), cMutInf.getNode2(), false);
                i4++;
            }
            i5++;
        }
        int nextInt = generator.nextInt(this.nVariables - 1);
        Node elementAt = graph.getNodeList().elementAt(nextInt);
        new NodeList();
        NodeList siblingsNodes = elementAt.getSiblingsNodes();
        Vector vector3 = new Vector();
        vector3.add(elementAt);
        Vector vector4 = new Vector();
        Vector vector5 = new Vector();
        Vector vector6 = new Vector();
        NodeList copy = nodeList.copy();
        for (int i6 = 0; i6 < siblingsNodes.size(); i6++) {
            Link link = new Link(elementAt, siblingsNodes.elementAt(i6));
            vector6.add(siblingsNodes.elementAt(i6));
            vector4.add(link);
            vector5.add(link);
            Vector vector7 = new Vector();
            vector7.add(link);
            LinkList linkList = new LinkList();
            linkList.setLinks(vector7);
            copy.elementAt(graph.getNodeList().getId(siblingsNodes.elementAt(i6))).setParents(linkList);
        }
        LinkList linkList2 = new LinkList();
        linkList2.setLinks(vector4);
        copy.elementAt(nextInt).setChildren(linkList2);
        while (!vector6.isEmpty()) {
            Node node = (Node) vector6.firstElement();
            vector3.add(node);
            new NodeList();
            int id = copy.getId(node);
            NodeList siblingsNodes2 = graph.getNodeList().elementAt(id).getSiblingsNodes();
            Vector vector8 = new Vector();
            for (int i7 = 0; i7 < siblingsNodes2.size(); i7++) {
                if (vector3.indexOf(siblingsNodes2.elementAt(i7)) == -1 && graph.getNodeList().elementAt(graph.getNodeList().getId(siblingsNodes2.elementAt(i7))).getParents().size() == 0) {
                    Link link2 = new Link(node, siblingsNodes2.elementAt(i7));
                    vector6.add(siblingsNodes2.elementAt(i7));
                    vector8.add(link2);
                    vector5.add(link2);
                    Vector vector9 = new Vector();
                    vector9.add(link2);
                    LinkList linkList3 = new LinkList();
                    linkList3.setLinks(vector9);
                    copy.elementAt(graph.getNodeList().getId(siblingsNodes2.elementAt(i7))).setParents(linkList3);
                }
            }
            LinkList linkList4 = new LinkList();
            linkList4.setLinks(vector8);
            copy.elementAt(id).setChildren(linkList4);
            vector6.remove(node);
        }
        Node elementAt2 = copy.elementAt(this.classIndex);
        elementAt2.setTitle(elementAt2.getName().concat(" ClassNode"));
        elementAt2.setComment("ClassNode");
        Vector vector10 = new Vector();
        for (int i8 = 0; i8 < this.nVariables; i8++) {
            if (i8 != this.classIndex) {
                vector10.add(new Link(elementAt2, copy.elementAt(i8)));
                vector5.add(new Link(elementAt2, copy.elementAt(i8)));
            }
        }
        LinkList linkList5 = new LinkList();
        linkList5.setLinks(vector10);
        elementAt2.setChildren(linkList5);
        for (int i9 = 0; i9 < this.nVariables; i9++) {
            if (i9 != this.classIndex) {
                Link link3 = new Link(elementAt2, copy.elementAt(i9));
                LinkList parents = copy.elementAt(i9).getParents();
                parents.getLinks().addElement(link3);
                copy.elementAt(i9).setParents(parents);
            }
        }
        this.classifier = new Bnet();
        for (int i10 = 0; i10 < this.nVariables; i10++) {
            this.classifier.addNode(copy.elementAt(i10));
            this.classifier.addRelation(copy.elementAt(i10));
        }
        LinkList linkList6 = new LinkList();
        linkList6.setLinks(vector5);
        this.classifier.setLinkList(linkList6);
        Vector vector11 = new Vector();
        vector11.addElement(Bnet.ABSENT);
        vector11.addElement(Bnet.PRESENT);
        this.classifier.setFSDefaultStates(vector11);
        this.classifier.setName("classifier TAN");
    }

    public static void main(String[] strArr) throws FileNotFoundException, IOException, InvalidEditException, ParseException, Exception {
        if (strArr.length != 3) {
            System.out.println("Usage: file-train.dbc file-test.dbc file-out.elv");
            System.exit(0);
        }
        FileInputStream fileInputStream = new FileInputStream(strArr[0]);
        DataBaseCases dataBaseCases = new DataBaseCases(fileInputStream);
        fileInputStream.close();
        CMutInfTAN cMutInfTAN = new CMutInfTAN(dataBaseCases, true);
        cMutInfTAN.train();
        System.out.println("Classifier learned");
        FileInputStream fileInputStream2 = new FileInputStream(strArr[1]);
        DataBaseCases dataBaseCases2 = new DataBaseCases(fileInputStream2);
        fileInputStream2.close();
        System.out.println("Classifier tested. Accuracy: " + cMutInfTAN.test(dataBaseCases2));
        cMutInfTAN.getConfusionMatrix().print();
        FileWriter fileWriter = new FileWriter(strArr[2]);
        cMutInfTAN.getClassifier().saveBnet(fileWriter);
        fileWriter.close();
    }

    @Override // elvira.learning.classification.supervised.discrete.DiscreteClassifier, elvira.learning.classification.Classifier
    public void saveModelToFile(String str) throws IOException {
        FileWriter fileWriter = new FileWriter(str + "CMutInfTAN.elv");
        this.classifier.saveBnet(fileWriter);
        fileWriter.close();
    }
}
