package elvira.learning.classification.unsupervised.mixed;

import elvira.Bnet;
import elvira.ContinuousCaseListMem;
import elvira.ContinuousConfiguration;
import elvira.Evidence;
import elvira.FiniteStates;
import elvira.Link;
import elvira.LinkList;
import elvira.Node;
import elvira.NodeList;
import elvira.Relation;
import elvira.database.DataBaseCases;
import elvira.inference.elimination.VariableElimination;
import elvira.learning.MTELearning;
import elvira.parser.ParseException;
import elvira.potential.ContinuousProbabilityTree;
import elvira.potential.MixtExpDensity;
import elvira.potential.PotentialContinuousPT;
import java.io.FileInputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.util.Vector;

/* loaded from: input_file:bayelvira-1.0-SNAPSHOT.jar:elvira/learning/classification/unsupervised/mixed/UnsupervisedMTENaiveBayes.class */
public class UnsupervisedMTENaiveBayes {
    Bnet classifier = new Bnet();
    DataBaseCases train = new DataBaseCases();
    DataBaseCases test = new DataBaseCases();

    public void EMAlgorithm(Bnet bnet) throws IOException {
        new PotentialContinuousPT();
        MTELearning mTELearning = new MTELearning(this.train);
        double logLikelihood = this.test.logLikelihood(bnet);
        NodeList variables = this.train.getVariables();
        FiniteStates finiteStates = (FiniteStates) variables.elementAt(variables.size() - 1);
        double[] dArr = new double[finiteStates.getNumStates()];
        ContinuousCaseListMem continuousCaseListMem = (ContinuousCaseListMem) this.train.getCaseListMem();
        ContinuousCaseListMem continuousCaseListMem2 = (ContinuousCaseListMem) this.test.getCaseListMem();
        int i = 0;
        while (i < 100) {
            System.out.println("Iteration " + i);
            for (int i2 = 0; i2 < this.train.getNumberOfCases(); i2++) {
                ContinuousConfiguration continuousConfiguration = (ContinuousConfiguration) continuousCaseListMem.get(i2);
                ContinuousConfiguration continuousConfiguration2 = (ContinuousConfiguration) continuousConfiguration.copy();
                continuousConfiguration2.remove(finiteStates);
                VariableElimination variableElimination = new VariableElimination(bnet, new Evidence(continuousConfiguration2));
                NodeList nodeList = new NodeList();
                nodeList.insertNode(finiteStates);
                variableElimination.setInterest(nodeList);
                variableElimination.propagate();
                ContinuousProbabilityTree tree = ((PotentialContinuousPT) variableElimination.getResults().elementAt(0)).getTree();
                for (int i3 = 0; i3 < finiteStates.getNumStates(); i3++) {
                    ContinuousConfiguration continuousConfiguration3 = new ContinuousConfiguration();
                    continuousConfiguration3.insert(finiteStates, i3);
                    dArr[i3] = tree.getProb(continuousConfiguration3).getIndependent();
                }
                double random = Math.random();
                int i4 = 0;
                boolean z = false;
                double d = 0.0d;
                int i5 = 0;
                while (!z) {
                    d += dArr[i4];
                    if (random <= d || i4 == finiteStates.getNumStates() - 1) {
                        i5 = i4;
                        z = true;
                    } else {
                        i4++;
                    }
                }
                continuousConfiguration.putValue(finiteStates, i5);
                continuousCaseListMem.replaceCase(continuousConfiguration, i2);
            }
            for (int i6 = 0; i6 < this.test.getNumberOfCases(); i6++) {
                ContinuousConfiguration continuousConfiguration4 = (ContinuousConfiguration) continuousCaseListMem2.get(i6);
                ContinuousConfiguration continuousConfiguration5 = (ContinuousConfiguration) continuousConfiguration4.copy();
                continuousConfiguration5.remove(finiteStates);
                VariableElimination variableElimination2 = new VariableElimination(bnet, new Evidence(continuousConfiguration5));
                NodeList nodeList2 = new NodeList();
                nodeList2.insertNode(finiteStates);
                variableElimination2.setInterest(nodeList2);
                variableElimination2.propagate();
                ContinuousProbabilityTree tree2 = ((PotentialContinuousPT) variableElimination2.getResults().elementAt(0)).getTree();
                for (int i7 = 0; i7 < finiteStates.getNumStates(); i7++) {
                    ContinuousConfiguration continuousConfiguration6 = new ContinuousConfiguration();
                    continuousConfiguration6.insert(finiteStates, i7);
                    dArr[i7] = tree2.getProb(continuousConfiguration6).getIndependent();
                }
                double random2 = Math.random();
                int i8 = 0;
                boolean z2 = false;
                double d2 = 0.0d;
                int i9 = 0;
                while (!z2) {
                    d2 += dArr[i8];
                    if (random2 <= d2 || i8 == finiteStates.getNumStates() - 1) {
                        i9 = i8;
                        z2 = true;
                    } else {
                        i8++;
                    }
                }
                continuousConfiguration4.putValue(finiteStates, i9);
                continuousCaseListMem2.replaceCase(continuousConfiguration4, i6);
            }
            NodeList nodeList3 = bnet.getNodeList();
            NodeList nodeList4 = new NodeList();
            nodeList4.insertNode(finiteStates);
            for (int i10 = 0; i10 < nodeList3.size(); i10++) {
                Node elementAt = nodeList3.elementAt(i10);
                Relation relation = bnet.getRelation(elementAt);
                NodeList variables2 = relation.getVariables();
                relation.setValues(new PotentialContinuousPT(variables2, variables2.size() == 1 ? mTELearning.learnConditional(elementAt, new NodeList(), this.train, 4, 4) : mTELearning.learnConditional(elementAt, nodeList4, this.train, 4, 4)));
            }
            double logLikelihood2 = this.test.logLikelihood(bnet);
            if (logLikelihood2 > logLikelihood) {
                logLikelihood = logLikelihood2;
                i++;
            } else {
                i = 1001;
            }
            System.out.println("New likelihood " + logLikelihood2);
        }
    }

    public void addComponent(Bnet bnet, FiniteStates finiteStates) throws IOException {
        int numStates = finiteStates.getNumStates() + 1;
        FiniteStates finiteStates2 = new FiniteStates();
        finiteStates2.setNumStates(numStates);
        finiteStates2.setName(new String(finiteStates.getName()));
        finiteStates2.setStates((Vector) finiteStates.getStates().clone());
        Vector states = finiteStates2.getStates();
        states.addElement(new Integer(numStates - 1).toString());
        finiteStates2.setStates(states);
        System.out.println("Adding component " + numStates);
        NodeList nodeList = bnet.getNodeList();
        for (int i = 0; i < nodeList.size(); i++) {
            Node elementAt = nodeList.elementAt(i);
            Relation relation = bnet.getRelation(elementAt);
            ContinuousProbabilityTree tree = ((PotentialContinuousPT) relation.getValues()).getTree();
            ContinuousProbabilityTree copy = tree.copy();
            copy.setVar(finiteStates2);
            int numStates2 = finiteStates.getNumStates() - 1;
            if (elementAt.equals(finiteStates)) {
                MixtExpDensity prob = tree.getChild(numStates2).getProb();
                MixtExpDensity mixtExpDensity = new MixtExpDensity(prob.getIndependent() / 2.0d);
                copy.setChild(new ContinuousProbabilityTree(new MixtExpDensity(prob.getIndependent() / 2.0d)), numStates2);
                copy.insertChild(new ContinuousProbabilityTree(mixtExpDensity));
                nodeList.setElementAt(finiteStates2, i);
            } else {
                copy.insertChild(tree.getChild(numStates2).copy());
            }
            NodeList variables = relation.getVariables();
            variables.setElementAt(finiteStates2, variables.getId(finiteStates));
            relation.setValues(new PotentialContinuousPT(variables, copy));
        }
        finiteStates.setNumStates(numStates);
        Vector states2 = finiteStates.getStates();
        states2.addElement(new Integer(numStates - 1).toString());
        finiteStates.setStates(states2);
        System.out.println("Component added");
        System.out.println("New hidden var states" + finiteStates2.getStates().size());
        System.out.println("\nHidden var states" + finiteStates.getStates().size());
    }

    public Bnet learnInitialModel() {
        Vector vector = new Vector();
        FiniteStates finiteStates = new FiniteStates();
        finiteStates.setNumStates(2);
        finiteStates.setName("Hidden");
        finiteStates.setTitle("Hidden");
        vector.addElement(new Integer(0).toString());
        vector.addElement(new Integer(1).toString());
        finiteStates.setStates(vector);
        NodeList copy = this.train.getVariables().copy();
        NodeList nodeList = new NodeList();
        MTELearning mTELearning = new MTELearning(this.train);
        NodeList nodeList2 = new NodeList();
        nodeList2.insertNode(finiteStates);
        Vector vector2 = new Vector();
        LinkList linkList = new LinkList();
        for (int i = 0; i < copy.size(); i++) {
            Node elementAt = copy.elementAt(i);
            ContinuousProbabilityTree learnConditional = mTELearning.learnConditional(elementAt, nodeList, this.train, 4, 4);
            nodeList2.insertNode(elementAt);
            ContinuousProbabilityTree continuousProbabilityTree = new ContinuousProbabilityTree(finiteStates);
            continuousProbabilityTree.setChild(learnConditional.copy(), 0);
            continuousProbabilityTree.setChild(learnConditional, 1);
            NodeList nodeList3 = new NodeList();
            nodeList3.insertNode(elementAt);
            nodeList3.insertNode(finiteStates);
            linkList.insertLink(new Link(finiteStates, elementAt));
            PotentialContinuousPT potentialContinuousPT = new PotentialContinuousPT(nodeList3, continuousProbabilityTree);
            Relation relation = new Relation();
            relation.setVariables(nodeList3);
            relation.setValues(potentialContinuousPT);
            vector2.addElement(relation);
        }
        MixtExpDensity mixtExpDensity = new MixtExpDensity(0.5d);
        MixtExpDensity mixtExpDensity2 = new MixtExpDensity(0.5d);
        ContinuousProbabilityTree continuousProbabilityTree2 = new ContinuousProbabilityTree(finiteStates);
        continuousProbabilityTree2.setChild(new ContinuousProbabilityTree(mixtExpDensity), 0);
        continuousProbabilityTree2.setChild(new ContinuousProbabilityTree(mixtExpDensity2), 1);
        NodeList nodeList4 = new NodeList();
        nodeList4.insertNode(finiteStates);
        PotentialContinuousPT potentialContinuousPT2 = new PotentialContinuousPT(nodeList4, continuousProbabilityTree2);
        Relation relation2 = new Relation();
        relation2.setVariables(nodeList4);
        relation2.setValues(potentialContinuousPT2);
        vector2.addElement(relation2);
        Bnet bnet = new Bnet();
        bnet.setRelationList(vector2);
        bnet.setNodeList(nodeList2);
        bnet.setLinkList(linkList);
        NodeList copy2 = copy.copy();
        copy.insertNode(finiteStates);
        ContinuousCaseListMem continuousCaseListMem = (ContinuousCaseListMem) this.train.getCaseListMem();
        this.train.setNodeList(copy);
        for (int i2 = 0; i2 < continuousCaseListMem.getNumberOfCases(); i2++) {
            continuousCaseListMem.setVariables(copy2.getNodes());
            ContinuousConfiguration continuousConfiguration = (ContinuousConfiguration) continuousCaseListMem.get(i2);
            continuousConfiguration.insert(finiteStates, 0);
            continuousCaseListMem.setVariables(copy.getNodes());
            continuousCaseListMem.replaceCase(continuousConfiguration, i2);
        }
        ContinuousCaseListMem continuousCaseListMem2 = (ContinuousCaseListMem) this.test.getCaseListMem();
        this.test.setNodeList(copy);
        for (int i3 = 0; i3 < continuousCaseListMem2.getNumberOfCases(); i3++) {
            continuousCaseListMem2.setVariables(copy2.getNodes());
            ContinuousConfiguration continuousConfiguration2 = (ContinuousConfiguration) continuousCaseListMem2.get(i3);
            continuousConfiguration2.insert(finiteStates, 0);
            continuousCaseListMem2.setVariables(copy.getNodes());
            continuousCaseListMem2.replaceCase(continuousConfiguration2, i3);
        }
        return bnet;
    }

    public void learnModel() throws IOException {
        boolean z = false;
        boolean z2 = true;
        System.out.println("\nLearning initial model...\n");
        Bnet learnInitialModel = learnInitialModel();
        System.out.println("\nInitial model learnt.\n");
        FileWriter fileWriter = new FileWriter("initial.elv");
        learnInitialModel.saveBnet(fileWriter);
        fileWriter.close();
        double logLikelihood = this.test.logLikelihood(learnInitialModel);
        System.out.println("Initial likelihood: " + logLikelihood);
        FiniteStates finiteStates = (FiniteStates) learnInitialModel.getNodeList().elementAt(0);
        int numStates = finiteStates.getNumStates();
        while (!z && finiteStates.getNumStates() < 51) {
            Bnet bnet = new Bnet();
            bnet.setNodeList(learnInitialModel.getNodeList().copy());
            bnet.setLinkList(learnInitialModel.getLinkList().copy());
            bnet.setRelationList(learnInitialModel.getRelationList());
            if (z2) {
                z2 = false;
            } else {
                addComponent(bnet, finiteStates);
            }
            System.out.println("\nNumber of components " + finiteStates.getNumStates());
            EMAlgorithm(bnet);
            this.classifier = bnet;
            saveNetwork("bestnet" + finiteStates.getNumStates() + "comp.elv");
            double logLikelihood2 = this.test.logLikelihood(bnet);
            System.out.println("Likelihood best net with " + finiteStates.getNumStates() + " components: " + logLikelihood2);
            if (logLikelihood2 >= logLikelihood) {
                logLikelihood = logLikelihood2;
                learnInitialModel = bnet;
            } else if (numStates > 5) {
                z = true;
            }
            numStates++;
        }
        this.classifier = learnInitialModel;
    }

    public void setTrain(DataBaseCases dataBaseCases) {
        this.train = dataBaseCases;
    }

    public void setTest(DataBaseCases dataBaseCases) {
        this.test = dataBaseCases;
    }

    public void saveNetwork(String str) throws IOException {
        FileWriter fileWriter = new FileWriter(str);
        this.classifier.saveBnet(fileWriter);
        fileWriter.close();
    }

    public static void main(String[] strArr) throws ParseException, IOException {
        FileInputStream fileInputStream = new FileInputStream(strArr[0]);
        FileInputStream fileInputStream2 = new FileInputStream(strArr[2]);
        DataBaseCases dataBaseCases = new DataBaseCases(fileInputStream);
        DataBaseCases dataBaseCases2 = new DataBaseCases(fileInputStream2);
        UnsupervisedMTENaiveBayes unsupervisedMTENaiveBayes = new UnsupervisedMTENaiveBayes();
        unsupervisedMTENaiveBayes.setTrain(dataBaseCases);
        unsupervisedMTENaiveBayes.setTest(dataBaseCases2);
        unsupervisedMTENaiveBayes.learnModel();
        unsupervisedMTENaiveBayes.saveNetwork(strArr[1]);
    }
}
