package elvira.learning.classification.supervised.discrete;

import elvira.Bnet;
import elvira.Configuration;
import elvira.FiniteStates;
import elvira.LinkList;
import elvira.Node;
import elvira.NodeList;
import elvira.database.DataBaseCases;
import elvira.learning.DELearning;
import elvira.learning.LPLearning;
import elvira.learning.Learning;
import elvira.learning.ParameterLearning;
import elvira.learning.classification.ClassifierValidator;
import elvira.learning.classification.ConfusionMatrix;
import elvira.learning.classification.SizeComparableClassifier;
import elvira.parser.ParseException;
import elvira.potential.PotentialTable;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.util.Vector;

/* loaded from: input_file:bayelvira-1.0-SNAPSHOT.jar:elvira/learning/classification/supervised/discrete/MarkovBlanketLearning.class */
public abstract class MarkovBlanketLearning extends Learning implements SizeComparableClassifier {
    DataBaseCases input;
    int classvar;
    boolean laplace;

    public MarkovBlanketLearning() {
        setInput(null);
        setOutput(null);
        setIfAplyLaplaceCorrection(true);
        setVarToClassify(0);
    }

    public MarkovBlanketLearning(DataBaseCases dataBaseCases, int i, boolean z) {
        setVarToClassify(i);
        setInput(dataBaseCases);
        setIfAplyLaplaceCorrection(z);
    }

    @Override // elvira.learning.Learning
    public abstract void learning();

    public void setInput(DataBaseCases dataBaseCases) {
        this.input = dataBaseCases;
    }

    public DataBaseCases getInput() {
        return this.input;
    }

    public void setVarToClassify(int i) {
        this.classvar = i;
    }

    public int getVarToClassify() {
        return this.classvar;
    }

    public void setIfAplyLaplaceCorrection(boolean z) {
        this.laplace = z;
    }

    public boolean getIfAplyLaplaceCorrection() {
        return this.laplace;
    }

    public void learn(DataBaseCases dataBaseCases, int i) {
        ParameterLearning dELearning;
        setInput(dataBaseCases);
        setVarToClassify(i);
        learning();
        if (this.laplace) {
            dELearning = new LPLearning(dataBaseCases, getOutput());
            dELearning.learning();
        } else {
            dELearning = new DELearning(dataBaseCases, getOutput());
            dELearning.learning();
        }
        setOutput(dELearning.getOutput());
    }

    @Override // elvira.learning.classification.Classifier
    public Vector classify(Configuration configuration, int i) {
        int size = getInput().getNodeList().size();
        int[] iArr = new int[size];
        Vector values = configuration.getValues();
        for (int i2 = 0; i2 < size; i2++) {
            iArr[i2] = ((Integer) values.elementAt(i2)).intValue();
        }
        Vector vector = new Vector();
        Node elementAt = getOutput().getNodeList().elementAt(i);
        PotentialTable potentialTable = (PotentialTable) getOutput().getRelation(elementAt).getValues();
        int numStates = ((FiniteStates) elementAt).getNumStates();
        NodeList parentNodes = elementAt.getParentNodes();
        for (int i3 = 0; i3 < numStates; i3++) {
            Vector vector2 = new Vector();
            Vector vector3 = new Vector();
            vector2.addElement(elementAt);
            vector3.addElement(new Integer(i3));
            for (int i4 = 0; i4 < parentNodes.size(); i4++) {
                vector2.addElement(parentNodes.elementAt(i4));
                vector3.addElement(new Integer(iArr[getOutput().getNodeList().getId(parentNodes.elementAt(i4))]));
            }
            double value = potentialTable.getValue(new Configuration(vector2, vector3));
            NodeList childrenNodes = elementAt.getChildrenNodes();
            for (int i5 = 0; i5 < childrenNodes.size(); i5++) {
                FiniteStates finiteStates = (FiniteStates) childrenNodes.elementAt(i5);
                PotentialTable potentialTable2 = (PotentialTable) getOutput().getRelation(finiteStates).getValues();
                NodeList parentNodes2 = finiteStates.getParentNodes();
                Vector vector4 = new Vector();
                Vector vector5 = new Vector();
                vector4.addElement(finiteStates);
                vector5.addElement(new Integer(iArr[getInput().getNodeList().getId(finiteStates)]));
                for (int i6 = 0; i6 < parentNodes2.size(); i6++) {
                    if (!parentNodes2.elementAt(i6).equals(elementAt)) {
                        vector4.addElement(parentNodes2.elementAt(i6));
                        vector5.addElement(new Integer(iArr[getOutput().getNodeList().getId(parentNodes2.elementAt(i6))]));
                    }
                }
                vector4.addElement(elementAt);
                vector5.addElement(new Integer(i3));
                value *= potentialTable2.getValue(new Configuration(vector4, vector5));
            }
            vector.addElement(new Double(value));
        }
        return vector;
    }

    public static void main(String[] strArr) throws ParseException, IOException {
        if (strArr.length < 3) {
            System.out.println("too few arguments: Usage: input.elv input.dbc class [file.elv]");
            System.out.println("\tinput.elv : Bayesian network to test and compare");
            System.out.println("\tinput.dbc : DataBaseCases file for test the bnet");
            System.out.println("\tclass : The number of the variable to classify if it's the first use 0.");
            System.out.println("\tfile.elv: Optional. True net to be compared.");
            System.exit(0);
        }
        FileInputStream fileInputStream = new FileInputStream(strArr[0]);
        Bnet bnet = new Bnet(fileInputStream);
        fileInputStream.close();
        FileInputStream fileInputStream2 = new FileInputStream(strArr[1]);
        DataBaseCases dataBaseCases = new DataBaseCases(fileInputStream2);
        fileInputStream2.close();
        int intValue = new Integer(strArr[2]).intValue();
        MarkovBlanketLearning markovBlanketLearning = new MarkovBlanketLearning() { // from class: elvira.learning.classification.supervised.discrete.MarkovBlanketLearning.1
            @Override // elvira.learning.classification.supervised.discrete.MarkovBlanketLearning, elvira.learning.Learning
            public void learning() {
            }

            public void saveModelToFile(File file) {
            }
        };
        markovBlanketLearning.setInput(dataBaseCases);
        markovBlanketLearning.setOutput(bnet);
        markovBlanketLearning.setVarToClassify(intValue);
        ConfusionMatrix confusionMatrix = new ClassifierValidator(markovBlanketLearning, dataBaseCases, intValue).confusionMatrix(markovBlanketLearning, dataBaseCases, intValue);
        System.out.println("");
        confusionMatrix.print();
        System.out.println("Accuracy=" + confusionMatrix.getAccuracy() + "% ");
        System.out.println("");
        if (strArr.length > 3) {
            FileInputStream fileInputStream3 = new FileInputStream(strArr[3]);
            Bnet bnet2 = new Bnet(fileInputStream3);
            fileInputStream3.close();
            double divergenceKL = dataBaseCases.getDivergenceKL(bnet);
            double divergenceKL2 = dataBaseCases.getDivergenceKL(bnet2);
            System.out.println("kL Divergence for input net: " + divergenceKL2);
            System.out.println("kL Divergence for optimal net: " + divergenceKL2);
            System.out.println("Divergence between optimal and learned nets: " + (divergenceKL2 - divergenceKL));
            LinkList[] linkListArr = new LinkList[3];
            LinkList[] compareOutput = markovBlanketLearning.compareOutput(bnet2);
            System.out.println("\nAdded links: " + compareOutput[0].size());
            System.out.print(compareOutput[0].toString());
            System.out.println("\nRemoved links: " + compareOutput[1].size());
            System.out.print(compareOutput[1].toString());
            System.out.println("\nInverted links: " + compareOutput[2].size());
            System.out.print(compareOutput[2].toString());
        }
    }

    @Override // elvira.learning.classification.SizeComparableClassifier
    public long size() {
        return getOutput().getNumberOfFreeParameters();
    }

    @Override // elvira.learning.classification.Classifier
    public void saveModelToFile(String str) throws IOException {
        FileWriter fileWriter = new FileWriter(str + "markovBlanketLearning.elv");
        getOutput().saveBnet(fileWriter);
        fileWriter.close();
    }
}
