package elvira.learning.classification.supervised.discrete;

import elvira.Bnet;
import elvira.CaseListMem;
import elvira.Configuration;
import elvira.FiniteStates;
import elvira.InvalidEditException;
import elvira.Node;
import elvira.NodeList;
import elvira.Relation;
import elvira.database.DataBaseCases;
import elvira.learning.classification.ClassifierValidator;
import elvira.parser.ParseException;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileWriter;
import java.io.IOException;
import java.util.Vector;
import weka.core.TestInstances;

/* loaded from: input_file:bayelvira-1.0-SNAPSHOT.jar:elvira/learning/classification/supervised/discrete/WrapperSelectiveNaiveBayes.class */
public class WrapperSelectiveNaiveBayes extends SelectiveNaiveBayes {
    private int k_fold;

    public WrapperSelectiveNaiveBayes() {
        this.k_fold = 5;
    }

    public WrapperSelectiveNaiveBayes(int i) {
        this.k_fold = i;
    }

    public WrapperSelectiveNaiveBayes(DataBaseCases dataBaseCases, boolean z) throws InvalidEditException {
        super(dataBaseCases, z);
        this.k_fold = 5;
    }

    public WrapperSelectiveNaiveBayes(DataBaseCases dataBaseCases, boolean z, int i) throws InvalidEditException {
        super(dataBaseCases, z);
        this.k_fold = i;
    }

    private DataBaseCases generateDbcInclude(NodeList nodeList, Node node) {
        DataBaseCases dataBaseCases = new DataBaseCases();
        dataBaseCases.setName(this.cases.getName());
        dataBaseCases.setTitle(this.cases.getName());
        Node copy = nodeList.lastElement().copy();
        Vector vector = new Vector();
        for (int i = 0; i < nodeList.size() - 1; i++) {
            vector.addElement(nodeList.elementAt(i).copy());
        }
        vector.addElement(node.copy());
        vector.addElement(copy);
        NodeList nodeList2 = new NodeList((Vector<Node>) vector);
        dataBaseCases.setNodeList(nodeList2);
        NodeList nodeList3 = this.cases.getNodeList();
        CaseListMem caseListMem = new CaseListMem(nodeList2);
        Configuration configuration = new Configuration(nodeList2);
        Vector cases = ((CaseListMem) ((Relation) this.cases.getRelationList().get(0)).getValues()).getCases();
        int[] iArr = new int[this.nVariables];
        for (int i2 = 0; i2 < this.nCases; i2++) {
            Configuration configuration2 = new Configuration(nodeList2);
            int[] iArr2 = (int[]) cases.elementAt(i2);
            for (int i3 = 0; i3 < nodeList2.size(); i3++) {
                FiniteStates finiteStates = (FiniteStates) nodeList2.getNodes().elementAt(i3);
                configuration2.putValue(finiteStates, iArr2[nodeList3.getId(finiteStates)]);
            }
            configuration.setValues(configuration2.getValues());
            caseListMem.put(configuration);
        }
        Vector vector2 = new Vector();
        Relation relation = new Relation();
        relation.setVariables(nodeList2);
        relation.setValues(caseListMem);
        vector2.addElement(relation);
        dataBaseCases.setRelationList(vector2);
        dataBaseCases.setNumberOfCases(this.cases.getNumberOfCases());
        return dataBaseCases;
    }

    @Override // elvira.learning.classification.supervised.discrete.SelectiveNaiveBayes, elvira.learning.classification.supervised.discrete.DiscreteClassifier
    public void structuralLearning() throws InvalidEditException, Exception {
        Vector vector = new Vector();
        vector.addElement(this.cases.getVariables().elementAt(this.nVariables - 1));
        NodeList nodeList = new NodeList((Vector<Node>) vector);
        Vector vector2 = new Vector();
        Vector vector3 = new Vector();
        for (int i = 0; i < this.nVariables - 1; i++) {
            vector3.addElement(this.cases.getVariables().elementAt(i));
        }
        double d = Double.MIN_VALUE;
        DataBaseCases dataBaseCases = new DataBaseCases();
        boolean z = false;
        while (!z) {
            double d2 = Double.MIN_VALUE;
            DataBaseCases dataBaseCases2 = new DataBaseCases();
            Node finiteStates = new FiniteStates();
            for (int i2 = 0; i2 < vector3.size(); i2++) {
                Node node = (Node) vector3.elementAt(i2);
                DataBaseCases generateDbcInclude = generateDbcInclude(nodeList, node);
                double error = 1.0d - new ClassifierValidator((DiscreteClassifierDiscriminativeLearning) new Naive_Bayes(), generateDbcInclude, generateDbcInclude.getClassId()).kFoldCrossValidation(this.k_fold).getError();
                this.evaluations++;
                if (error > d2) {
                    d2 = error;
                    dataBaseCases2 = generateDbcInclude;
                    finiteStates = node;
                }
            }
            if (d2 > d) {
                d = d2;
                dataBaseCases = dataBaseCases2;
                nodeList = dataBaseCases.getNodeList();
                vector2.addElement(finiteStates);
                vector3.removeElement(finiteStates);
                System.out.println(d + TestInstances.DEFAULT_SEPARATORS + finiteStates.toString());
            } else {
                z = true;
            }
        }
        this.accurateClassifier = new Naive_Bayes(dataBaseCases, this.laplace);
        this.accurateClassifier.train();
        this.classifier = new Bnet();
        this.classifier = this.accurateClassifier.getClassifier();
        System.out.println();
        System.out.println(this.accurateClassifier.getClassifier().getNodeList().toString());
        System.out.println("    " + (this.accurateClassifier.getClassifier().getNodeList().size() - 1) + " selected variables");
        System.out.println("    " + this.evaluations + " evaluated solutions");
    }

    public static void main(String[] strArr) throws FileNotFoundException, IOException, InvalidEditException, ParseException, Exception {
        if (strArr.length != 3) {
            System.out.println("Usage: file-train.dbc file-test.dbc file-out.elv");
            System.exit(0);
        }
        FileInputStream fileInputStream = new FileInputStream(strArr[0]);
        DataBaseCases dataBaseCases = new DataBaseCases(fileInputStream);
        fileInputStream.close();
        WrapperSelectiveNaiveBayes wrapperSelectiveNaiveBayes = new WrapperSelectiveNaiveBayes(dataBaseCases, true);
        wrapperSelectiveNaiveBayes.train();
        System.out.println("Classifier learned");
        FileInputStream fileInputStream2 = new FileInputStream(strArr[1]);
        DataBaseCases dataBaseCases2 = new DataBaseCases(fileInputStream2);
        fileInputStream2.close();
        System.out.println("Classifier tested. Accuracy: " + wrapperSelectiveNaiveBayes.test(dataBaseCases2));
        wrapperSelectiveNaiveBayes.getConfusionMatrix().print();
        FileWriter fileWriter = new FileWriter(strArr[2]);
        wrapperSelectiveNaiveBayes.getClassifier().saveBnet(fileWriter);
        fileWriter.close();
    }
}
