package elvira.learning.classification.supervised.discrete;

import elvira.Bnet;
import elvira.InvalidEditException;
import elvira.LinkList;
import elvira.NodeList;
import elvira.database.DataBaseCases;
import elvira.learning.BDeMetrics;
import elvira.learning.BICMetrics;
import elvira.learning.DELearning;
import elvira.learning.DVNSSTLearning;
import elvira.learning.K2Metrics;
import elvira.learning.LPLearning;
import elvira.learning.Metrics;
import elvira.learning.ParameterLearning;
import elvira.parser.ParseException;
import java.io.FileInputStream;
import java.io.FileWriter;
import java.io.IOException;

/* loaded from: input_file:bayelvira-1.0-SNAPSHOT.jar:elvira/learning/classification/supervised/discrete/LocalSearchLearning.class */
public class LocalSearchLearning extends MarkovBlanketLearning {
    Metrics metric;
    String metricName;

    public LocalSearchLearning() {
        this.metricName = new String();
        this.metric = null;
    }

    public LocalSearchLearning(DataBaseCases dataBaseCases, int i, boolean z, String str) {
        super(dataBaseCases, i, z);
        setMetrics(null);
        this.metricName = str;
        init();
    }

    private void init() {
        NodeList nodeList = getInput().getNodeList();
        nodeList.elementAt(this.classvar);
        Bnet bnet = new Bnet();
        bnet.setKindOfGraph(2);
        for (int i = 0; i < nodeList.size(); i++) {
            try {
                bnet.addNode(nodeList.elementAt(i));
            } catch (InvalidEditException e) {
                System.out.println("Error when building the initial empty graph");
            }
        }
        setOutput(bnet);
    }

    @Override // elvira.learning.classification.supervised.discrete.MarkovBlanketLearning, elvira.learning.classification.Classifier
    public void learn(DataBaseCases dataBaseCases, int i) {
        ParameterLearning dELearning;
        setInput(dataBaseCases);
        setVarToClassify(i);
        if (getMetrics() == null) {
            setMetrics(this.metricName.equals("BIC") ? new BICMetrics(getInput()) : this.metricName.equals("K2") ? new K2Metrics(getInput()) : new BDeMetrics(getInput()));
        } else {
            Metrics metrics = getMetrics();
            metrics.setData(dataBaseCases);
            setMetrics(metrics);
        }
        init();
        learning();
        if (getIfAplyLaplaceCorrection()) {
            dELearning = new LPLearning(dataBaseCases, getOutput());
            dELearning.learning();
        } else {
            dELearning = new DELearning(dataBaseCases, getOutput());
            dELearning.learning();
        }
        setOutput(dELearning.getOutput());
    }

    @Override // elvira.learning.classification.supervised.discrete.MarkovBlanketLearning, elvira.learning.Learning
    public void learning() {
        if (getMetrics() == null) {
            setMetrics(this.metricName.equals("BIC") ? new BICMetrics(getInput()) : this.metricName.equals("K2") ? new K2Metrics(getInput()) : new BDeMetrics(getInput()));
        }
        DVNSSTLearning dVNSSTLearning = new DVNSSTLearning(getInput(), 1, 0, 0, 1, getMetrics());
        dVNSSTLearning.setInitialBnet(getOutput());
        dVNSSTLearning.learning();
        setOutput(dVNSSTLearning.getOutput());
    }

    public void setMetrics(Metrics metrics) {
        this.metric = metrics;
    }

    public Metrics getMetrics() {
        return this.metric;
    }

    public static void main(String[] strArr) throws ParseException, IOException {
        Metrics bDeMetrics;
        String str;
        if (strArr.length < 4) {
            System.out.println("too few arguments: Usage: input.dbc ouput.elv class metric [file.elv]");
            System.out.println("\tinput.dbc : DataBaseCases file for building the bnet");
            System.out.println("\toutput.dbc : For saving the result");
            System.out.println("\tclass : The number of the variable to classify if it's the first use 0.");
            System.out.println("\tmetric : Metric used to score, it can be BIC,K2,BDe");
            System.out.println("\tfile.elv: Optional. True net to be compared.");
            System.exit(0);
        }
        DataBaseCases dataBaseCases = new DataBaseCases(new FileInputStream(strArr[0]));
        int intValue = new Integer(strArr[2]).intValue();
        if (strArr[3].equals("BIC")) {
            bDeMetrics = new BICMetrics(dataBaseCases);
            str = strArr[3];
        } else if (strArr[3].equals("K2")) {
            bDeMetrics = new K2Metrics(dataBaseCases);
            str = strArr[3];
        } else {
            bDeMetrics = new BDeMetrics(dataBaseCases);
            str = "BDe";
        }
        LocalSearchLearning localSearchLearning = new LocalSearchLearning(dataBaseCases, intValue, true, str);
        localSearchLearning.learning();
        LPLearning lPLearning = new LPLearning(dataBaseCases, localSearchLearning.getOutput());
        lPLearning.learning();
        double divergenceKL = dataBaseCases.getDivergenceKL(lPLearning.getOutput());
        System.out.println("KL Divergence = " + divergenceKL);
        System.out.println("Bayes Metric for the output net: " + bDeMetrics.score(localSearchLearning.getOutput()));
        FileWriter fileWriter = new FileWriter(strArr[1]);
        lPLearning.getOutput().saveBnet(fileWriter);
        fileWriter.close();
        if (strArr.length > 4) {
            Bnet bnet = new Bnet(new FileInputStream(strArr[4]));
            double divergenceKL2 = dataBaseCases.getDivergenceKL(bnet);
            System.out.println("kL Divergence for optimal net: " + divergenceKL2);
            System.out.println("Divergence between optimal and learned nets: " + (divergenceKL2 - divergenceKL));
            LinkList[] linkListArr = new LinkList[3];
            LinkList[] compareOutput = localSearchLearning.compareOutput(bnet);
            System.out.print("\nAdded links: " + compareOutput[0].size());
            System.out.print(compareOutput[0].toString());
            System.out.print("\nRemoved links: " + compareOutput[1].size());
            System.out.print(compareOutput[1].toString());
            System.out.println("\nInverted links: " + compareOutput[2].size());
            System.out.print(compareOutput[2].toString());
        }
    }
}
