package org.cytoscape.bayelviraapp.internal;

import com.csvreader.CsvReader;
import elvira.Bnet;
import elvira.CaseListMem;
import elvira.Configuration;
import elvira.Continuous;
import elvira.FiniteStates;
import elvira.InvalidEditException;
import elvira.Link;
import elvira.Node;
import elvira.NodeList;
import elvira.database.DataBaseCases;
import elvira.gui.explication.macroExplanation;
import elvira.learning.BDeMetrics;
import elvira.learning.BICMetrics;
import elvira.learning.K2Metrics;
import elvira.learning.classification.ConfusionMatrix;
import elvira.learning.classification.supervised.discrete.CMutInfKDB;
import elvira.learning.classification.supervised.discrete.CMutInfTAN;
import elvira.learning.classification.supervised.discrete.ClassTreeNaive;
import elvira.learning.classification.supervised.discrete.DiscreteClassifier;
import elvira.learning.classification.supervised.discrete.Naive_Bayes;
import elvira.learning.classification.supervised.discrete.WrapperSelectiveNaiveBayes;
import elvira.learning.classification.supervised.discrete.WrapperSemiNaiveBayes;
import elvira.learning.classification.supervised.mixed.Gaussian_Naive_Bayes;
import elvira.learning.classification.supervised.mixed.Selective_GNB;
import java.awt.Color;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.TreeMap;
import java.util.Vector;
import org.cytoscape.model.CyEdge;
import org.cytoscape.model.CyNetwork;
import org.cytoscape.model.CyNode;
import org.cytoscape.service.util.CyServiceRegistrar;
import org.cytoscape.task.AbstractNetworkViewTask;
import org.cytoscape.view.layout.CyLayoutAlgorithm;
import org.cytoscape.view.layout.CyLayoutAlgorithmManager;
import org.cytoscape.view.model.CyNetworkView;
import org.cytoscape.view.presentation.property.NodeShapeVisualProperty;
import org.cytoscape.view.presentation.property.values.NodeShape;
import org.cytoscape.work.Task;
import org.cytoscape.work.TaskMonitor;
import org.cytoscape.work.Tunable;
import org.cytoscape.work.util.ListSingleSelection;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.converters.ArffLoader;
import weka.core.xml.XMLInstances;

/* loaded from: input_file:org/cytoscape/bayelviraapp/internal/ImportBDTask.class */
public class ImportBDTask extends AbstractNetworkViewTask {

    @Tunable(description = "Database file (.arff, .csv, .dbc or .elv)", groups = {"Input data"}, params = "fileCategory=table;input=true")
    public File data;

    @Tunable(description = "Database name", groups = {"Input data"})
    public String name;

    @Tunable(description = "Bayes net type", groups = {"Input data"})
    public ListSingleSelection<String> algorithms;

    @Tunable(description = "Class attribute name (default: last attribute)", groups = {"Input data"})
    public String classAtt;

    @Tunable(description = "Export calculated bayes net in .elv format", groups = {"Output"})
    public boolean export;

    @Tunable(description = "Output file", groups = {"Output"}, params = "fileCategory=unspecified", dependsOn = "export=true")
    public File outBnet;
    CyLayoutAlgorithm layout;
    CyServiceRegistrar serviceRegistrarRef;
    public static final String CSV_EXT = "csv";
    public static final String ARFF_EXT = "arff";
    public static final String DBC_EXT = "dbc";
    public static final String ELV_EXT = "elv";
    public static final String ATT_NODE_TYPE = "node_type";
    public static final String ATT_NODE_LABEL = "name";
    public static final String ATT_EDGE_DIST = "distance";
    public static final String ATT_EDGE_COMPARE = "compare";
    private static final String DEFAULT_LAYOUT = "force-directed";
    private static final double EDGE_WIDE_MAX = 4.0d;
    private static final double TRAIN_TEST_RATE = 0.7d;
    DiscreteClassifier classif;
    protected static final String ALG_NB = "Naive Bayes (discrete)";
    protected static final String ALG_SEMINB = "Semi Naive Bayes (discrete)";
    protected static final String ALG_KDB = "KDB (discrete)";
    protected static final String ALG_TAN = "TAN (discrete)";
    protected static final String ALG_SELNB = "Selective Naive Bayes (discrete)";
    protected static final String ALG_CLTREE = "Class Tree Naive (discrete)";
    protected static final String ALG_GAUS_NB = "Gaussian Naive Bayes (mixed)";
    protected static final String ALG_SEL_GNB = "Selective Gaussian Naive Bayes (mixed)";
    protected static final String[] ALGORITHMS = {ALG_NB, ALG_SEMINB, ALG_KDB, ALG_TAN, ALG_SELNB, ALG_CLTREE, ALG_GAUS_NB, ALG_SEL_GNB};
    public static final HashMap<String, Color> ATT_COMPARE_MAP = new HashMap<>();
    public static final HashMap<String, NodeShape> ATT_NODE_TYPE_MAP = new HashMap<>();

    public ImportBDTask(CyServiceRegistrar cyServiceRegistrar, CyNetworkView cyNetworkView) {
        super(cyNetworkView);
        this.export = false;
        this.classif = null;
        this.serviceRegistrarRef = cyServiceRegistrar;
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(Arrays.asList(ALGORITHMS));
        this.algorithms = new ListSingleSelection<>(arrayList);
        this.algorithms.setSelectedValue(ALG_NB);
        this.layout = ((CyLayoutAlgorithmManager) this.serviceRegistrarRef.getService(CyLayoutAlgorithmManager.class)).getLayout(DEFAULT_LAYOUT);
    }

    protected DiscreteClassifier createClassifier(String str, DataBaseCases dataBaseCases) throws Exception {
        if (str.equals(ALG_NB)) {
            return new Naive_Bayes(dataBaseCases, true);
        }
        if (str.equals(ALG_SEMINB)) {
            return new WrapperSemiNaiveBayes(dataBaseCases, true);
        }
        if (str.equals(ALG_TAN)) {
            return new CMutInfTAN(dataBaseCases, true);
        }
        if (str.equals(ALG_KDB)) {
            return new CMutInfKDB(dataBaseCases, true, 5);
        }
        if (str.equals(ALG_SELNB)) {
            return new WrapperSelectiveNaiveBayes(dataBaseCases, true);
        }
        if (str.equals(ALG_CLTREE)) {
            return new ClassTreeNaive(dataBaseCases);
        }
        if (str.equals(ALG_GAUS_NB)) {
            return new Gaussian_Naive_Bayes(dataBaseCases, true, getClassIndex(dataBaseCases));
        }
        if (str.equals(ALG_SEL_GNB)) {
            return new Selective_GNB(dataBaseCases, true, getClassIndex(dataBaseCases));
        }
        return null;
    }

    public void run(TaskMonitor taskMonitor) throws Exception {
        if (this.view == null) {
            return;
        }
        taskMonitor.setTitle("Create bayesian network.");
        Logger logger = LoggerFactory.getLogger("CyUserMessages");
        DataBaseCases dataBaseCases = null;
        Bnet bnet = null;
        if (this.data.getName().toLowerCase().endsWith(CSV_EXT)) {
            dataBaseCases = parseCSV();
        } else if (this.data.getName().toLowerCase().endsWith(ARFF_EXT)) {
            dataBaseCases = parseArff();
        } else if (this.data.getName().toLowerCase().endsWith(DBC_EXT)) {
            dataBaseCases = new DataBaseCases(new FileInputStream(this.data));
        } else if (this.data.getName().toLowerCase().endsWith(ELV_EXT)) {
            bnet = new Bnet(this.data.getAbsolutePath());
        }
        if (dataBaseCases != null) {
            DataBaseCases dataBaseCases2 = new DataBaseCases();
            DataBaseCases dataBaseCases3 = new DataBaseCases();
            dataBaseCases.divideIntoTrainAndTest(dataBaseCases2, dataBaseCases3, TRAIN_TEST_RATE);
            this.classif = createClassifier((String) this.algorithms.getSelectedValue(), dataBaseCases2);
            if (this.classif == null) {
                taskMonitor.setStatusMessage("Cannot determine network type.");
                logger.error("Cannot determine network type.");
                return;
            }
            if (!setClassVar()) {
                taskMonitor.setStatusMessage("Class attribute not found.");
                logger.error("Class attribute not found.");
                return;
            }
            taskMonitor.setProgress(0.1d);
            taskMonitor.setStatusMessage("Performing training and test.");
            this.classif.train();
            this.classif.test(dataBaseCases3);
            taskMonitor.setProgress(TRAIN_TEST_RATE);
            taskMonitor.setStatusMessage("Creating network.");
            createNetwork(null, this.classif, (CyNetwork) this.view.getModel());
            taskMonitor.setProgress(0.9d);
            taskMonitor.setStatusMessage("Calculating metrics.");
            calcMetrics(this.classif, dataBaseCases2, (CyNetwork) this.view.getModel());
        } else if (bnet == null) {
            taskMonitor.setStatusMessage("Bad table file format.");
            logger.error("Bad table file format.");
            return;
        } else {
            taskMonitor.setProgress(TRAIN_TEST_RATE);
            taskMonitor.setStatusMessage("Creating network.");
            createNetwork(bnet, null, (CyNetwork) this.view.getModel());
        }
        if (this.export) {
            exportBnet(dataBaseCases != null ? this.classif.getClassifier() : bnet);
        }
        if (this.layout != null) {
            insertTasksAfterCurrentTask(this.layout.createTaskIterator(this.view, this.layout.createLayoutContext(), new HashSet(this.view.getNodeViews()), (String) null));
        }
        insertTasksAfterCurrentTask(new Task[]{new VisualStyleTask(this.serviceRegistrarRef, this.view)});
    }

    protected void exportBnet(Bnet bnet) throws IOException {
        if (!this.export || this.outBnet == null) {
            return;
        }
        if (!this.outBnet.getName().endsWith(ELV_EXT)) {
            this.outBnet = new File(this.outBnet.getAbsolutePath() + '.' + ELV_EXT);
        }
        bnet.saveBnet(new FileWriter(this.outBnet));
    }

    protected DataBaseCases parseCSV() throws IOException, InvalidEditException {
        CsvReader csvReader = new CsvReader(new BufferedReader(new FileReader(this.data)));
        csvReader.readHeaders();
        String[] headers = csvReader.getHeaders();
        HashMap hashMap = new HashMap();
        while (csvReader.readRecord()) {
            for (int i = 0; i < headers.length; i++) {
                if (hashMap.get(Integer.valueOf(i)) == null) {
                    hashMap.put(Integer.valueOf(i), new HashSet());
                }
                ((HashSet) hashMap.get(Integer.valueOf(i))).add(Integer.valueOf(Integer.parseInt(csvReader.get(i))));
            }
        }
        csvReader.close();
        Vector vector = new Vector();
        for (int i2 = 0; i2 < headers.length; i2++) {
            FiniteStates finiteStates = new FiniteStates(((HashSet) hashMap.get(Integer.valueOf(i2))).size());
            finiteStates.setName(headers[i2]);
            vector.add(finiteStates);
        }
        CaseListMem caseListMem = new CaseListMem(vector);
        CsvReader csvReader2 = new CsvReader(new BufferedReader(new FileReader(this.data)));
        csvReader2.readHeaders();
        while (csvReader2.readRecord()) {
            Configuration configuration = new Configuration(vector);
            for (int i3 = 0; i3 < headers.length; i3++) {
                configuration.setValue(i3, Integer.valueOf(Integer.parseInt(csvReader2.get(i3))).intValue());
            }
            caseListMem.put(configuration);
        }
        csvReader2.close();
        return new DataBaseCases(this.name, new NodeList((Vector<Node>) vector), caseListMem);
    }

    protected DataBaseCases parseArff() throws IOException, InvalidEditException {
        return parseArff(new FileInputStream(this.data), this.name);
    }

    protected static DataBaseCases parseArff(InputStream inputStream, String str) throws IOException, InvalidEditException {
        ArffLoader arffLoader = new ArffLoader();
        arffLoader.setSource(inputStream);
        Instances dataSet = arffLoader.getDataSet();
        Vector vector = new Vector();
        for (int i = 0; i < dataSet.numAttributes(); i++) {
            FiniteStates finiteStates = new FiniteStates(dataSet.attribute(i).numValues());
            finiteStates.setName(dataSet.attribute(i).name());
            vector.add(finiteStates);
        }
        CaseListMem caseListMem = new CaseListMem(vector);
        for (int i2 = 0; i2 < dataSet.numInstances(); i2++) {
            Configuration configuration = new Configuration(vector);
            Instance instance = dataSet.instance(i2);
            for (int i3 = 0; i3 < dataSet.numAttributes(); i3++) {
                configuration.setValue(i3, (int) instance.value(i3));
            }
            caseListMem.put(configuration);
        }
        return new DataBaseCases(str, new NodeList((Vector<Node>) vector), caseListMem);
    }

    protected static void createNetwork(Bnet bnet, DiscreteClassifier discreteClassifier, CyNetwork cyNetwork) {
        String name;
        Logger logger = LoggerFactory.getLogger("CyUserMessages");
        if (bnet == null) {
            bnet = discreteClassifier.getClassifier();
            name = discreteClassifier.getDataBaseCases().getName();
        } else {
            name = bnet.getName();
        }
        cyNetwork.getRow(cyNetwork).set("name", name);
        TreeMap treeMap = new TreeMap();
        cyNetwork.getDefaultNodeTable().createColumn(ATT_NODE_TYPE, String.class, true);
        Iterator<Node> it = bnet.getNodeList().getNodes().iterator();
        while (it.hasNext()) {
            Node next = it.next();
            CyNode addNode = cyNetwork.addNode();
            cyNetwork.getRow(addNode).set("name", next.getName());
            treeMap.put(next.getName(), addNode);
            cyNetwork.getDefaultNodeTable().getRow(addNode.getSUID()).set(ATT_NODE_TYPE, getNodeType(next, discreteClassifier));
        }
        cyNetwork.getDefaultEdgeTable().createColumn(ATT_EDGE_DIST, Double.class, true);
        cyNetwork.getDefaultEdgeTable().createColumn(ATT_EDGE_COMPARE, String.class, true);
        logger.debug("Total nodes = {}", Integer.valueOf(bnet.getNodeList().size()));
        Iterator it2 = bnet.getLinkList().getLinks().iterator();
        while (it2.hasNext()) {
            Link link = (Link) it2.next();
            calcLinkStyles(bnet, link, cyNetwork.addEdge((CyNode) treeMap.get(link.getHead().getName()), (CyNode) treeMap.get(link.getTail().getName()), true), cyNetwork);
        }
    }

    protected static String getNodeType(Node node, DiscreteClassifier discreteClassifier) {
        return (discreteClassifier != null && node.getName().equals(discreteClassifier.getClassVar().getName())) ? "class" : XMLInstances.TAG_ATTRIBUTE;
    }

    protected static void calcLinkStyles(Bnet bnet, Link link, CyEdge cyEdge, CyNetwork cyNetwork) {
        Object obj = "noncomparable";
        double d = 0.2d;
        boolean z = false;
        if ((link.getHead() instanceof Continuous) || (link.getTail() instanceof Continuous)) {
            z = true;
        }
        if (!z && !isIncomingToSVNode(link)) {
            double[][][] greaterdist = macroExplanation.greaterdist(bnet, link.getHead(), link.getTail());
            int compare = macroExplanation.compare(greaterdist);
            switch (compare) {
                case 0:
                    obj = "greater";
                    break;
                case 1:
                    obj = "less";
                    break;
                case 3:
                    obj = "equals";
                    break;
            }
            if (compare == 0 || compare == 1) {
                d = macroExplanation.influence(greaterdist);
            }
        }
        cyNetwork.getDefaultEdgeTable().getRow(cyEdge.getSUID()).set(ATT_EDGE_DIST, Double.valueOf(d * EDGE_WIDE_MAX));
        cyNetwork.getDefaultEdgeTable().getRow(cyEdge.getSUID()).set(ATT_EDGE_COMPARE, obj);
    }

    private static boolean isIncomingToSVNode(Link link) {
        return link.getHead().getKindOfNode() == 3;
    }

    public static void calcMetrics(DiscreteClassifier discreteClassifier, DataBaseCases dataBaseCases, CyNetwork cyNetwork) {
        if (!hasContinuous(dataBaseCases)) {
            BICMetrics bICMetrics = new BICMetrics(dataBaseCases);
            cyNetwork.getDefaultNetworkTable().createColumn("BIC score", Double.class, true);
            cyNetwork.getRow(cyNetwork).set("BIC score", Double.valueOf(bICMetrics.score(discreteClassifier.getClassifier())));
            K2Metrics k2Metrics = new K2Metrics(dataBaseCases);
            cyNetwork.getDefaultNetworkTable().createColumn("K2 score", Double.class, true);
            cyNetwork.getRow(cyNetwork).set("K2 score", Double.valueOf(k2Metrics.score(discreteClassifier.getClassifier())));
            BDeMetrics bDeMetrics = new BDeMetrics(dataBaseCases);
            cyNetwork.getDefaultNetworkTable().createColumn("BDe score", Double.class, true);
            cyNetwork.getRow(cyNetwork).set("BDe score", Double.valueOf(bDeMetrics.score(discreteClassifier.getClassifier())));
        }
        ConfusionMatrix confusionMatrix = discreteClassifier.getConfusionMatrix();
        cyNetwork.getDefaultNetworkTable().createColumn("Accuracy", Double.class, true);
        cyNetwork.getRow(cyNetwork).set("Accuracy", Double.valueOf(confusionMatrix.getAccuracy()));
        cyNetwork.getDefaultNetworkTable().createColumn("#Cases", Integer.class, true);
        cyNetwork.getRow(cyNetwork).set("#Cases", Integer.valueOf(confusionMatrix.getCases()));
        cyNetwork.getDefaultNetworkTable().createColumn("Error", Double.class, true);
        cyNetwork.getRow(cyNetwork).set("Error", Double.valueOf(confusionMatrix.getError()));
        for (int i = 0; i < confusionMatrix.getDimension(); i++) {
            for (int i2 = 0; i2 < confusionMatrix.getDimension(); i2++) {
                int value = (int) confusionMatrix.getValue(i, i2);
                String str = i + "_classif_as_" + i2;
                cyNetwork.getDefaultNetworkTable().createColumn(str, Integer.class, true);
                cyNetwork.getRow(cyNetwork).set(str, Integer.valueOf(value));
            }
        }
        if (confusionMatrix.getDimension() == 2) {
            double value2 = confusionMatrix.getValue(1, 1) / (confusionMatrix.getValue(1, 1) + confusionMatrix.getValue(1, 0));
            cyNetwork.getDefaultNetworkTable().createColumn("Sensitivity", Double.class, true);
            cyNetwork.getRow(cyNetwork).set("Sensitivity", Double.valueOf(value2));
            double value3 = confusionMatrix.getValue(0, 0) / (confusionMatrix.getValue(0, 0) + confusionMatrix.getValue(0, 1));
            cyNetwork.getDefaultNetworkTable().createColumn("Specificity", Double.class, true);
            cyNetwork.getRow(cyNetwork).set("Specificity", Double.valueOf(value3));
        }
    }

    private boolean setClassVar() {
        FiniteStates finiteStates = null;
        if (!this.classAtt.isEmpty()) {
            Iterator<Node> it = this.classif.getDataBaseCases().getNodeList().getNodes().iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                Node next = it.next();
                if (next.getName().equalsIgnoreCase(this.classAtt)) {
                    finiteStates = (FiniteStates) next;
                    break;
                }
            }
        } else {
            finiteStates = (FiniteStates) this.classif.getDataBaseCases().getNodeList().lastElement();
        }
        if (finiteStates == null) {
            return false;
        }
        this.classif.setClassVar(finiteStates);
        return true;
    }

    private int getClassIndex(DataBaseCases dataBaseCases) {
        int size = dataBaseCases.getNodeList().size() - 1;
        if (!this.classAtt.isEmpty()) {
            int i = 0;
            Iterator<Node> it = dataBaseCases.getNodeList().getNodes().iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                if (it.next().getName().equalsIgnoreCase(this.classAtt)) {
                    size = i;
                    break;
                }
                i++;
            }
        }
        return size;
    }

    private static boolean hasContinuous(DataBaseCases dataBaseCases) {
        Iterator<Node> it = dataBaseCases.getNodeList().getNodes().iterator();
        while (it.hasNext()) {
            if (it.next() instanceof Continuous) {
                return true;
            }
        }
        return false;
    }

    static {
        ATT_COMPARE_MAP.put("greater", Color.RED);
        ATT_COMPARE_MAP.put("less", Color.BLUE);
        ATT_COMPARE_MAP.put("equals", new Color(204, 102, 255));
        ATT_COMPARE_MAP.put("noncomparable", Color.BLACK);
        ATT_NODE_TYPE_MAP.put(XMLInstances.TAG_ATTRIBUTE, NodeShapeVisualProperty.ROUND_RECTANGLE);
        ATT_NODE_TYPE_MAP.put("class", NodeShapeVisualProperty.ELLIPSE);
    }
}
