package jsat;

import java.lang.ref.SoftReference;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.DataPoint;
import jsat.datatransform.DataTransform;
import jsat.datatransform.InPlaceTransform;
import jsat.linear.ConstantVector;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.Matrix;
import jsat.linear.MatrixOfVecs;
import jsat.linear.MatrixStatistics;
import jsat.linear.SparseVector;
import jsat.linear.Vec;
import jsat.math.OnLineStatistics;
import jsat.utils.FakeExecutor;
import jsat.utils.IntList;
import jsat.utils.ListUtils;
import jsat.utils.SystemInfo;
import jsat.utils.random.XORWOW;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/DataSet.class */
public abstract class DataSet<Type extends DataSet> {
    protected int numNumerVals;
    protected CategoricalData[] categories;
    protected List<String> numericalVariableNames;
    protected Map<Integer, SoftReference<Vec>> columnVecCache = new HashMap();

    public boolean setNumericName(String str, int i) {
        String lowerCase = str.toLowerCase();
        if (this.numericalVariableNames.contains(lowerCase) || i >= getNumNumericalVars() || i < 0) {
            return false;
        }
        this.numericalVariableNames.set(i, lowerCase);
        return true;
    }

    public String getNumericName(int i) {
        if (i >= getNumNumericalVars() || i < 0) {
            throw new IndexOutOfBoundsException("Can not acces variable for invalid index  " + i);
        }
        if (this.numericalVariableNames == null) {
            return null;
        }
        return this.numericalVariableNames.get(i);
    }

    public String getCategoryName(int i) {
        if (i >= getNumCategoricalVars() || i < 0) {
            throw new IndexOutOfBoundsException("Can not acces variable for invalid index  " + i);
        }
        return this.categories[i].getCategoryName();
    }

    public void applyTransform(DataTransform dataTransform) {
        applyTransform(dataTransform, false);
    }

    public void applyTransform(DataTransform dataTransform, ExecutorService executorService) {
        if (executorService == null || (executorService instanceof FakeExecutor)) {
            applyTransform(dataTransform);
        } else {
            applyTransform(dataTransform, false, executorService);
        }
    }

    public void applyTransform(DataTransform dataTransform, boolean z) {
        applyTransform(dataTransform, z, new FakeExecutor());
    }

    public void applyTransform(final DataTransform dataTransform, boolean z, ExecutorService executorService) {
        final CountDownLatch countDownLatch = new CountDownLatch(SystemInfo.LogicalCores);
        if (executorService == null) {
            executorService = new FakeExecutor();
        }
        if (z && (dataTransform instanceof InPlaceTransform)) {
            final InPlaceTransform inPlaceTransform = (InPlaceTransform) dataTransform;
            for (int i = 0; i < SystemInfo.LogicalCores; i++) {
                final int i2 = i;
                executorService.submit(new Runnable() { // from class: jsat.DataSet.1
                    @Override // java.lang.Runnable
                    public void run() {
                        int i3 = i2;
                        while (true) {
                            int i4 = i3;
                            if (i4 >= DataSet.this.getSampleSize()) {
                                countDownLatch.countDown();
                                return;
                            } else {
                                inPlaceTransform.mutableTransform(DataSet.this.getDataPoint(i4));
                                i3 = i4 + SystemInfo.LogicalCores;
                            }
                        }
                    }
                });
            }
        } else {
            for (int i3 = 0; i3 < SystemInfo.LogicalCores; i3++) {
                final int i4 = i3;
                executorService.submit(new Runnable() { // from class: jsat.DataSet.2
                    @Override // java.lang.Runnable
                    public void run() {
                        int i5 = i4;
                        while (true) {
                            int i6 = i5;
                            if (i6 >= DataSet.this.getSampleSize()) {
                                countDownLatch.countDown();
                                return;
                            } else {
                                DataSet.this.setDataPoint(i6, dataTransform.transform(DataSet.this.getDataPoint(i6)));
                                i5 = i6 + SystemInfo.LogicalCores;
                            }
                        }
                    }
                });
            }
        }
        try {
            countDownLatch.await();
            this.columnVecCache.clear();
            this.numNumerVals = getDataPoint(0).numNumericalValues();
            this.categories = getDataPoint(0).getCategoricalData();
            if (this.numericalVariableNames != null) {
                this.numericalVariableNames.clear();
                for (int i5 = 0; i5 < getNumNumericalVars(); i5++) {
                    this.numericalVariableNames.add("TN" + (i5 + 1));
                }
            }
        } catch (InterruptedException e) {
            Logger.getLogger(DataSet.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
        }
    }

    public void replaceNumericFeatures(List<Vec> list) {
        if (getSampleSize() != list.size()) {
            throw new RuntimeException("Input list does not have the same not of dataums as the dataset");
        }
        for (int i = 0; i < list.size(); i++) {
            DataPoint dataPoint = getDataPoint(i);
            setDataPoint(i, new DataPoint(list.get(i), dataPoint.getCategoricalValues(), dataPoint.getCategoricalData(), dataPoint.getWeight()));
        }
        this.numNumerVals = getDataPoint(0).numNumericalValues();
        if (this.numericalVariableNames != null) {
            this.numericalVariableNames.clear();
            for (int i2 = 0; i2 < getNumNumericalVars(); i2++) {
                this.numericalVariableNames.add("TN" + (i2 + 1));
            }
        }
    }

    public abstract DataPoint getDataPoint(int i);

    public abstract void setDataPoint(int i, DataPoint dataPoint);

    public OnLineStatistics[] getOnlineColumnStats(boolean z) {
        OnLineStatistics[] onLineStatisticsArr = new OnLineStatistics[this.numNumerVals];
        for (int i = 0; i < onLineStatisticsArr.length; i++) {
            onLineStatisticsArr[i] = new OnLineStatistics();
        }
        double d = 0.0d;
        double[] dArr = new double[this.numNumerVals];
        Iterator<DataPoint> dataPointIterator = getDataPointIterator();
        while (dataPointIterator.hasNext()) {
            DataPoint next = dataPointIterator.next();
            double weight = z ? next.getWeight() : 1.0d;
            d += weight;
            Iterator<IndexValue> it = next.getNumericalValues().iterator();
            while (it.hasNext()) {
                IndexValue next2 = it.next();
                if (Double.isNaN(next2.getValue())) {
                    int index = next2.getIndex();
                    dArr[index] = dArr[index] + weight;
                } else {
                    onLineStatisticsArr[next2.getIndex()].add(next2.getValue(), weight);
                }
            }
        }
        double d2 = d;
        for (int i2 = 0; i2 < onLineStatisticsArr.length; i2++) {
            onLineStatisticsArr[i2].add(0.0d, (d2 - onLineStatisticsArr[i2].getSumOfWeights()) - dArr[i2]);
        }
        return onLineStatisticsArr;
    }

    public OnLineStatistics getOnlineDenseStats() {
        OnLineStatistics onLineStatistics = new OnLineStatistics();
        double numNumericalVars = getNumNumericalVars();
        for (int i = 0; i < getSampleSize(); i++) {
            onLineStatistics.add(getDataPoint(i).getNumericalValues().nnz() / numNumericalVars);
        }
        return onLineStatistics;
    }

    public Vec[] getColumnMeanVariance() {
        int numNumericalVars = getNumNumericalVars();
        Vec[] vecArr = {new DenseVector(numNumericalVars), new DenseVector(numNumericalVars)};
        Vec vec = vecArr[0];
        Vec vec2 = vecArr[1];
        MatrixStatistics.meanVector(vec, this);
        MatrixStatistics.covarianceDiag(vec, vec2, this);
        return vecArr;
    }

    public Iterator<DataPoint> getDataPointIterator() {
        return new Iterator<DataPoint>() { // from class: jsat.DataSet.3
            int cur = 0;
            int to;

            {
                this.to = DataSet.this.getSampleSize();
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return this.cur < this.to;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public DataPoint next() {
                DataSet dataSet = DataSet.this;
                int i = this.cur;
                this.cur = i + 1;
                return dataSet.getDataPoint(i);
            }

            @Override // java.util.Iterator
            public void remove() {
                throw new UnsupportedOperationException("This operation is not supported for DataSet");
            }
        };
    }

    public abstract int getSampleSize();

    public int getNumCategoricalVars() {
        return this.categories.length;
    }

    public int getNumNumericalVars() {
        return this.numNumerVals;
    }

    public CategoricalData[] getCategories() {
        return this.categories;
    }

    protected abstract Type getSubset(List<Integer> list);

    public Type getMissingDropped() {
        IntList intList = new IntList();
        for (int i = 0; i < getSampleSize(); i++) {
            DataPoint dataPoint = getDataPoint(i);
            boolean z = dataPoint.getNumericalValues().countNaNs() > 0;
            for (int i2 : dataPoint.getCategoricalValues()) {
                if (i2 < 0) {
                    z = true;
                }
            }
            if (!z) {
                intList.add((IntList) Integer.valueOf(i));
            }
        }
        return getSubset(intList);
    }

    public List<Type> randomSplit(Random random, double... dArr) {
        if (dArr.length < 1) {
            throw new IllegalArgumentException("Input array of split fractions must be non-empty");
        }
        IntList intList = new IntList(getSampleSize());
        ListUtils.addRange(intList, 0, getSampleSize(), 1);
        Collections.shuffle(intList, random);
        int[] iArr = new int[dArr.length];
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            d += dArr[i];
            if (d >= 1.001d) {
                throw new IllegalArgumentException("Input splits sum is greater than 1 by index " + i + " reaching a sum of " + d);
            }
            iArr[i] = (int) Math.round(d * intList.size());
        }
        ArrayList arrayList = new ArrayList(dArr.length);
        int i2 = 0;
        for (int i3 = 0; i3 < iArr.length; i3++) {
            arrayList.add(getSubset(intList.subList(i2, iArr[i3])));
            i2 = iArr[i3];
        }
        return arrayList;
    }

    public List<Type> randomSplit(double... dArr) {
        return randomSplit(new XORWOW(), dArr);
    }

    public List<Type> cvSet(int i, Random random) {
        double[] dArr = new double[i];
        Arrays.fill(dArr, 1.0d / i);
        return randomSplit(random, dArr);
    }

    public List<Type> cvSet(int i) {
        return cvSet(i, new XORWOW());
    }

    public List<DataPoint> getDataPoints() {
        ArrayList arrayList = new ArrayList(getSampleSize());
        for (int i = 0; i < getSampleSize(); i++) {
            arrayList.add(getDataPoint(i));
        }
        return arrayList;
    }

    public List<Vec> getDataVectors() {
        ArrayList arrayList = new ArrayList(getSampleSize());
        for (int i = 0; i < getSampleSize(); i++) {
            arrayList.add(getDataPoint(i).getNumericalValues());
        }
        return arrayList;
    }

    public Vec getNumericColumn(int i) {
        Vec vec;
        if (i < 0 || i >= getNumNumericalVars()) {
            throw new IndexOutOfBoundsException("There is no index for column " + i);
        }
        SoftReference<Vec> softReference = this.columnVecCache.get(Integer.valueOf(i));
        if (softReference != null && (vec = softReference.get()) != null) {
            return vec;
        }
        Vec denseVector = new DenseVector(getSampleSize());
        for (int i2 = 0; i2 < getSampleSize(); i2++) {
            denseVector.set(i2, getDataPoint(i2).getNumericalValues().get(i));
        }
        Vec sparseVector = getSparsityStats().getMean() < 0.6d ? new SparseVector(denseVector) : denseVector;
        this.columnVecCache.put(Integer.valueOf(i), new SoftReference<>(sparseVector));
        return sparseVector;
    }

    public long countMissingValues() {
        long j = 0;
        for (int i = 0; i < getSampleSize(); i++) {
            j += r0.getNumericalValues().countNaNs();
            for (int i2 : getDataPoint(i).getCategoricalValues()) {
                if (i2 < 0) {
                    j++;
                }
            }
        }
        return j;
    }

    public Vec[] getNumericColumns() {
        return getNumericColumns(Collections.EMPTY_SET);
    }

    public Vec[] getNumericColumns(Set<Integer> set) {
        boolean z = getSparsityStats().getMean() < 0.6d;
        Vec[] vecArr = new Vec[getNumNumericalVars()];
        boolean[] zArr = new boolean[vecArr.length];
        Arrays.fill(zArr, false);
        for (int i = 0; i < vecArr.length; i++) {
            if (!set.contains(Integer.valueOf(i))) {
                SoftReference<Vec> softReference = this.columnVecCache.get(Integer.valueOf(i));
                if (softReference != null) {
                    Vec vec = softReference.get();
                    if (vec != null) {
                        vecArr[i] = vec;
                        zArr[i] = true;
                    } else {
                        Map<Integer, SoftReference<Vec>> map = this.columnVecCache;
                        Integer valueOf = Integer.valueOf(i);
                        int i2 = i;
                        Vec sparseVector = z ? new SparseVector(getSampleSize()) : new DenseVector(getSampleSize());
                        vecArr[i2] = sparseVector;
                        map.put(valueOf, new SoftReference<>(sparseVector));
                    }
                } else {
                    Map<Integer, SoftReference<Vec>> map2 = this.columnVecCache;
                    Integer valueOf2 = Integer.valueOf(i);
                    int i3 = i;
                    Vec sparseVector2 = z ? new SparseVector(getSampleSize()) : new DenseVector(getSampleSize());
                    vecArr[i3] = sparseVector2;
                    map2.put(valueOf2, new SoftReference<>(sparseVector2));
                }
            }
        }
        for (int i4 = 0; i4 < getSampleSize(); i4++) {
            Iterator<IndexValue> it = getDataPoint(i4).getNumericalValues().iterator();
            while (it.hasNext()) {
                IndexValue next = it.next();
                int index = next.getIndex();
                if (vecArr[index] != null && !zArr[index]) {
                    vecArr[index].set(i4, next.getValue());
                }
            }
        }
        return vecArr;
    }

    public Matrix getDataMatrix() {
        DenseMatrix denseMatrix = new DenseMatrix(getSampleSize(), getNumNumericalVars());
        for (int i = 0; i < getSampleSize(); i++) {
            Vec numericalValues = getDataPoint(i).getNumericalValues();
            for (int i2 = 0; i2 < numericalValues.length(); i2++) {
                denseMatrix.set(i, i2, numericalValues.get(i2));
            }
        }
        return denseMatrix;
    }

    public Matrix getDataMatrixView() {
        return new MatrixOfVecs(getDataVectors());
    }

    public int getNumFeatures() {
        return getNumCategoricalVars() + getNumNumericalVars();
    }

    /* renamed from: shallowClone */
    public abstract DataSet<Type> shallowClone2();

    public DataSet getTwiceShallowClone() {
        DataSet<Type> shallowClone2 = shallowClone2();
        for (int i = 0; i < shallowClone2.getSampleSize(); i++) {
            DataPoint dataPoint = getDataPoint(i);
            shallowClone2.setDataPoint(i, new DataPoint(dataPoint.getNumericalValues(), dataPoint.getCategoricalValues(), dataPoint.getCategoricalData()));
        }
        return shallowClone2;
    }

    public OnLineStatistics getSparsityStats() {
        OnLineStatistics onLineStatistics = new OnLineStatistics();
        for (int i = 0; i < getSampleSize(); i++) {
            if (getDataPoint(i).getNumericalValues().isSparse()) {
                onLineStatistics.add(r0.nnz() / r0.length());
            } else {
                onLineStatistics.add(1.0d);
            }
        }
        return onLineStatistics;
    }

    public Vec getDataWeights() {
        int sampleSize = getSampleSize();
        if (sampleSize == 0) {
            return new DenseVector(0);
        }
        double weight = getDataPoint(0).getWeight();
        double[] dArr = null;
        for (int i = 1; i < sampleSize; i++) {
            double weight2 = getDataPoint(i).getWeight();
            if (dArr != null || weight != weight2) {
                if (dArr == null) {
                    dArr = new double[sampleSize];
                    Arrays.fill(dArr, 0, i, weight);
                }
                dArr[i] = weight2;
            }
        }
        return dArr == null ? new ConstantVector(weight, getSampleSize()) : new DenseVector(dArr);
    }
}
