package elvira.learning.classification.supervised.continuous;

import elvira.Bnet;
import elvira.Continuous;
import elvira.ContinuousCaseListMem;
import elvira.ContinuousConfiguration;
import elvira.Evidence;
import elvira.FiniteStates;
import elvira.Node;
import elvira.NodeList;
import elvira.database.DataBaseCases;
import elvira.inference.clustering.MTESimplePenniless;
import elvira.inference.elimination.VariableElimination;
import elvira.learning.MTELearning;
import elvira.parser.ParseException;
import elvira.potential.ContinuousProbabilityTree;
import elvira.potential.PotentialContinuousPT;
import java.io.FileInputStream;
import java.io.IOException;
import java.util.Vector;
import weka.classifiers.lazy.kstar.KStarConstants;

/* loaded from: input_file:bayelvira-1.0-SNAPSHOT.jar:elvira/learning/classification/supervised/continuous/NaiveMTEPredictorMissingData.class */
public class NaiveMTEPredictorMissingData {
    NodeList variables;
    int classVariable;
    Bnet net;
    DataBaseCases dbCases;

    public NaiveMTEPredictorMissingData(DataBaseCases dataBaseCases, int i, int i2) {
        this.dbCases = dataBaseCases;
        this.classVariable = i;
        NaiveMTEPredictor naiveMTEPredictor = new NaiveMTEPredictor(FillValuesLearnUnivariate(), this.classVariable, i2);
        this.net = naiveMTEPredictor.net;
        double d = 1.0E9d;
        double d2 = getrmse(naiveMTEPredictor, dataBaseCases);
        while (true) {
            double d3 = d2;
            if (d3 >= d) {
                return;
            }
            System.out.println("     Improving rmse: " + d + " --> " + d3);
            d = d3;
            naiveMTEPredictor = new NaiveMTEPredictor(FillSimulatedPredictedValues(naiveMTEPredictor), this.classVariable, i2);
            this.net = naiveMTEPredictor.net;
            d2 = getrmse(naiveMTEPredictor, dataBaseCases);
        }
    }

    public Vector predictWithMeanMissingValues(ContinuousConfiguration continuousConfiguration, Node node) {
        double value = continuousConfiguration.getValue((Continuous) node);
        continuousConfiguration.remove(node);
        Evidence evidence = new Evidence(continuousConfiguration);
        MTESimplePenniless mTESimplePenniless = new MTESimplePenniless(this.net, evidence, KStarConstants.FLOOR, KStarConstants.FLOOR, KStarConstants.FLOOR, KStarConstants.FLOOR, 0);
        mTESimplePenniless.propagate(evidence);
        ContinuousProbabilityTree tree = ((PotentialContinuousPT) mTESimplePenniless.getResults().elementAt(0)).getTree();
        double firstOrderMoment = tree.firstOrderMoment();
        double Variance = tree.Variance();
        double median = tree.median();
        Vector vector = new Vector();
        vector.addElement(new Double(firstOrderMoment));
        vector.addElement(new Double(Variance));
        vector.addElement(new Double(median));
        vector.addElement(new Double(value));
        return vector;
    }

    public Vector predictWithMeanMissingValues(DataBaseCases dataBaseCases, double d) {
        Node elementAt = dataBaseCases.getVariables().elementAt(this.classVariable);
        ContinuousCaseListMem continuousCaseListMem = (ContinuousCaseListMem) dataBaseCases.getCaseListMem();
        Vector vector = new Vector();
        Vector vector2 = new Vector();
        Vector vector3 = new Vector();
        Vector vector4 = new Vector();
        Vector vector5 = new Vector();
        int numberOfCases = continuousCaseListMem.getNumberOfCases();
        for (int i = 1; i < numberOfCases; i++) {
            ContinuousConfiguration continuousConfiguration = (ContinuousConfiguration) continuousCaseListMem.get(i);
            continuousConfiguration.removeUndefinedValues();
            Vector predictWithMeanMissingValues = predictWithMeanMissingValues(continuousConfiguration, elementAt, d);
            vector2.addElement((Double) predictWithMeanMissingValues.elementAt(0));
            vector4.addElement((Double) predictWithMeanMissingValues.elementAt(1));
            vector3.addElement((Double) predictWithMeanMissingValues.elementAt(2));
            vector5.addElement((Double) predictWithMeanMissingValues.elementAt(3));
        }
        vector.addElement(vector2);
        vector.addElement(vector4);
        vector.addElement(vector3);
        vector.addElement(vector5);
        return vector;
    }

    public Vector predictWithMeanMissingValues(ContinuousConfiguration continuousConfiguration, Node node, double d) {
        double value = continuousConfiguration.getValue((Continuous) node);
        continuousConfiguration.remove(node);
        Evidence evidence = new Evidence(continuousConfiguration);
        MTESimplePenniless mTESimplePenniless = new MTESimplePenniless(this.net, evidence, KStarConstants.FLOOR, KStarConstants.FLOOR, KStarConstants.FLOOR, KStarConstants.FLOOR, 0);
        mTESimplePenniless.propagate(evidence);
        ContinuousProbabilityTree tree = ((PotentialContinuousPT) mTESimplePenniless.getResults().elementAt(0)).getTree();
        double firstOrderMoment = tree.firstOrderMoment() - d;
        double Variance = tree.Variance();
        double median = tree.median() - d;
        Vector vector = new Vector();
        vector.addElement(new Double(firstOrderMoment));
        vector.addElement(new Double(Variance));
        vector.addElement(new Double(median));
        vector.addElement(new Double(value));
        return vector;
    }

    public Vector predictWithMeanMissingValues(DataBaseCases dataBaseCases) {
        Node elementAt = dataBaseCases.getVariables().elementAt(this.classVariable);
        ContinuousCaseListMem continuousCaseListMem = (ContinuousCaseListMem) dataBaseCases.getCaseListMem();
        Vector vector = new Vector();
        Vector vector2 = new Vector();
        Vector vector3 = new Vector();
        Vector vector4 = new Vector();
        Vector vector5 = new Vector();
        int numberOfCases = continuousCaseListMem.getNumberOfCases();
        for (int i = 1; i < numberOfCases; i++) {
            ContinuousConfiguration continuousConfiguration = (ContinuousConfiguration) continuousCaseListMem.get(i);
            continuousConfiguration.removeUndefinedValues();
            Vector predictWithMeanMissingValues = predictWithMeanMissingValues(continuousConfiguration, elementAt);
            vector2.addElement((Double) predictWithMeanMissingValues.elementAt(0));
            vector4.addElement((Double) predictWithMeanMissingValues.elementAt(1));
            vector3.addElement((Double) predictWithMeanMissingValues.elementAt(2));
            vector5.addElement((Double) predictWithMeanMissingValues.elementAt(3));
        }
        vector.addElement(vector2);
        vector.addElement(vector4);
        vector.addElement(vector3);
        vector.addElement(vector5);
        return vector;
    }

    private double getrmse(NaiveMTEPredictor naiveMTEPredictor, DataBaseCases dataBaseCases) {
        dataBaseCases.removeCasesMissingValue(this.classVariable);
        Vector predictWithMeanMissingValues = predictWithMeanMissingValues(dataBaseCases);
        Vector predictWithMeanMissingValues2 = predictWithMeanMissingValues(dataBaseCases, NaiveMTEPredictor.computeBias((Vector) predictWithMeanMissingValues.elementAt(0), (Vector) predictWithMeanMissingValues.elementAt(3)));
        Vector computeErrors = NaiveMTEPredictor.computeErrors((Vector) predictWithMeanMissingValues2.elementAt(0), (Vector) predictWithMeanMissingValues2.elementAt(3));
        double doubleValue = ((Double) computeErrors.elementAt(0)).doubleValue();
        ((Double) computeErrors.elementAt(1)).doubleValue();
        Vector computeErrors2 = NaiveMTEPredictor.computeErrors((Vector) predictWithMeanMissingValues2.elementAt(2), (Vector) predictWithMeanMissingValues2.elementAt(3));
        ((Double) computeErrors2.elementAt(0)).doubleValue();
        ((Double) computeErrors2.elementAt(1)).doubleValue();
        return doubleValue;
    }

    private DataBaseCases FillValuesLearnUnivariate() {
        this.dbCases.copy();
        DataBaseCases copy = this.dbCases.copy();
        MTELearning mTELearning = new MTELearning();
        new DataBaseCases();
        for (int i = 0; i < copy.getVariables().size(); i++) {
            DataBaseCases copy2 = copy.copy();
            Node elementAt = this.dbCases.getVariables().elementAt(i);
            NodeList nodeList = new NodeList();
            nodeList.insertNode(elementAt);
            copy2.projection(nodeList);
            int i2 = 0;
            for (int i3 = 0; i3 < copy.getNumberOfCases(); i3++) {
                if (Continuous.isUndefined(copy.getCaseListMem().getValue(i3, i))) {
                    copy2.getCaseListMem().getCases().removeElementAt(i2);
                    copy2.getCaseListMem().setNumberOfCases(copy2.getCaseListMem().getNumberOfCases() - 1);
                    copy2.setNumberOfCases(copy2.getNumberOfCases() - 1);
                } else {
                    i2++;
                }
            }
            new ContinuousProbabilityTree();
            ContinuousProbabilityTree learnConditional = mTELearning.learnConditional(elementAt, new NodeList(), copy2, 4, 4);
            for (int i4 = 0; i4 < this.dbCases.getNumberOfCases(); i4++) {
                if (Continuous.isUndefined(copy.getCaseListMem().getValue(i4, i))) {
                    copy.getCaseListMem().setValue(i4, i, learnConditional.simulateValue());
                }
            }
        }
        return copy;
    }

    private DataBaseCases FillSimulatedPredictedValues(NaiveMTEPredictor naiveMTEPredictor) {
        DataBaseCases copy = this.dbCases.copy();
        for (int i = 0; i < this.dbCases.getNumberOfCases(); i++) {
            ContinuousConfiguration continuousConfiguration = (ContinuousConfiguration) this.dbCases.getCaseListMem().get(i).copy();
            for (int i2 = 0; i2 < this.dbCases.getVariables().size(); i2++) {
                Node elementAt = this.dbCases.getVariables().elementAt(i2);
                if (Continuous.isUndefined(copy.getCaseListMem().getValue(i, i2))) {
                    ContinuousConfiguration continuousConfiguration2 = (ContinuousConfiguration) continuousConfiguration.copy();
                    if (i2 == this.classVariable) {
                        continuousConfiguration2.print();
                        continuousConfiguration.putValue((Continuous) elementAt, ((Double) predictWithMeanMissingValues(continuousConfiguration2, elementAt).elementAt(0)).doubleValue());
                    } else {
                        continuousConfiguration2.remove(elementAt);
                        VariableElimination variableElimination = new VariableElimination(naiveMTEPredictor.net, new Evidence(continuousConfiguration2));
                        NodeList nodeList = new NodeList();
                        nodeList.insertNode(elementAt);
                        variableElimination.setInterest(nodeList);
                        variableElimination.propagate();
                        ContinuousProbabilityTree tree = ((PotentialContinuousPT) variableElimination.getResults().elementAt(0)).getTree();
                        if (elementAt instanceof FiniteStates) {
                            continuousConfiguration.putValue((FiniteStates) elementAt, Math.round((float) tree.simulateValue()));
                        } else {
                            continuousConfiguration.putValue((Continuous) elementAt, tree.simulateValue());
                        }
                    }
                }
            }
            ((ContinuousCaseListMem) copy.getCaseListMem()).replaceCase(continuousConfiguration, i);
        }
        return copy;
    }

    public static void main(String[] strArr) throws ParseException, IOException {
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        double d4 = 0.0d;
        if (strArr[1].compareTo("CV") != 0) {
            FileInputStream fileInputStream = new FileInputStream(strArr[0]);
            FileInputStream fileInputStream2 = new FileInputStream(strArr[1]);
            int intValue = Integer.valueOf(strArr[2]).intValue();
            int intValue2 = Integer.valueOf(strArr[3]).intValue();
            DataBaseCases dataBaseCases = new DataBaseCases(fileInputStream);
            NaiveMTEPredictor naiveMTEPredictor = new NaiveMTEPredictor(dataBaseCases, intValue, intValue2);
            naiveMTEPredictor.saveNetwork("NB.elv");
            Vector predictWithMean = naiveMTEPredictor.predictWithMean(dataBaseCases);
            Vector predictWithMean2 = naiveMTEPredictor.predictWithMean(new DataBaseCases(fileInputStream2), NaiveMTEPredictor.computeBias((Vector) predictWithMean.elementAt(0), (Vector) predictWithMean.elementAt(3)));
            System.out.println("Hold-out validation");
            NaiveMTEPredictor.computeErrors((Vector) predictWithMean2.elementAt(0), (Vector) predictWithMean2.elementAt(3));
            NaiveMTEPredictor.computeErrors((Vector) predictWithMean2.elementAt(2), (Vector) predictWithMean2.elementAt(3));
            return;
        }
        FileInputStream fileInputStream3 = new FileInputStream(strArr[0]);
        int intValue3 = Integer.valueOf(strArr[4]).intValue();
        int intValue4 = Integer.valueOf(strArr[2]).intValue();
        int intValue5 = Integer.valueOf(strArr[3]).intValue();
        DataBaseCases dataBaseCases2 = new DataBaseCases(fileInputStream3);
        for (int i = 0; i < intValue3; i++) {
            System.out.println("ITERATION " + i);
            DataBaseCases trainCV = dataBaseCases2.getTrainCV(i, intValue3);
            DataBaseCases testCV = dataBaseCases2.getTestCV(i, intValue3);
            testCV.removeCasesMissingValue(intValue4);
            NaiveMTEPredictorMissingData naiveMTEPredictorMissingData = new NaiveMTEPredictorMissingData(trainCV, intValue4, intValue5);
            Vector predictWithMeanMissingValues = naiveMTEPredictorMissingData.predictWithMeanMissingValues(trainCV);
            Vector predictWithMeanMissingValues2 = naiveMTEPredictorMissingData.predictWithMeanMissingValues(testCV, NaiveMTEPredictor.computeBias((Vector) predictWithMeanMissingValues.elementAt(0), (Vector) predictWithMeanMissingValues.elementAt(3)));
            Vector computeErrors = NaiveMTEPredictor.computeErrors((Vector) predictWithMeanMissingValues2.elementAt(0), (Vector) predictWithMeanMissingValues2.elementAt(3));
            d3 += ((Double) computeErrors.elementAt(0)).doubleValue();
            d += ((Double) computeErrors.elementAt(1)).doubleValue();
            Vector computeErrors2 = NaiveMTEPredictor.computeErrors((Vector) predictWithMeanMissingValues2.elementAt(2), (Vector) predictWithMeanMissingValues2.elementAt(3));
            d4 += ((Double) computeErrors2.elementAt(0)).doubleValue();
            d2 += ((Double) computeErrors2.elementAt(1)).doubleValue();
        }
        System.out.println("-------------------------------------------------");
        System.out.println(intValue3 + "-fold cross validation.");
        System.out.println("\nFinal results:");
        System.out.println("rmse_mean,lcc_mean,rmse_median,lcc_median");
        System.out.println((d3 / intValue3) + "," + (d / intValue3) + "," + (d4 / intValue3) + "," + (d2 / intValue3));
        System.out.println("\n");
    }
}
