package jsat.datatransform.featureselection;

import java.util.Iterator;
import java.util.Random;
import java.util.Set;
import jsat.DataSet;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.ClassificationModelEvaluation;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.datatransform.DataTransform;
import jsat.datatransform.RemoveAttributeTransform;
import jsat.regression.RegressionDataSet;
import jsat.regression.RegressionModelEvaluation;
import jsat.regression.Regressor;
import jsat.utils.IntSet;
import jsat.utils.ListUtils;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/datatransform/featureselection/SFS.class */
public class SFS implements DataTransform {
    private static final long serialVersionUID = 140187978708131002L;
    private RemoveAttributeTransform finalTransform;
    private Set<Integer> catSelected;
    private Set<Integer> numSelected;
    private double maxIncrease;
    private Classifier classifier;
    private Regressor regressor;
    private int minFeatures;
    private int maxFeatures;
    private int folds;
    private Object evaluator;

    private SFS(SFS sfs) {
        if (sfs.catSelected != null) {
            this.finalTransform = sfs.finalTransform.clone();
            this.catSelected = new IntSet(sfs.catSelected);
            this.numSelected = new IntSet(sfs.numSelected);
        }
        this.maxIncrease = sfs.maxIncrease;
        this.folds = sfs.folds;
        this.minFeatures = sfs.minFeatures;
        this.maxFeatures = sfs.maxFeatures;
        this.evaluator = sfs.evaluator;
        if (sfs.classifier != null) {
            this.classifier = sfs.classifier.m574clone();
        }
        if (sfs.regressor != null) {
            this.regressor = sfs.regressor.m574clone();
        }
    }

    public SFS(int i, int i2, Classifier classifier, double d) {
        this(i, i2, classifier.m574clone(), 3, d);
    }

    public SFS(int i, int i2, ClassificationDataSet classificationDataSet, Classifier classifier, int i3, double d) {
        this(i, i2, classifier.m574clone(), i3, d);
        search(i, i2, classificationDataSet, i3);
    }

    public SFS(int i, int i2, Regressor regressor, double d) {
        this(i, i2, regressor.m574clone(), 3, d);
    }

    public SFS(int i, int i2, RegressionDataSet regressionDataSet, Regressor regressor, int i3, double d) {
        this(i, i2, regressor.m574clone(), i3, d);
        search(i, i2, regressionDataSet, i3);
    }

    private SFS(int i, int i2, Object obj, int i3, double d) {
        setMinFeatures(i);
        setMaxFeatures(i2);
        setFolds(i3);
        setMaxIncrease(d);
        setEvaluator(obj);
    }

    @Override // jsat.datatransform.DataTransform
    public void fit(DataSet dataSet) {
        search(this.minFeatures, this.maxFeatures, dataSet, this.minFeatures);
    }

    private void search(int i, int i2, DataSet dataSet, int i3) {
        Random random = new Random();
        int numFeatures = dataSet.getNumFeatures();
        int numCategoricalVars = dataSet.getNumCategoricalVars();
        IntSet intSet = new IntSet();
        ListUtils.addRange(intSet, 0, numFeatures, 1);
        this.catSelected = new IntSet(dataSet.getNumCategoricalVars());
        this.numSelected = new IntSet(dataSet.getNumNumericalVars());
        IntSet intSet2 = new IntSet(dataSet.getNumCategoricalVars());
        IntSet intSet3 = new IntSet(dataSet.getNumNumericalVars());
        ListUtils.addRange(intSet2, 0, numCategoricalVars, 1);
        ListUtils.addRange(intSet3, 0, numFeatures - numCategoricalVars, 1);
        double[] dArr = {Double.POSITIVE_INFINITY};
        Cloneable cloneable = this.regressor;
        if (dataSet instanceof ClassificationDataSet) {
            cloneable = this.classifier;
        }
        while (this.catSelected.size() + this.numSelected.size() < i2 && SFSSelectFeature(intSet, dataSet, intSet2, intSet3, this.catSelected, this.numSelected, cloneable, i3, random, dArr, i) >= 0) {
        }
        this.finalTransform = new RemoveAttributeTransform(dataSet, intSet2, intSet3);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static void addFeature(int i, int i2, Set<Integer> set, Set<Integer> set2) {
        if (i >= i2) {
            set2.add(Integer.valueOf(i - i2));
        } else {
            set.add(Integer.valueOf(i));
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static void removeFeature(int i, int i2, Set<Integer> set, Set<Integer> set2) {
        if (i >= i2) {
            set2.remove(Integer.valueOf(i - i2));
        } else {
            set.remove(Integer.valueOf(i));
        }
    }

    @Override // jsat.datatransform.DataTransform
    public DataPoint transform(DataPoint dataPoint) {
        return this.finalTransform.transform(dataPoint);
    }

    @Override // jsat.datatransform.DataTransform
    public SFS clone() {
        return new SFS(this);
    }

    public Set<Integer> getSelectedCategorical() {
        return new IntSet(this.catSelected);
    }

    public Set<Integer> getSelectedNumerical() {
        return new IntSet(this.numSelected);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static int SFSSelectFeature(Set<Integer> set, DataSet dataSet, Set<Integer> set2, Set<Integer> set3, Set<Integer> set4, Set<Integer> set5, Object obj, int i, Random random, double[] dArr, int i2) {
        int numCategoricalVars = dataSet.getNumCategoricalVars();
        int i3 = -1;
        double d = Double.POSITIVE_INFINITY;
        Iterator<Integer> it = set.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            removeFeature(intValue, numCategoricalVars, set2, set3);
            DataSet shallowClone2 = dataSet.shallowClone2();
            shallowClone2.applyTransform(new RemoveAttributeTransform(shallowClone2, set2, set3));
            double score = getScore(shallowClone2, obj, i, random);
            if (score < d) {
                d = score;
                i3 = intValue;
            }
            addFeature(intValue, numCategoricalVars, set2, set3);
        }
        if (d <= 1.0E-14d && dArr[0] <= 1.0E-14d && set4.size() + set5.size() >= i2) {
            return -1;
        }
        if (d >= dArr[0] && set4.size() + set5.size() >= i2 && Math.abs(dArr[0] - d) >= 0.001d) {
            return -1;
        }
        dArr[0] = d;
        addFeature(i3, numCategoricalVars, set4, set5);
        removeFeature(i3, numCategoricalVars, set2, set3);
        set.remove(Integer.valueOf(i3));
        return i3;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static double getScore(DataSet dataSet, Object obj, int i, Random random) {
        if (dataSet instanceof ClassificationDataSet) {
            ClassificationModelEvaluation classificationModelEvaluation = new ClassificationModelEvaluation((Classifier) obj, (ClassificationDataSet) dataSet);
            classificationModelEvaluation.evaluateCrossValidation(i, random);
            return classificationModelEvaluation.getErrorRate();
        }
        if (!(dataSet instanceof RegressionDataSet)) {
            return Double.POSITIVE_INFINITY;
        }
        RegressionModelEvaluation regressionModelEvaluation = new RegressionModelEvaluation((Regressor) obj, (RegressionDataSet) dataSet);
        regressionModelEvaluation.evaluateCrossValidation(i, random);
        return regressionModelEvaluation.getMeanError();
    }

    public void setMaxIncrease(double d) {
        if (d < 0.0d) {
            throw new IllegalArgumentException("Decarese must be a positive value, not " + d);
        }
        this.maxIncrease = d;
    }

    public double getMaxIncrease() {
        return this.maxIncrease;
    }

    public void setMinFeatures(int i) {
        this.minFeatures = i;
    }

    public int getMinFeatures() {
        return this.minFeatures;
    }

    public void setMaxFeatures(int i) {
        this.maxFeatures = i;
    }

    public int getMaxFeatures() {
        return this.maxFeatures;
    }

    public void setFolds(int i) {
        if (i <= 0) {
            throw new IllegalArgumentException("Number of CV folds must be positive, not " + i);
        }
        this.folds = i;
    }

    public int getFolds() {
        return this.folds;
    }

    private void setEvaluator(Object obj) {
        this.evaluator = obj;
        if (obj instanceof Classifier) {
            this.classifier = (Classifier) obj;
        }
        if (obj instanceof Regressor) {
            this.regressor = (Regressor) obj;
        }
    }
}
