package jsat.datatransform.featureselection;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.classifiers.ClassificationDataSet;
import jsat.datatransform.RemoveAttributeTransform;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.linear.VecPaired;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.linear.vectorcollection.DefaultVectorCollectionFactory;
import jsat.linear.vectorcollection.VectorCollection;
import jsat.linear.vectorcollection.VectorCollectionFactory;
import jsat.utils.FakeExecutor;
import jsat.utils.IndexTable;
import jsat.utils.IntSet;
import jsat.utils.SystemInfo;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/datatransform/featureselection/ReliefF.class */
public class ReliefF extends RemoveAttributeTransform {
    private static final long serialVersionUID = -3336500245613075520L;
    private double[] w;
    private int featureCount;
    private int iterations;
    private int neighbors;
    private DistanceMetric dm;
    private VectorCollectionFactory<Vec> vcf;

    public ReliefF(int i) {
        this(i, 100, 15, new EuclideanDistance(), new DefaultVectorCollectionFactory());
    }

    public ReliefF(int i, int i2, int i3, DistanceMetric distanceMetric) {
        this(i, i2, i3, distanceMetric, new DefaultVectorCollectionFactory());
    }

    public ReliefF(ClassificationDataSet classificationDataSet, int i, int i2, int i3, DistanceMetric distanceMetric) {
        this(classificationDataSet, i, i2, i3, distanceMetric, new DefaultVectorCollectionFactory());
    }

    public ReliefF(ClassificationDataSet classificationDataSet, int i, int i2, int i3, DistanceMetric distanceMetric, ExecutorService executorService) {
        this(classificationDataSet, i, i2, i3, distanceMetric, new DefaultVectorCollectionFactory(), executorService);
    }

    public ReliefF(ClassificationDataSet classificationDataSet, int i, int i2, int i3, DistanceMetric distanceMetric, VectorCollectionFactory<Vec> vectorCollectionFactory) {
        this(classificationDataSet, i, i2, i3, distanceMetric, vectorCollectionFactory, null);
    }

    protected ReliefF(ReliefF reliefF) {
        super(reliefF);
        this.vcf = new DefaultVectorCollectionFactory();
        if (reliefF.w != null) {
            this.w = Arrays.copyOf(reliefF.w, reliefF.w.length);
        }
        this.dm = reliefF.dm.mo652clone();
        this.featureCount = reliefF.featureCount;
        this.iterations = reliefF.iterations;
        this.neighbors = reliefF.neighbors;
        this.vcf = reliefF.vcf.m677clone();
    }

    public ReliefF(int i, int i2, int i3, DistanceMetric distanceMetric, VectorCollectionFactory<Vec> vectorCollectionFactory) {
        this.vcf = new DefaultVectorCollectionFactory();
        setFeatureCount(i);
        setIterations(i2);
        setNeighbors(i3);
        setDistanceMetric(distanceMetric);
        this.vcf = vectorCollectionFactory;
    }

    public ReliefF(ClassificationDataSet classificationDataSet, int i, int i2, int i3, DistanceMetric distanceMetric, VectorCollectionFactory<Vec> vectorCollectionFactory, ExecutorService executorService) {
        this(i, i2, i3, distanceMetric, vectorCollectionFactory);
        fit(classificationDataSet, executorService);
    }

    @Override // jsat.datatransform.RemoveAttributeTransform, jsat.datatransform.DataTransform
    public void fit(DataSet dataSet) {
        fit(dataSet, null);
    }

    public void fit(DataSet dataSet, ExecutorService executorService) {
        if (!(dataSet instanceof ClassificationDataSet)) {
            throw new FailedToFitException("RelifF only works with classification datasets, not " + dataSet.getClass().getSimpleName());
        }
        final ClassificationDataSet classificationDataSet = (ClassificationDataSet) dataSet;
        super.fit(dataSet);
        this.w = new double[classificationDataSet.getNumNumericalVars()];
        double[] dArr = new double[this.w.length];
        Arrays.fill(dArr, Double.POSITIVE_INFINITY);
        final double[] dArr2 = new double[this.w.length];
        Arrays.fill(dArr2, Double.NEGATIVE_INFINITY);
        final double[] priors = classificationDataSet.getPriors();
        final List<Vec> dataVectors = classificationDataSet.getDataVectors();
        for (Vec vec : dataVectors) {
            for (int i = 0; i < vec.length(); i++) {
                dArr[i] = Math.min(dArr[i], vec.get(i));
                dArr2[i] = Math.max(dArr2[i], vec.get(i));
            }
        }
        for (int i2 = 0; i2 < dArr2.length; i2++) {
            int i3 = i2;
            dArr2[i3] = dArr2[i3] - dArr[i2];
        }
        final ArrayList arrayList = new ArrayList(priors.length);
        TrainableDistanceMetric.trainIfNeeded(this.dm, classificationDataSet, executorService);
        int i4 = 0;
        for (int i5 = 0; i5 < priors.length; i5++) {
            int classSampleCount = classificationDataSet.classSampleCount(i5);
            if (executorService == null) {
                arrayList.add(this.vcf.getVectorCollection(dataVectors.subList(i4, i4 + classSampleCount), this.dm));
            } else {
                arrayList.add(this.vcf.getVectorCollection(dataVectors.subList(i4, i4 + classSampleCount), this.dm, executorService));
            }
            i4 += classSampleCount;
        }
        final int i6 = this.iterations;
        final int i7 = this.neighbors;
        int i8 = executorService == null ? 1 : SystemInfo.LogicalCores;
        if (executorService == null) {
            executorService = new FakeExecutor();
        }
        int i9 = i6 / i8;
        final CountDownLatch countDownLatch = new CountDownLatch(i8);
        int i10 = 0;
        while (i10 < i8) {
            final int i11 = i10 < i6 % i8 ? i9 + 1 : i9;
            executorService.submit(new Runnable() { // from class: jsat.datatransform.featureselection.ReliefF.1
                @Override // java.lang.Runnable
                public void run() {
                    double[] dArr3 = new double[ReliefF.this.w.length];
                    Random random = new Random();
                    for (int i12 = 0; i12 < i11; i12++) {
                        int nextInt = random.nextInt(classificationDataSet.getSampleSize());
                        Vec vec2 = (Vec) dataVectors.get(nextInt);
                        int dataPointCategory = classificationDataSet.getDataPointCategory(nextInt);
                        int i13 = 0;
                        while (i13 < priors.length) {
                            int i14 = i13 == dataPointCategory ? i7 + 1 : i7;
                            List<VecPaired> search = ((VectorCollection) arrayList.get(i13)).search(vec2, i14);
                            if (i14 != i7) {
                                search = search.subList(1, i14);
                            }
                            for (int i15 = 0; i15 < ReliefF.this.w.length; i15++) {
                                for (VecPaired vecPaired : search) {
                                    if (i13 == dataPointCategory) {
                                        int i16 = i15;
                                        dArr3[i16] = dArr3[i16] - (ReliefF.this.diff(i15, vec2, vecPaired.getVector(), dArr2) / (i6 * i7));
                                    } else {
                                        int i17 = i15;
                                        dArr3[i17] = dArr3[i17] + (((priors[i13] / (1.0d - priors[dataPointCategory])) * ReliefF.this.diff(i15, vec2, vecPaired.getVector(), dArr2)) / (i6 * i7));
                                    }
                                }
                            }
                            i13++;
                        }
                    }
                    synchronized (ReliefF.this.w) {
                        for (int i18 = 0; i18 < ReliefF.this.w.length; i18++) {
                            double[] dArr4 = ReliefF.this.w;
                            int i19 = i18;
                            dArr4[i19] = dArr4[i19] + dArr3[i18];
                        }
                    }
                    countDownLatch.countDown();
                }
            });
            i10++;
        }
        try {
            countDownLatch.await();
        } catch (InterruptedException e) {
            Logger.getLogger(ReliefF.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
        }
        IndexTable indexTable = new IndexTable(this.w);
        Set<Integer> intSet = new IntSet(this.w.length * 2);
        for (int i12 = 0; i12 < this.w.length - this.featureCount; i12++) {
            intSet.add(Integer.valueOf(indexTable.index(i12)));
        }
        setUp(classificationDataSet, Collections.EMPTY_SET, intSet);
    }

    public Vec getWeights() {
        return new DenseVector(this.w);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public double diff(int i, Vec vec, Vec vec2, double[] dArr) {
        if (dArr[i] == 0.0d) {
            return 0.0d;
        }
        return Math.abs(vec.get(i) - vec2.get(i)) / dArr[i];
    }

    @Override // jsat.datatransform.RemoveAttributeTransform
    public ReliefF clone() {
        return new ReliefF(this);
    }

    public void setFeatureCount(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("Number of features to select must be positive, not " + i);
        }
        this.featureCount = i;
    }

    public int getFeatureCount() {
        return this.featureCount;
    }

    public void setIterations(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("Number of iterations must be positive, not " + i);
        }
        this.iterations = i;
    }

    public int getIterations() {
        return this.iterations;
    }

    public void setNeighbors(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("Number of neighbors must be positive, not " + i);
        }
        this.neighbors = i;
    }

    public int getNeighbors() {
        return this.neighbors;
    }

    public void setDistanceMetric(DistanceMetric distanceMetric) {
        this.dm = distanceMetric;
    }

    public DistanceMetric getDistanceMetric() {
        return this.dm;
    }
}
