package jsat.datatransform.featureselection;

import java.util.Iterator;
import java.util.Random;
import java.util.Set;
import jsat.DataSet;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.datatransform.DataTransform;
import jsat.datatransform.RemoveAttributeTransform;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.utils.IntSet;
import jsat.utils.ListUtils;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/datatransform/featureselection/LRS.class */
public class LRS implements DataTransform {
    private static final long serialVersionUID = 3065300352046535656L;
    private RemoveAttributeTransform finalTransform;
    private Set<Integer> catSelected;
    private Set<Integer> numSelected;
    private int L;
    private int R;
    private Object evaluater;
    private int folds;

    private LRS(LRS lrs) {
        this.L = lrs.L;
        this.R = lrs.R;
        this.folds = lrs.folds;
        this.evaluater = lrs.evaluater;
        if (lrs.catSelected != null) {
            this.finalTransform = lrs.finalTransform.clone();
            this.catSelected = new IntSet(lrs.catSelected);
            this.numSelected = new IntSet(lrs.numSelected);
        }
    }

    public LRS(int i, int i2, Classifier classifier, int i3) {
        setFeaturesToAdd(i);
        setFeaturesToRemove(i2);
        setFolds(i3);
        setEvaluator(classifier);
    }

    public LRS(int i, int i2, ClassificationDataSet classificationDataSet, Classifier classifier, int i3) {
        search(classificationDataSet, i, i2, classifier, i3);
    }

    public LRS(int i, int i2, Regressor regressor, int i3) {
        setFeaturesToAdd(i);
        setFeaturesToRemove(i2);
        setFolds(i3);
        setEvaluator(regressor);
    }

    public LRS(int i, int i2, RegressionDataSet regressionDataSet, Regressor regressor, int i3) {
        this(i, i2, regressor, i3);
        search(regressionDataSet, i, i2, regressor, i3);
    }

    @Override // jsat.datatransform.DataTransform
    public DataPoint transform(DataPoint dataPoint) {
        return this.finalTransform.transform(dataPoint);
    }

    @Override // jsat.datatransform.DataTransform
    public LRS clone() {
        return new LRS(this);
    }

    public Set<Integer> getSelectedCategorical() {
        return new IntSet(this.catSelected);
    }

    public Set<Integer> getSelectedNumerical() {
        return new IntSet(this.numSelected);
    }

    @Override // jsat.datatransform.DataTransform
    public void fit(DataSet dataSet) {
        search(dataSet, this.L, this.R, this.evaluater, this.folds);
    }

    private void search(DataSet dataSet, int i, int i2, Object obj, int i3) {
        int numFeatures = dataSet.getNumFeatures();
        int numCategoricalVars = dataSet.getNumCategoricalVars();
        this.catSelected = new IntSet(numCategoricalVars);
        this.numSelected = new IntSet(numFeatures - numCategoricalVars);
        IntSet intSet = new IntSet(numCategoricalVars);
        IntSet intSet2 = new IntSet(numFeatures - numCategoricalVars);
        IntSet intSet3 = new IntSet(numFeatures);
        ListUtils.addRange(intSet3, 0, numFeatures, 1);
        Random random = new Random();
        double[] dArr = {Double.POSITIVE_INFINITY};
        if (i > i2) {
            ListUtils.addRange(intSet, 0, numCategoricalVars, 1);
            ListUtils.addRange(intSet2, 0, numFeatures - numCategoricalVars, 1);
            for (int i4 = 0; i4 < i; i4++) {
                SFS.SFSSelectFeature(intSet3, dataSet, intSet, intSet2, this.catSelected, this.numSelected, obj, i3, random, dArr, i);
            }
            intSet3.clear();
            intSet3.addAll(this.catSelected);
            Iterator<Integer> it = this.numSelected.iterator();
            while (it.hasNext()) {
                intSet3.add((IntSet) Integer.valueOf(it.next().intValue() + numCategoricalVars));
            }
            for (int i5 = 0; i5 < i2; i5++) {
                SBS.SBSRemoveFeature(intSet3, dataSet, intSet, intSet2, this.catSelected, this.numSelected, obj, i3, random, i - i2, dArr, 0.0d);
            }
        } else if (i < i2) {
            ListUtils.addRange(this.catSelected, 0, numCategoricalVars, 1);
            ListUtils.addRange(this.numSelected, 0, numFeatures - numCategoricalVars, 1);
            for (int i6 = 0; i6 < i2; i6++) {
                SBS.SBSRemoveFeature(intSet3, dataSet, intSet, intSet2, this.catSelected, this.numSelected, obj, i3, random, numFeatures - i2, dArr, 0.0d);
            }
            intSet3.clear();
            intSet3.addAll(intSet);
            Iterator<Integer> it2 = intSet2.iterator();
            while (it2.hasNext()) {
                intSet3.add((IntSet) Integer.valueOf(it2.next().intValue() + numCategoricalVars));
            }
            for (int i7 = 0; i7 < i; i7++) {
                SFS.SFSSelectFeature(intSet3, dataSet, intSet, intSet2, this.catSelected, this.numSelected, obj, i3, random, dArr, i2 - i);
            }
        }
        this.finalTransform = new RemoveAttributeTransform(dataSet, intSet, intSet2);
    }

    public void setFeaturesToAdd(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("Number of features to add must be positive, not " + i);
        }
        this.L = i;
    }

    public int getFeaturesToAdd() {
        return this.L;
    }

    public void setFeaturesToRemove(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("Number of features to remove must be positive, not " + i);
        }
        this.R = i;
    }

    public int getFeaturesToRemove() {
        return this.R;
    }

    public void setFolds(int i) {
        if (i <= 0) {
            throw new IllegalArgumentException("Number of CV folds must be positive, not " + i);
        }
        this.folds = i;
    }

    public int getFolds() {
        return this.folds;
    }

    private void setEvaluator(Object obj) {
        this.evaluater = obj;
    }
}
