package elvira.learning;

import elvira.Bnet;
import elvira.Graph;
import elvira.InvalidEditException;
import elvira.Link;
import elvira.Node;
import elvira.NodeList;
import elvira.database.DataBaseCases;
import elvira.parser.ParseException;
import java.io.FileInputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.util.Date;
import java.util.Enumeration;
import java.util.Vector;

/* loaded from: input_file:bayelvira-1.0-SNAPSHOT.jar:elvira/learning/RPDAGLearning.class */
public class RPDAGLearning extends Learning {
    Bnet currentBnet;
    Metrics metric;
    DataBaseCases input;
    boolean laplace;

    public void setMetrics(Metrics metrics) {
        this.metric = metrics;
    }

    public Metrics getMetrics() {
        return this.metric;
    }

    public void setInput(DataBaseCases dataBaseCases) {
        this.input = dataBaseCases;
    }

    public DataBaseCases getInput() {
        return this.input;
    }

    public void setIfAplyLaplaceCorrection(boolean z) {
        this.laplace = z;
    }

    public boolean getIfAplyLaplaceCorrection() {
        return this.laplace;
    }

    public RPDAGLearning() {
        this.input = null;
        this.metric = null;
        setOutput(null);
        this.currentBnet = new Bnet();
        this.currentBnet.setKindOfGraph(2);
    }

    public RPDAGLearning(DataBaseCases dataBaseCases, boolean z, Metrics metrics) {
        setIfAplyLaplaceCorrection(z);
        setInput(dataBaseCases);
        this.metric = metrics;
        NodeList nodeList = dataBaseCases.getNodeList();
        this.currentBnet = new Bnet();
        for (int i = 0; i < nodeList.size(); i++) {
            try {
                this.currentBnet.addNode(nodeList.elementAt(i));
            } catch (InvalidEditException e) {
            }
        }
        this.currentBnet.setKindOfGraph(2);
    }

    private boolean hasUndirectedCycle(Graph graph, Node node, Node node2, Node node3, Vector vector) {
        if (node.equals(node2)) {
            vector.addElement(node3);
            return true;
        }
        Enumeration elements = graph.siblings(node2).elements();
        boolean z = false;
        while (!z && elements.hasMoreElements()) {
            Node node4 = (Node) elements.nextElement();
            if (!node4.equals(node3) && !node4.getVisited()) {
                node4.setVisited(true);
                z = hasUndirectedCycle(graph, node, node4, node2, vector);
            }
            if (z && node2.equals(node3)) {
                vector.addElement(node4);
            }
        }
        return z;
    }

    private boolean hasMixedCycle(Graph graph, Node node, Node node2, Node node3, Vector vector) {
        if (node.equals(node2)) {
            return true;
        }
        Enumeration elements = graph.siblings(node2).elements();
        Enumeration elements2 = graph.children(node2).elements();
        boolean z = false;
        while (!z && elements.hasMoreElements()) {
            Node node4 = (Node) elements.nextElement();
            if (!node4.equals(node3) && !node4.getVisited()) {
                node4.setVisited(true);
                z = hasMixedCycle(graph, node, node4, node2, vector);
            }
            if (z && node2.equals(node3)) {
                vector.addElement(node4);
            }
        }
        while (!z && elements2.hasMoreElements()) {
            Node node5 = (Node) elements2.nextElement();
            if (!node5.equals(node3) && !node5.getVisited()) {
                node5.setVisited(true);
                z = hasMixedCycle(graph, node, node5, node2, vector);
            }
            if (z && node2.equals(node3)) {
                vector.addElement(node5);
            }
        }
        return z;
    }

    @Override // elvira.learning.Learning
    public void learning() {
        System.out.println("With the bnet: " + this.currentBnet.getLinkList().toString());
        double score = this.metric.score(this.currentBnet);
        double d = score;
        System.out.println("initial fitness :" + score);
        boolean z = true;
        while (z) {
            int[] iArr = new int[3];
            Vector vector = new Vector();
            double maxScore = maxScore(iArr, vector, this.currentBnet);
            int i = iArr[0];
            Node node = (Node) vector.elementAt(0);
            Node node2 = (Node) vector.elementAt(1);
            d += maxScore;
            if (d > score) {
                score = d;
                if (i != 0) {
                    int i2 = iArr[1];
                    int i3 = iArr[2];
                    switch (i2) {
                        case 0:
                            try {
                                this.currentBnet.createLink(node, node2, false);
                                break;
                            } catch (InvalidEditException e) {
                                break;
                            }
                        case 1:
                            try {
                                this.currentBnet.createLink(node, node2, true);
                                break;
                            } catch (InvalidEditException e2) {
                                break;
                            }
                        case 2:
                            try {
                                this.currentBnet.createLink(node, node2, true);
                            } catch (InvalidEditException e3) {
                            }
                            this.currentBnet.orientInCascade(node2);
                            break;
                        default:
                            Link link = this.currentBnet.getLink((Node) vector.elementAt(2), node2);
                            try {
                                this.currentBnet.createLink(node, node2, true);
                                this.currentBnet.removeLink(link);
                                this.currentBnet.createLink((Node) vector.elementAt(2), node2, true);
                            } catch (InvalidEditException e4) {
                            }
                            if (i3 != 1) {
                                break;
                            } else {
                                this.currentBnet.orientInCascade(node2);
                                break;
                            }
                    }
                } else {
                    Link link2 = this.currentBnet.getLink(node, node2);
                    if (link2 != null) {
                        this.currentBnet.removeLinkrepairRPDAG(link2);
                    } else {
                        System.out.println("link to delete not found!!!!");
                        z = false;
                    }
                }
            } else {
                z = false;
            }
        }
        System.out.println("***FitnessBnet Final: " + score);
        System.out.println("Number of arcs " + this.currentBnet.getLinkList().size());
        this.currentBnet.extendRPDAG();
        setOutput(this.currentBnet);
    }

    private double maxScore(int[] iArr, Vector vector, Bnet bnet) {
        int i;
        NodeList nodeList;
        Node node;
        Node node2;
        int[] iArr2 = new int[7];
        double d = Double.NEGATIVE_INFINITY;
        int size = bnet.getNodeList().size();
        for (int i2 = 0; i2 < size; i2++) {
            Node elementAt = bnet.getNodeList().elementAt(i2);
            for (int i3 = i2 + 1; i3 < size; i3++) {
                Node elementAt2 = bnet.getNodeList().elementAt(i3);
                Link link = bnet.getLink(elementAt, elementAt2);
                Link link2 = link;
                if (link == null) {
                    Link link3 = bnet.getLink(elementAt2, elementAt);
                    link2 = link3;
                    i = link3 == null ? 1 : 0;
                } else {
                    i = 0;
                }
                if (i == 1) {
                    rpdagNeighb(bnet, elementAt, elementAt2, iArr2);
                    boolean z = false;
                    int i4 = 1;
                    while (i4 < iArr2[0]) {
                        if (iArr2[i4] < 0) {
                            node = elementAt2;
                            node2 = elementAt;
                        } else {
                            node = elementAt;
                            node2 = elementAt2;
                        }
                        boolean z2 = true;
                        Node node3 = null;
                        double d2 = 0.0d;
                        NodeList nodeList2 = new NodeList();
                        switch (Math.abs(iArr2[i4])) {
                            case 0:
                                Vector vector2 = new Vector();
                                bnet.setVisitedAll(false);
                                if (iArr2[i4 + 1] == 2 && hasUndirectedCycle(bnet, node, node2, node2, vector2)) {
                                    z2 = false;
                                    int id = bnet.siblings(node).getId((Node) vector2.elementAt(0));
                                    iArr2[i4 + 2] = 3 + bnet.siblings(node2).getId((Node) vector2.elementAt(1));
                                    iArr2[i4 + 3] = node2.getSiblings().size() > 1 ? 1 : 0;
                                    iArr2[i4 + 4] = (-3) - id;
                                    iArr2[i4 + 5] = node.getSiblings().size() > 1 ? 1 : 0;
                                    z = true;
                                }
                                if (z2) {
                                    nodeList2.insertNode(node2);
                                    d2 = this.metric.score(nodeList2);
                                    break;
                                }
                                break;
                            case 1:
                                if (iArr2[i4 + 1] == 2) {
                                    new Vector();
                                    if (bnet.directedDescendants(node2).indexOf(node) != -1) {
                                        z2 = false;
                                    }
                                }
                                if (z2) {
                                    NodeList parents = bnet.parents(node2);
                                    nodeList2.insertNode(node2);
                                    nodeList2.join(parents);
                                    d2 = this.metric.score(nodeList2);
                                    break;
                                }
                                break;
                            case 2:
                                Vector vector3 = new Vector();
                                bnet.setVisitedAll(false);
                                if (hasMixedCycle(bnet, node, node2, node2, vector3)) {
                                    z2 = false;
                                    Node node4 = (Node) vector3.elementAt(0);
                                    if (node2.getChildrenNodes().getId(node4) != -1) {
                                        iArr2[0] = 1;
                                    } else {
                                        Link link4 = bnet.getLink(node2.getName(), node4.getName());
                                        try {
                                            bnet.removeLink(link4);
                                        } catch (InvalidEditException e) {
                                        }
                                        Vector vector4 = new Vector();
                                        bnet.setVisitedAll(false);
                                        boolean hasMixedCycle = hasMixedCycle(bnet, node, node2, node2, vector4);
                                        try {
                                            bnet.createLink(link4.getTail(), link4.getHead(), false);
                                        } catch (InvalidEditException e2) {
                                        }
                                        if (hasMixedCycle) {
                                            iArr2[0] = 1;
                                        } else {
                                            int id2 = bnet.siblings(node2).getId(node4);
                                            if (iArr2[i4 + 2] > 0) {
                                                iArr2[i4 + 2] = 3 + id2;
                                            } else {
                                                iArr2[i4 + 2] = (-3) - id2;
                                            }
                                            z = true;
                                        }
                                    }
                                }
                                if (z2) {
                                    nodeList2.insertNode(node2);
                                    d2 = this.metric.score(nodeList2);
                                    break;
                                }
                                break;
                            default:
                                NodeList nodeList3 = new NodeList();
                                node3 = bnet.siblings(node2).elementAt(Math.abs(iArr2[i4]) - 3);
                                nodeList3.insertNode(node3);
                                nodeList2.insertNode(node2);
                                nodeList2.join(nodeList3);
                                d2 = this.metric.score(nodeList2);
                                break;
                        }
                        nodeList2.insertNode(node);
                        double score = this.metric.score(nodeList2) - d2;
                        if (z2 && score > d) {
                            vector.clear();
                            vector.addElement(node);
                            vector.addElement(node2);
                            if (node3 != null) {
                                vector.addElement(node3);
                            }
                            iArr[0] = i;
                            iArr[1] = Math.abs(iArr2[i4]);
                            iArr[2] = iArr2[i4 + 1];
                            d = score;
                        }
                        if (Math.abs(iArr2[i4]) < 4) {
                            i4 += 2;
                        } else if (z) {
                            i4 += 2;
                        } else if (iArr2[i4] < 0) {
                            int i5 = i4;
                            iArr2[i5] = iArr2[i5] + 1;
                        } else {
                            int i6 = i4;
                            iArr2[i6] = iArr2[i6] - 1;
                        }
                    }
                } else {
                    Node tail = link2.getTail();
                    Node head = link2.getHead();
                    if (link2.getDirected()) {
                        nodeList = bnet.parents(head);
                    } else {
                        nodeList = new NodeList();
                        nodeList.insertNode(tail);
                    }
                    NodeList nodeList4 = new NodeList();
                    nodeList4.insertNode(head);
                    nodeList4.join(nodeList);
                    double score2 = this.metric.score(nodeList4);
                    nodeList.removeNode(tail);
                    NodeList nodeList5 = new NodeList();
                    nodeList5.insertNode(head);
                    nodeList5.join(nodeList);
                    double score3 = this.metric.score(nodeList5) - score2;
                    if (score3 > d) {
                        d = score3;
                        vector.clear();
                        vector.addElement(tail);
                        vector.addElement(head);
                        iArr[0] = i;
                    }
                }
            }
        }
        return d;
    }

    public void rpdagNeighb(Bnet bnet, Node node, Node node2, int[] iArr) {
        int i;
        int size = bnet.siblings(node).size();
        int size2 = bnet.siblings(node2).size();
        int size3 = bnet.parents(node).size();
        int size4 = bnet.parents(node2).size();
        if (size == 0 && size2 == 0) {
            if (size3 == 0 && size4 == 0) {
                i = 1 + 1;
                iArr[1] = 0;
                iArr[i] = 0;
            } else {
                int i2 = 1 + 1;
                iArr[1] = 1;
                iArr[i2] = (size3 == 0 || bnet.children(node2).size() == 0) ? 0 : 2;
                int i3 = i2 + 1;
                i = i3 + 1;
                iArr[i3] = -1;
                iArr[i] = (size4 == 0 || bnet.children(node).size() == 0) ? 0 : 2;
            }
        } else if (size3 == 0 && size4 == 0) {
            if (size != 0 && size2 != 0) {
                int i4 = 1 + 1;
                iArr[1] = 0;
                iArr[i4] = 2;
                int i5 = i4 + 1;
                int i6 = i5 + 1;
                iArr[i5] = 2 + size2;
                iArr[i6] = size2 > 1 ? 1 : 0;
                int i7 = i6 + 1;
                i = i7 + 1;
                iArr[i7] = (-2) - size;
                iArr[i] = size > 1 ? 1 : 0;
            } else if (size == 0) {
                int i8 = 1 + 1;
                iArr[1] = 0;
                iArr[i8] = 0;
                int i9 = i8 + 1;
                i = i9 + 1;
                iArr[i9] = 2 + size2;
                iArr[i] = size2 > 1 ? 1 : 0;
            } else {
                int i10 = 1 + 1;
                iArr[1] = 0;
                iArr[i10] = 0;
                int i11 = i10 + 1;
                i = i11 + 1;
                iArr[i11] = (-2) - size;
                iArr[i] = size > 1 ? 1 : 0;
            }
        } else if (size == 0) {
            int i12 = 1 + 1;
            iArr[1] = -1;
            iArr[i12] = 0;
            int i13 = i12 + 1;
            int i14 = i13 + 1;
            iArr[i13] = 2;
            iArr[i14] = 1;
            int i15 = i14 + 1;
            i = i15 + 1;
            iArr[i15] = 2 + size2;
            iArr[i] = size2 > 1 ? 1 : 0;
        } else {
            int i16 = 1 + 1;
            iArr[1] = 1;
            iArr[i16] = 0;
            int i17 = i16 + 1;
            int i18 = i17 + 1;
            iArr[i17] = -2;
            iArr[i18] = 1;
            int i19 = i18 + 1;
            i = i19 + 1;
            iArr[i19] = (-2) - size;
            iArr[i] = size > 1 ? 1 : 0;
        }
        iArr[0] = i + 1;
    }

    public static void main(String[] strArr) throws ParseException, IOException {
        ParameterLearning dELearning;
        if (strArr.length < 4) {
            System.out.println("too few arguments: Usage: file.dbc file.elv nCases  Metric(K2,BIC,BDe) ");
            System.exit(0);
        }
        DataBaseCases dataBaseCases = new DataBaseCases(new FileInputStream(strArr[0]));
        if (Integer.valueOf(strArr[2]).intValue() > 0) {
            dataBaseCases.setNumberOfCases(Integer.valueOf(strArr[2]).intValue());
        }
        Metrics k2Metrics = strArr[3].equals("K2") ? new K2Metrics(dataBaseCases) : strArr[3].equals("BIC") ? new BICMetrics(dataBaseCases) : new BDeMetrics(dataBaseCases);
        double time = new Date().getTime();
        RPDAGLearning rPDAGLearning = new RPDAGLearning(dataBaseCases, true, k2Metrics);
        rPDAGLearning.learning();
        Date date = new Date();
        double time2 = (date.getTime() - time) / 1000.0d;
        System.out.println("Tiempo consumido: " + time2);
        System.out.println("Estadisticos evaluados: " + k2Metrics.getTotalStEval());
        System.out.println("Total de estadisticos: " + k2Metrics.getTotalSt());
        System.out.println("Numero medio de var en St: " + k2Metrics.getAverageNVars());
        if (rPDAGLearning.laplace) {
            dELearning = new LPLearning(dataBaseCases, rPDAGLearning.getOutput());
            dELearning.learning();
        } else {
            dELearning = new DELearning(dataBaseCases, rPDAGLearning.getOutput());
            dELearning.learning();
        }
        FileWriter fileWriter = new FileWriter(strArr[1]);
        String str = strArr[1];
        String substring = str.substring(str.lastIndexOf(47) + 1);
        Bnet output = dELearning.getOutput();
        output.setName(substring);
        output.setComment("learned with RPDAG from " + strArr[0] + "with " + dataBaseCases.getNumberOfCases() + " samples");
        output.saveBnet(fileWriter);
        fileWriter.close();
        System.out.println("Tempos consumido: " + ((date.getTime() - time2) / 1000.0d));
    }
}
