package elvira.learning.classification.supervised.continuous;

import elvira.ContinuousCaseListMem;
import elvira.Node;
import elvira.NodeList;
import elvira.database.DataBaseCases;
import elvira.learning.preprocessing.ProjectDBC;
import elvira.parser.ParseException;
import java.io.FileInputStream;
import java.io.IOException;
import java.util.Random;
import java.util.Vector;
import org.apache.tools.ant.taskdefs.optional.ejb.GenericDeploymentTool;

/* loaded from: input_file:bayelvira-1.0-SNAPSHOT.jar:elvira/learning/classification/supervised/continuous/SelectiveGaussianPredictor.class */
public class SelectiveGaussianPredictor {
    NodeList m_variables;
    int m_classVariable;
    String m_model;
    ContinuousCaseListMem m_cases;
    int m_lookahead;
    double m_epsilon;
    NodeList m_selected = new NodeList();
    Vector m_selectedIndexes = new Vector();
    GaussianPredictor selectivePredictor = null;

    public static void main(String[] strArr) throws ParseException, IOException {
        if (strArr.length != 6 && strArr.length != 8) {
            System.out.println("Wrong number of arguments. Usage GaussianPredictor\n\t training_file.dbc -- the input file --\n\t (naive|full|tan) -- the type of model to be used\n\t name-of-class-variable\n\t cases_to_predict.dbc\n\t threshod for selection (1.0, 1.01, 1.001, ...)\n\t lookahead value (0, 1, 2, ...)\n\t [ list-of-attributes to filter --- i.e., 1,3,5-9,11 (non-zero based)\n\t   inverSelection (true|fasle) -- if true the listed attributed are deleted\n\t ]");
            System.exit(0);
        }
        DataBaseCases dataBaseCases = new DataBaseCases(new FileInputStream(strArr[0]));
        DataBaseCases dataBaseCases2 = new DataBaseCases(new FileInputStream(strArr[3]));
        double parseDouble = Double.parseDouble(strArr[4]);
        int parseInt = Integer.parseInt(strArr[5]);
        if (strArr.length == 8) {
            Vector parseAttributes = ProjectDBC.parseAttributes(strArr[6]);
            boolean booleanValue = Boolean.valueOf(strArr[7]).booleanValue();
            dataBaseCases = new ProjectDBC(dataBaseCases, parseAttributes, booleanValue).doProjection();
            dataBaseCases2 = new ProjectDBC(dataBaseCases2, parseAttributes, booleanValue).doProjection();
        }
        ContinuousCaseListMem continuousCaseListMem = (ContinuousCaseListMem) dataBaseCases2.getCaseListMem();
        int id = dataBaseCases.getVariables().getId(strArr[2]);
        if (id == -1) {
            System.out.println("** Error **, variable " + strArr[2] + " is not included in the (projected) data base");
            System.exit(0);
        }
        String str = strArr[1];
        if (!str.equals("naive") && !str.equals(GenericDeploymentTool.ANALYZER_FULL) && !str.equals("tan")) {
            System.out.println("*** Error ***. Model " + str + " is not allowed. Exiting .... ");
            System.exit(0);
        }
        SelectiveGaussianPredictor selectiveGaussianPredictor = new SelectiveGaussianPredictor(dataBaseCases, id, str, parseDouble, parseInt);
        selectiveGaussianPredictor.doSelection();
        selectiveGaussianPredictor.learnSelectiveModel();
        selectiveGaussianPredictor.predictWithMean(continuousCaseListMem);
        System.out.println("\nCorrelation : " + selectiveGaussianPredictor.getLinearCorrelation());
        System.out.println("Rmse        : " + Math.sqrt(selectiveGaussianPredictor.getMeanSquaredError()) + "\n");
    }

    public SelectiveGaussianPredictor(DataBaseCases dataBaseCases, int i, String str, double d, int i2) {
        this.m_variables = null;
        this.m_classVariable = -1;
        this.m_model = null;
        this.m_lookahead = 0;
        this.m_epsilon = 1.0d;
        this.m_epsilon = d;
        this.m_lookahead = i2;
        this.m_variables = dataBaseCases.getVariables().copy();
        this.m_classVariable = i;
        this.m_model = str;
        this.m_cases = (ContinuousCaseListMem) dataBaseCases.getCaseListMem();
        this.m_cases.randomize(new Random());
    }

    public void learnSelectiveModel() {
        NodeList nodeList = new NodeList();
        nodeList.join(this.m_selected);
        nodeList.insertNode(this.m_variables.elementAt(this.m_classVariable));
        this.selectivePredictor = new GaussianPredictor(this.m_cases.projection(nodeList), nodeList.size() - 1, this.m_model);
    }

    public void predictWithMean(ContinuousCaseListMem continuousCaseListMem) {
        NodeList nodeList = new NodeList();
        nodeList.join(this.m_selected);
        nodeList.insertNode(this.m_variables.elementAt(this.m_classVariable));
        this.selectivePredictor.predictWithMean(continuousCaseListMem.projection(nodeList));
    }

    public double getLinearCorrelation() {
        return this.selectivePredictor.getLinearCorrelation();
    }

    public double getMeanSquaredError() {
        return this.selectivePredictor.getMeanSquaredError();
    }

    public void doSelection() {
        double d = -1.0d;
        int[] mIRanking = getMIRanking();
        int size = this.m_variables.size() - 1;
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < size; i3++) {
            if (i3 != this.m_classVariable) {
                this.m_selected.insertNode(this.m_variables.elementAt(mIRanking[i3]));
                this.m_selected.insertNode(this.m_variables.elementAt(this.m_classVariable));
                double doCVPrediction = doCVPrediction(this.m_cases.projection(this.m_selected), 5);
                i2++;
                this.m_selected.removeNode(this.m_selected.size() - 1);
                if (doCVPrediction <= this.m_epsilon * d) {
                    this.m_selected.removeNode(this.m_selected.size() - 1);
                    if (i >= this.m_lookahead) {
                        break;
                    } else {
                        i++;
                    }
                } else {
                    d = doCVPrediction;
                    this.m_selectedIndexes.add(new Integer(mIRanking[i3]));
                    i = 0;
                }
            }
        }
        System.out.println("\n\nNum. of selected variables: " + this.m_selected.size());
        System.out.print("Indexes of the selected attributes:  ");
        for (int i4 = 0; i4 < this.m_selected.size(); i4++) {
            System.out.print(((Integer) this.m_selectedIndexes.elementAt(i4)).intValue() + ",");
        }
        System.out.println("\nNum. of evaluated subsets: " + i2);
        System.out.println("Correlation: " + d + "\n");
        this.m_selected.insertNode(this.m_variables.elementAt(this.m_classVariable));
        for (int size2 = this.m_selected.size() - 3; size2 >= 0; size2--) {
            Node elementAt = this.m_selected.elementAt(size2);
            this.m_selected.removeNode(size2);
            double doCVPrediction2 = doCVPrediction(this.m_cases.projection(this.m_selected), 5);
            i2++;
            if (doCVPrediction2 >= d) {
                d = doCVPrediction2;
                this.m_selectedIndexes.remove(size2);
            } else {
                Vector vector = this.m_selected.toVector();
                vector.insertElementAt(elementAt, size2);
                this.m_selected = new NodeList((Vector<Node>) vector);
            }
        }
        this.m_selected.removeNode(this.m_selected.size() - 1);
        System.out.println("\n\nNum. of selected variables: " + this.m_selected.size());
        System.out.print("Indexes of the selected attributes:  ");
        for (int i5 = 0; i5 < this.m_selected.size(); i5++) {
            System.out.print(((Integer) this.m_selectedIndexes.elementAt(i5)).intValue() + ",");
        }
        System.out.println("\nNum. of evaluated subsets: " + i2);
        System.out.println("Correlation: " + d);
    }

    private double doCVPrediction(ContinuousCaseListMem continuousCaseListMem, int i) {
        double d = 0.0d;
        int size = continuousCaseListMem.getVariables().size() - 1;
        for (int i2 = 0; i2 < i; i2++) {
            GaussianPredictor gaussianPredictor = new GaussianPredictor(continuousCaseListMem.trainCV(i2, i), size, this.m_model);
            gaussianPredictor.predictWithMean(continuousCaseListMem.testCV(i2, i));
            d += gaussianPredictor.getLinearCorrelation() / i;
        }
        return d;
    }

    public int[] getMIRanking() {
        int i;
        int size = this.m_variables.size();
        double[] dArr = new double[size];
        double[] dArr2 = new double[size];
        double[] dArr3 = new double[size];
        int[] iArr = new int[size - 1];
        for (int i2 = 0; i2 < size; i2++) {
            dArr[i2] = this.m_cases.mean(i2);
        }
        for (int i3 = 0; i3 < size; i3++) {
            dArr2[i3] = this.m_cases.correlation(i3, this.m_classVariable, dArr[i3], dArr[this.m_classVariable]);
            dArr3[i3] = (-0.5d) * Math.log(1.0d - Math.pow(dArr2[i3], 2.0d));
        }
        if (this.m_classVariable == 0) {
            i = 2;
            iArr[0] = 1;
        } else {
            i = 1;
            iArr[0] = 0;
        }
        int i4 = 1;
        for (int i5 = i; i5 < size; i5++) {
            if (i5 != this.m_classVariable) {
                int i6 = 0;
                while (true) {
                    if (i6 >= i4) {
                        break;
                    }
                    if (dArr3[i5] > dArr3[iArr[i6]]) {
                        for (int i7 = i4; i7 > i6; i7--) {
                            iArr[i7] = iArr[i7 - 1];
                        }
                        iArr[i6] = i5;
                    } else {
                        i6++;
                    }
                }
                if (i6 == i4) {
                    iArr[i6] = i5;
                }
                i4++;
            }
        }
        return iArr;
    }
}
