package jsat.datatransform;

import java.util.Random;
import jsat.DataSet;
import jsat.classifiers.DataPoint;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.linear.DenseMatrix;
import jsat.linear.Matrix;
import jsat.linear.RandomMatrix;
import jsat.utils.random.XORWOW;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/datatransform/JLTransform.class */
public class JLTransform extends DataTransformBase {
    private static final long serialVersionUID = -8621368067861343912L;
    private TransformMode mode;
    private Matrix R;
    private int k;
    private boolean inMemory;

    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/datatransform/JLTransform$RandomMatrixJL.class */
    private static class RandomMatrixJL extends RandomMatrix {
        private static final long serialVersionUID = 2009377824896155918L;
        private double cnst;
        private TransformMode mode;

        public RandomMatrixJL(int i, int i2, long j, TransformMode transformMode) {
            super(i, i2, j);
            this.mode = transformMode;
            if (transformMode == TransformMode.GAUSS || transformMode == TransformMode.BINARY) {
                this.cnst = 1.0d / Math.sqrt(i);
            } else if (transformMode == TransformMode.SPARSE) {
                this.cnst = Math.sqrt(3.0d) / Math.sqrt(i);
            }
        }

        @Override // jsat.linear.RandomMatrix
        protected double getVal(Random random) {
            if (this.mode == TransformMode.GAUSS) {
                return random.nextGaussian() * this.cnst;
            }
            if (this.mode == TransformMode.BINARY) {
                return random.nextBoolean() ? -this.cnst : this.cnst;
            }
            if (this.mode != TransformMode.SPARSE) {
                throw new RuntimeException("BUG: Please report");
            }
            int nextInt = random.nextInt(6);
            if (nextInt == 0) {
                return -this.cnst;
            }
            if (nextInt == 1) {
                return this.cnst;
            }
            return 0.0d;
        }
    }

    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/datatransform/JLTransform$TransformMode.class */
    public enum TransformMode {
        GAUSS,
        BINARY,
        SPARSE
    }

    protected JLTransform(JLTransform jLTransform) {
        this.mode = jLTransform.mode;
        this.R = jLTransform.R.mo641clone();
    }

    public JLTransform() {
        this(50);
    }

    public JLTransform(int i) {
        this(i, TransformMode.SPARSE);
    }

    public JLTransform(int i, TransformMode transformMode) {
        this(i, transformMode, true);
    }

    public JLTransform(int i, TransformMode transformMode, boolean z) {
        this.mode = transformMode;
        this.k = i;
        this.inMemory = z;
    }

    @Override // jsat.datatransform.DataTransform
    public void fit(DataSet dataSet) {
        int numNumericalVars = dataSet.getNumNumericalVars();
        RandomMatrixJL randomMatrixJL = new RandomMatrixJL(this.k, numNumericalVars, new XORWOW().nextLong(), this.mode);
        this.R = randomMatrixJL;
        if (this.inMemory) {
            this.R = new DenseMatrix(this.k, numNumericalVars);
            this.R.mutableAdd(randomMatrixJL);
        }
    }

    public void setMode(TransformMode transformMode) {
        this.mode = transformMode;
    }

    public TransformMode getMode() {
        return this.mode;
    }

    public void setInMemory(boolean z) {
        this.inMemory = z;
    }

    public boolean isInMemory() {
        return this.inMemory;
    }

    public void setProjectedDimension(int i) {
        this.k = i;
    }

    public int getProjectedDimension() {
        return this.k;
    }

    public static Distribution guessProjectedDimension(DataSet dataSet) {
        double d = 100.0d;
        double d2 = 10.0d;
        if (dataSet.getNumNumericalVars() > 10000) {
            d2 = 100.0d;
            d = 1000.0d;
        }
        return new LogUniform(d2, d);
    }

    @Override // jsat.datatransform.DataTransform
    public DataPoint transform(DataPoint dataPoint) {
        return new DataPoint(this.R.multiply(dataPoint.getNumericalValues()), dataPoint.getCategoricalValues(), dataPoint.getCategoricalData(), dataPoint.getWeight());
    }

    @Override // jsat.datatransform.DataTransformBase
    public DataTransform clone() {
        return new JLTransform(this);
    }
}
