package jsat.datatransform;

import java.util.Arrays;
import jsat.DataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.DataPoint;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.distributions.discrete.UniformDiscrete;
import jsat.linear.DenseVector;
import jsat.linear.Vec;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/datatransform/NumericalToHistogram.class */
public class NumericalToHistogram implements DataTransform {
    private static final long serialVersionUID = -2318706869393636074L;
    private int n;
    double[][] conversionArray;
    CategoricalData[] newDataArray;

    public NumericalToHistogram() {
        this(25);
    }

    public NumericalToHistogram(DataSet dataSet) {
        this(dataSet, (int) Math.ceil(Math.sqrt(dataSet.getSampleSize())));
    }

    public NumericalToHistogram(int i) {
        setNumberOfBins(i);
    }

    public NumericalToHistogram(DataSet dataSet, int i) {
        this(i);
        fit(dataSet);
    }

    @Override // jsat.datatransform.DataTransform
    public void fit(DataSet dataSet) {
        this.conversionArray = new double[dataSet.getNumNumericalVars()][2];
        double[] dArr = new double[this.conversionArray.length];
        double[] dArr2 = new double[this.conversionArray.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = Double.MAX_VALUE;
            dArr2[i] = Double.MIN_VALUE;
        }
        for (int i2 = 0; i2 < dataSet.getSampleSize(); i2++) {
            Vec numericalValues = dataSet.getDataPoint(i2).getNumericalValues();
            for (int i3 = 0; i3 < dArr.length; i3++) {
                double d = numericalValues.get(i3);
                if (!Double.isNaN(d)) {
                    dArr[i3] = Math.min(dArr[i3], d);
                    dArr2[i3] = Math.max(dArr2[i3], d);
                }
            }
        }
        for (int i4 = 0; i4 < this.conversionArray.length; i4++) {
            this.conversionArray[i4][0] = dArr[i4];
            this.conversionArray[i4][1] = (dArr2[i4] - dArr[i4]) / this.n;
        }
        this.newDataArray = new CategoricalData[dataSet.getNumNumericalVars() + dataSet.getNumCategoricalVars()];
        for (int i5 = 0; i5 < dataSet.getNumNumericalVars(); i5++) {
            this.newDataArray[i5] = new CategoricalData(this.n);
        }
        System.arraycopy(dataSet.getCategories(), 0, this.newDataArray, dataSet.getNumNumericalVars(), dataSet.getNumCategoricalVars());
    }

    public void setNumberOfBins(int i) {
        if (i <= 0) {
            throw new RuntimeException("Must partition into a positive number of groups");
        }
        this.n = i;
    }

    public int getNumberOfBins() {
        return this.n;
    }

    public static Distribution guessNumberOfBins(DataSet dataSet) {
        if (dataSet.getSampleSize() < 20) {
            return new UniformDiscrete(2, dataSet.getSampleSize() - 1);
        }
        if (dataSet.getSampleSize() >= 1000000) {
            return new LogUniform(50.0d, 1000.0d);
        }
        int sqrt = (int) Math.sqrt(dataSet.getSampleSize());
        return new UniformDiscrete(Math.max(sqrt / 3, 2), Math.min(sqrt * 3, dataSet.getSampleSize() - 1));
    }

    /* JADX WARN: Type inference failed for: r1v13, types: [double[], double[][]] */
    private NumericalToHistogram(NumericalToHistogram numericalToHistogram) {
        this.n = numericalToHistogram.n;
        if (numericalToHistogram.conversionArray != null) {
            this.conversionArray = new double[numericalToHistogram.conversionArray.length];
            for (int i = 0; i < numericalToHistogram.conversionArray.length; i++) {
                this.conversionArray[i] = Arrays.copyOf(numericalToHistogram.conversionArray[i], numericalToHistogram.conversionArray[i].length);
            }
        }
        if (numericalToHistogram.newDataArray != null) {
            this.newDataArray = new CategoricalData[numericalToHistogram.newDataArray.length];
            for (int i2 = 0; i2 < numericalToHistogram.newDataArray.length; i2++) {
                this.newDataArray[i2] = numericalToHistogram.newDataArray[i2].m480clone();
            }
        }
    }

    @Override // jsat.datatransform.DataTransform
    public DataPoint transform(DataPoint dataPoint) {
        int[] iArr = new int[this.newDataArray.length];
        Vec numericalValues = dataPoint.getNumericalValues();
        for (int i = 0; i < this.conversionArray.length; i++) {
            double d = numericalValues.get(i) - this.conversionArray[i][0];
            if (Double.isNaN(d)) {
                iArr[i] = -1;
            } else {
                int floor = (int) Math.floor(d / this.conversionArray[i][1]);
                if (floor < 0) {
                    floor = 0;
                } else if (floor >= this.n) {
                    floor = this.n - 1;
                }
                iArr[i] = floor;
            }
        }
        System.arraycopy(dataPoint.getCategoricalValues(), 0, iArr, this.conversionArray.length, dataPoint.numCategoricalValues());
        return new DataPoint(new DenseVector(0), iArr, this.newDataArray);
    }

    @Override // jsat.datatransform.DataTransform
    public NumericalToHistogram clone() {
        return new NumericalToHistogram(this);
    }
}
