package elvira.learning.classification.supervised.continuous;

import elvira.ContinuousCaseListMem;
import elvira.ContinuousConfiguration;
import elvira.NodeList;
import elvira.database.DataBaseCases;
import elvira.parser.ParseException;
import java.io.FileInputStream;
import java.io.IOException;
import java.util.Vector;

/* loaded from: input_file:bayelvira-1.0-SNAPSHOT.jar:elvira/learning/classification/supervised/continuous/NaiveGaussianPredictor.class */
public class NaiveGaussianPredictor {
    NodeList variables;
    int classVariable;
    Vector meanVector;
    Vector varianceVector;
    double[][] covarianceMatrix;
    double linearCorrelation;

    public static void main(String[] strArr) throws ParseException, IOException {
        if (strArr.length != 3) {
            System.out.println("wrong number of arguments: Usage: training_file.dbc index_of_class_variable cases_to_predict.dbc");
            System.exit(0);
        }
        DataBaseCases dataBaseCases = new DataBaseCases(new FileInputStream(strArr[0]));
        ContinuousCaseListMem continuousCaseListMem = (ContinuousCaseListMem) new DataBaseCases(new FileInputStream(strArr[2])).getCaseListMem();
        NaiveGaussianPredictor naiveGaussianPredictor = new NaiveGaussianPredictor(dataBaseCases, Integer.parseInt(strArr[1]));
        naiveGaussianPredictor.predictWithMean(continuousCaseListMem);
        System.out.println("Correlation : " + naiveGaussianPredictor.getLinearCorrelation() + "\n");
        System.out.println("Estimated values:");
    }

    public NaiveGaussianPredictor() {
        this.variables = new NodeList();
        this.classVariable = -1;
        this.meanVector = new Vector();
        this.varianceVector = new Vector();
        this.covarianceMatrix = new double[1][1];
    }

    public NaiveGaussianPredictor(DataBaseCases dataBaseCases, int i) {
        this.variables = dataBaseCases.getVariables().copy();
        this.classVariable = i;
        int size = this.variables.size();
        this.meanVector = new Vector();
        this.varianceVector = new Vector();
        this.covarianceMatrix = new double[size][size];
        ContinuousCaseListMem continuousCaseListMem = (ContinuousCaseListMem) dataBaseCases.getCaseListMem();
        for (int i2 = 0; i2 < size; i2++) {
            double variance = continuousCaseListMem.variance(i2);
            this.varianceVector.addElement(new Double(variance));
            this.meanVector.addElement(new Double(continuousCaseListMem.mean(i2)));
            for (int i3 = 0; i3 < size; i3++) {
                if (i2 == i3) {
                    this.covarianceMatrix[i2][i3] = variance;
                } else if (i2 == i || i3 == i) {
                    this.covarianceMatrix[i2][i3] = continuousCaseListMem.covariance(i2, i3);
                } else {
                    this.covarianceMatrix[i2][i3] = 0.0d;
                }
            }
            this.covarianceMatrix[i2][i2] = variance;
            this.covarianceMatrix[i2][i] = continuousCaseListMem.covariance(i2, i);
            this.covarianceMatrix[i][i2] = this.covarianceMatrix[i2][i];
        }
    }

    public double predictWithMean(ContinuousConfiguration continuousConfiguration) {
        double d;
        double[] dArr = new double[this.meanVector.size()];
        double[][] dArr2 = new double[this.meanVector.size()][this.meanVector.size()];
        for (int i = 0; i < this.meanVector.size(); i++) {
            dArr[i] = ((Double) this.meanVector.elementAt(i)).doubleValue();
            for (int i2 = 0; i2 < this.meanVector.size(); i2++) {
                dArr2[i][i2] = this.covarianceMatrix[i][i2];
            }
        }
        for (int size = continuousConfiguration.size() - 1; size >= 0; size--) {
            if (size != this.classVariable) {
                double continuousValue = continuousConfiguration.getContinuousValue(size);
                double d2 = dArr[size];
                double d3 = dArr2[size][size];
                double d4 = 0.0d;
                for (int i3 = 0; i3 < size; i3++) {
                    d4 += Math.pow(dArr2[i3][size], 2.0d);
                }
                if (size < this.classVariable) {
                    double d5 = dArr2[this.classVariable][size];
                    d = (d4 + Math.pow(d5, 2.0d)) / d3;
                    dArr[this.classVariable] = dArr[this.classVariable] + ((d5 * (continuousValue - d2)) / d3);
                } else {
                    d = d4 / d3;
                }
                for (int i4 = 0; i4 < size; i4++) {
                    dArr[i4] = dArr[i4] + ((dArr2[i4][size] * (continuousValue - d2)) / d3);
                    for (int i5 = 0; i5 < size; i5++) {
                        dArr2[i4][i5] = dArr2[i4][i5] - d;
                    }
                    if (size < this.classVariable) {
                        dArr2[i4][this.classVariable] = dArr2[i4][this.classVariable] - d;
                        dArr2[this.classVariable][i4] = dArr2[this.classVariable][i4] - d;
                    }
                }
                if (size < this.classVariable) {
                    dArr2[this.classVariable][this.classVariable] = dArr2[this.classVariable][this.classVariable] - d;
                }
            }
        }
        double d6 = dArr[this.classVariable];
        System.out.println("Var : " + dArr2[this.classVariable][this.classVariable] + " Mean : " + d6);
        return d6;
    }

    public Vector predictWithMean(ContinuousCaseListMem continuousCaseListMem) {
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        Vector vector = new Vector();
        for (int i = 0; i < continuousCaseListMem.getNumberOfCases(); i++) {
            ContinuousConfiguration continuousConfiguration = (ContinuousConfiguration) continuousCaseListMem.get(i);
            double predictWithMean = predictWithMean(continuousConfiguration);
            vector.addElement(new Double(predictWithMean));
            d += predictWithMean;
            d2 += predictWithMean * predictWithMean;
            d3 += predictWithMean * continuousConfiguration.getContinuousValue(this.classVariable);
            System.out.println("Valor estimado: " + predictWithMean + " real: " + continuousConfiguration.getContinuousValue(this.classVariable));
        }
        double numberOfCases = d / continuousCaseListMem.getNumberOfCases();
        double sqrt = Math.sqrt((d2 / continuousCaseListMem.getNumberOfCases()) - Math.pow(numberOfCases, 2.0d));
        double sqrt2 = Math.sqrt(((Double) this.varianceVector.elementAt(this.classVariable)).doubleValue());
        double doubleValue = ((Double) this.meanVector.elementAt(this.classVariable)).doubleValue();
        System.out.println("Class mean " + doubleValue + " ; Predicted mean: " + numberOfCases);
        System.out.println("Class sigma " + sqrt2 + " ; Predicted sigma: " + sqrt);
        this.linearCorrelation = ((d3 / continuousCaseListMem.getNumberOfCases()) - (numberOfCases * doubleValue)) / (sqrt * sqrt2);
        return vector;
    }

    public double getLinearCorrelation() {
        return this.linearCorrelation;
    }
}
