package elvira.learning.classification.supervised.continuous;

import elvira.ContinuousCaseListMem;
import elvira.ContinuousConfiguration;
import elvira.Node;
import elvira.NodeList;
import elvira.database.DataBaseCases;
import elvira.learning.preprocessing.ProjectDBC;
import elvira.parser.ParseException;
import java.io.FileInputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Vector;
import org.apache.tools.ant.taskdefs.optional.ejb.GenericDeploymentTool;
import weka.classifiers.lazy.kstar.KStarConstants;

/* loaded from: input_file:bayelvira-1.0-SNAPSHOT.jar:elvira/learning/classification/supervised/continuous/GaussianPredictor.class */
public class GaussianPredictor {
    NodeList variables;
    int classVariable;
    Vector meanVector;
    Vector varianceVector;
    double[][] covarianceMatrix;
    double linearCorrelation;
    double meanSquaredError;
    String model;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:bayelvira-1.0-SNAPSHOT.jar:elvira/learning/classification/supervised/continuous/GaussianPredictor$Tuple.class */
    public class Tuple {
        public int p1;
        public int p2;
        public double value;

        public Tuple(int i, int i2, double d) {
            this.p1 = i;
            this.p2 = i2;
            this.value = d;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:bayelvira-1.0-SNAPSHOT.jar:elvira/learning/classification/supervised/continuous/GaussianPredictor$TupleComparator.class */
    public class TupleComparator implements Comparator {
        private TupleComparator() {
        }

        @Override // java.util.Comparator
        public int compare(Object obj, Object obj2) {
            Tuple tuple = (Tuple) obj;
            Tuple tuple2 = (Tuple) obj2;
            if (tuple.value < tuple2.value) {
                return 1;
            }
            return tuple.value > tuple2.value ? -1 : 0;
        }
    }

    public static void main(String[] strArr) throws ParseException, IOException {
        if (strArr.length != 4 && strArr.length != 6) {
            System.out.println("Wrong number of arguments. Usage GaussianPredictor\n\t training_file.dbc -- the input file --\n\t (naive|full|tan) -- the type of model to be used\n\t name-of-class-variable\n\t cases_to_predict.dbc\n\t [ list-of-attributes to filter --- i.e., 1,3,5-9,11 (non-zero based)\n\t   inverSelection (true|fasle) -- if true the listed attributed are deleted\n\t ]");
            System.exit(0);
        }
        DataBaseCases dataBaseCases = new DataBaseCases(new FileInputStream(strArr[0]));
        DataBaseCases dataBaseCases2 = new DataBaseCases(new FileInputStream(strArr[3]));
        if (strArr.length == 6) {
            Vector parseAttributes = ProjectDBC.parseAttributes(strArr[4]);
            boolean booleanValue = Boolean.valueOf(strArr[5]).booleanValue();
            dataBaseCases = new ProjectDBC(dataBaseCases, parseAttributes, booleanValue).doProjection();
            dataBaseCases2 = new ProjectDBC(dataBaseCases2, parseAttributes, booleanValue).doProjection();
        }
        ContinuousCaseListMem continuousCaseListMem = (ContinuousCaseListMem) dataBaseCases2.getCaseListMem();
        int id = dataBaseCases.getVariables().getId(strArr[2]);
        if (id == -1) {
            System.out.println("** Error **, variable " + strArr[2] + " is not included in the (projected) data base");
            System.exit(0);
        }
        String str = strArr[1];
        if (!str.equals("naive") && !str.equals(GenericDeploymentTool.ANALYZER_FULL) && !str.equals("tan")) {
            System.out.println("*** Error ***. Model " + str + " is not allowed. Exiting .... ");
            System.exit(0);
        }
        GaussianPredictor gaussianPredictor = new GaussianPredictor(dataBaseCases, id, str);
        gaussianPredictor.predictWithMean(continuousCaseListMem);
        System.out.println("Correlation : " + gaussianPredictor.getLinearCorrelation() + "\n");
        System.out.println("Rmse        : " + Math.sqrt(gaussianPredictor.getMeanSquaredError()) + "\n");
        System.out.println("Estimated values:");
    }

    public GaussianPredictor() {
        this.meanSquaredError = KStarConstants.FLOOR;
        this.model = "naive";
        this.variables = new NodeList();
        this.classVariable = -1;
        this.meanVector = new Vector();
        this.varianceVector = new Vector();
        this.covarianceMatrix = new double[1][1];
    }

    public GaussianPredictor(DataBaseCases dataBaseCases, int i, String str) {
        this.meanSquaredError = KStarConstants.FLOOR;
        this.model = "naive";
        this.variables = dataBaseCases.getVariables().copy();
        this.classVariable = i;
        this.model = str;
        this.meanVector = new Vector();
        this.varianceVector = new Vector();
        int size = this.variables.size();
        this.covarianceMatrix = new double[size][size];
        ContinuousCaseListMem continuousCaseListMem = (ContinuousCaseListMem) dataBaseCases.getCaseListMem();
        if (this.model.equals(GenericDeploymentTool.ANALYZER_FULL)) {
            learnFullModel(continuousCaseListMem);
        } else if (this.model.equals("naive")) {
            learnNaiveModel(continuousCaseListMem);
        } else if (this.model.equals("tan")) {
            learnTanModel(continuousCaseListMem);
        }
    }

    public GaussianPredictor(ContinuousCaseListMem continuousCaseListMem, int i, String str) {
        this.meanSquaredError = KStarConstants.FLOOR;
        this.model = "naive";
        this.variables = new NodeList((Vector<Node>) continuousCaseListMem.getVariables());
        this.classVariable = i;
        this.model = str;
        this.meanVector = new Vector();
        this.varianceVector = new Vector();
        int size = this.variables.size();
        this.covarianceMatrix = new double[size][size];
        if (this.model.equals(GenericDeploymentTool.ANALYZER_FULL)) {
            learnFullModel(continuousCaseListMem);
        } else if (this.model.equals("naive")) {
            learnNaiveModel(continuousCaseListMem);
        } else if (this.model.equals("tan")) {
            learnTanModel(continuousCaseListMem);
        }
    }

    public void learnFullModel(ContinuousCaseListMem continuousCaseListMem) {
        int size = this.variables.size();
        for (int i = 0; i < size; i++) {
            double variance = continuousCaseListMem.variance(i);
            this.varianceVector.addElement(new Double(variance));
            this.meanVector.addElement(new Double(continuousCaseListMem.mean(i)));
            this.covarianceMatrix[i][i] = variance;
            for (int i2 = 0; i2 < i; i2++) {
                this.covarianceMatrix[i][i2] = continuousCaseListMem.covariance(i, i2);
                this.covarianceMatrix[i2][i] = this.covarianceMatrix[i][i2];
            }
        }
    }

    public void learnNaiveModel(ContinuousCaseListMem continuousCaseListMem) {
        int size = this.variables.size();
        for (int i = 0; i < size; i++) {
            for (int i2 = 0; i2 < size; i2++) {
                this.covarianceMatrix[i][i2] = 0.0d;
            }
        }
        for (int i3 = 0; i3 < size; i3++) {
            double variance = continuousCaseListMem.variance(i3);
            this.varianceVector.addElement(new Double(variance));
            this.meanVector.addElement(new Double(continuousCaseListMem.mean(i3)));
            this.covarianceMatrix[i3][i3] = variance;
            if (i3 != this.classVariable) {
                this.covarianceMatrix[i3][this.classVariable] = continuousCaseListMem.covariance(i3, this.classVariable);
                this.covarianceMatrix[this.classVariable][i3] = this.covarianceMatrix[i3][this.classVariable];
            }
        }
    }

    public void learnTanModel(ContinuousCaseListMem continuousCaseListMem) {
        int size = this.variables.size();
        ArrayList conditionalGMI = getConditionalGMI(continuousCaseListMem);
        int[][] iArr = new int[size][size];
        for (int i = 0; i < size; i++) {
            for (int i2 = 0; i2 < size; i2++) {
                iArr[i][i2] = 0;
            }
        }
        int i3 = 0;
        while (i3 < size - 2) {
            Tuple tuple = (Tuple) conditionalGMI.remove(0);
            if (!provokesCycle(tuple.p1, tuple.p2, iArr)) {
                int[] iArr2 = iArr[tuple.p1];
                int i4 = tuple.p2;
                iArr[tuple.p2][tuple.p1] = 1;
                iArr2[i4] = 1;
                i3++;
            }
        }
        for (int i5 = 0; i5 < size; i5++) {
            int[] iArr3 = iArr[i5];
            int i6 = this.classVariable;
            iArr[this.classVariable][i5] = 1;
            iArr3[i6] = 1;
        }
        for (int i7 = 0; i7 < size; i7++) {
            double variance = continuousCaseListMem.variance(i7);
            this.varianceVector.addElement(new Double(variance));
            this.covarianceMatrix[i7][i7] = variance;
            for (int i8 = 0; i8 < i7; i8++) {
                if (iArr[i7][i8] == 1) {
                    this.covarianceMatrix[i7][i8] = continuousCaseListMem.covariance(i7, i8);
                    this.covarianceMatrix[i8][i7] = this.covarianceMatrix[i7][i8];
                }
            }
        }
    }

    private boolean provokesCycle(int i, int i2, int[][] iArr) {
        int size = this.variables.size();
        for (int i3 = 0; i3 < size; i3++) {
            if (iArr[i][i3] == 1 && isAccessible(i3, i2, iArr, i)) {
                return true;
            }
        }
        return false;
    }

    public boolean isAccessible(int i, int i2, int[][] iArr, int i3) {
        int size = this.variables.size();
        for (int i4 = 0; i4 < size; i4++) {
            if (iArr[i][i4] == 1 && i4 != i3) {
                if (i4 == i2) {
                    return true;
                }
                return isAccessible(i4, i2, iArr, i);
            }
        }
        return false;
    }

    public ArrayList getConditionalGMI(ContinuousCaseListMem continuousCaseListMem) {
        int i = this.classVariable;
        int size = this.variables.size();
        double[][] dArr = new double[size][size];
        ArrayList arrayList = new ArrayList(size * (size - 2));
        double[] dArr2 = new double[size];
        for (int i2 = 0; i2 < size; i2++) {
            dArr2[i2] = continuousCaseListMem.mean(i2);
            this.meanVector.addElement(new Double(dArr2[i2]));
        }
        for (int i3 = 0; i3 < size; i3++) {
            dArr[i3][i3] = 1.0d;
            for (int i4 = 0; i4 < i3; i4++) {
                double correlation = continuousCaseListMem.correlation(i3, i4, dArr2[i3], dArr2[i4]);
                dArr[i4][i3] = correlation;
                dArr[i3][i4] = correlation;
            }
        }
        for (int i5 = 0; i5 < size; i5++) {
            for (int i6 = 0; i6 < i5; i6++) {
                if (i5 != i && i6 != i) {
                    arrayList.add(new Tuple(i5, i6, (-0.5d) * Math.log(1.0d - Math.pow((dArr[i5][i6] - (dArr[i5][i] * dArr[i6][i])) / Math.sqrt((1.0d - Math.pow(dArr[i5][i], 2.0d)) * (1.0d - Math.pow(dArr[i6][i], 2.0d))), 2.0d))));
                }
            }
        }
        Collections.sort(arrayList, new TupleComparator());
        return arrayList;
    }

    public double predictWithMean(ContinuousConfiguration continuousConfiguration) {
        double d;
        double[] dArr = new double[this.meanVector.size()];
        double[][] dArr2 = new double[this.meanVector.size()][this.meanVector.size()];
        for (int i = 0; i < this.meanVector.size(); i++) {
            dArr[i] = ((Double) this.meanVector.elementAt(i)).doubleValue();
            for (int i2 = 0; i2 < this.meanVector.size(); i2++) {
                dArr2[i][i2] = this.covarianceMatrix[i][i2];
            }
        }
        for (int size = continuousConfiguration.size() - 1; size >= 0; size--) {
            if (size != this.classVariable) {
                double continuousValue = continuousConfiguration.getContinuousValue(size);
                double d2 = dArr[size];
                double d3 = dArr2[size][size];
                double d4 = 0.0d;
                for (int i3 = 0; i3 < size; i3++) {
                    d4 += Math.pow(dArr2[i3][size], 2.0d);
                }
                if (size < this.classVariable) {
                    double d5 = dArr2[this.classVariable][size];
                    d = (d4 + Math.pow(d5, 2.0d)) / d3;
                    dArr[this.classVariable] = dArr[this.classVariable] + ((d5 * (continuousValue - d2)) / d3);
                } else {
                    d = d4 / d3;
                }
                for (int i4 = 0; i4 < size; i4++) {
                    dArr[i4] = dArr[i4] + ((dArr2[i4][size] * (continuousValue - d2)) / d3);
                    for (int i5 = 0; i5 < size; i5++) {
                        dArr2[i4][i5] = dArr2[i4][i5] - d;
                    }
                    if (size < this.classVariable) {
                        dArr2[i4][this.classVariable] = dArr2[i4][this.classVariable] - d;
                        dArr2[this.classVariable][i4] = dArr2[this.classVariable][i4] - d;
                    }
                }
                if (size < this.classVariable) {
                    dArr2[this.classVariable][this.classVariable] = dArr2[this.classVariable][this.classVariable] - d;
                }
            }
        }
        double d6 = dArr[this.classVariable];
        double d7 = dArr2[this.classVariable][this.classVariable];
        return d6;
    }

    public Vector predictWithMean(ContinuousCaseListMem continuousCaseListMem) {
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        Vector vector = new Vector();
        for (int i = 0; i < continuousCaseListMem.getNumberOfCases(); i++) {
            ContinuousConfiguration continuousConfiguration = (ContinuousConfiguration) continuousCaseListMem.get(i);
            double predictWithMean = predictWithMean(continuousConfiguration);
            vector.addElement(new Double(predictWithMean));
            d += predictWithMean;
            d2 += predictWithMean * predictWithMean;
            d3 += predictWithMean * continuousConfiguration.getContinuousValue(this.classVariable);
            this.meanSquaredError += Math.pow(predictWithMean - continuousConfiguration.getContinuousValue(this.classVariable), 2.0d);
        }
        double numberOfCases = d / continuousCaseListMem.getNumberOfCases();
        double sqrt = Math.sqrt((d2 / continuousCaseListMem.getNumberOfCases()) - Math.pow(numberOfCases, 2.0d));
        double sqrt2 = Math.sqrt(((Double) this.varianceVector.elementAt(this.classVariable)).doubleValue());
        double doubleValue = ((Double) this.meanVector.elementAt(this.classVariable)).doubleValue();
        System.out.println("Class mean " + doubleValue + " ; Predicted mean: " + numberOfCases);
        System.out.println("Class sigma " + sqrt2 + " ; Predicted sigma: " + sqrt);
        this.linearCorrelation = ((d3 / continuousCaseListMem.getNumberOfCases()) - (numberOfCases * doubleValue)) / (sqrt * sqrt2);
        this.meanSquaredError /= continuousCaseListMem.getNumberOfCases();
        return vector;
    }

    public double getLinearCorrelation() {
        return this.linearCorrelation;
    }

    public double getMeanSquaredError() {
        return this.meanSquaredError;
    }
}
