package jsat.datatransform;

import jsat.DataSet;
import jsat.classifiers.DataPoint;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.math.OnLineStatistics;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/datatransform/LinearTransform.class */
public class LinearTransform implements InPlaceInvertibleTransform {
    private static final long serialVersionUID = 5580283565080452022L;
    private double A;
    private double B;
    private Vec mins;
    private Vec mutliplyConstants;

    public LinearTransform() {
        this(1.0d, 0.0d);
    }

    public LinearTransform(DataSet dataSet) {
        this(dataSet, 1.0d, 0.0d);
    }

    public LinearTransform(double d, double d2) {
        setRange(d, d2);
    }

    public LinearTransform(DataSet dataSet, double d, double d2) {
        this(d, d2);
        fit(dataSet);
    }

    public void setRange(double d, double d2) {
        if (d == d2) {
            throw new RuntimeException("Values must be different");
        }
        if (d2 > d) {
            d = d2;
            d2 = d;
        }
        this.A = d;
        this.B = d2;
    }

    @Override // jsat.datatransform.DataTransform
    public void fit(DataSet dataSet) {
        this.mins = new DenseVector(dataSet.getNumNumericalVars());
        DenseVector denseVector = new DenseVector(this.mins.length());
        this.mutliplyConstants = new DenseVector(this.mins.length());
        OnLineStatistics[] onlineColumnStats = dataSet.getOnlineColumnStats(false);
        for (int i = 0; i < this.mins.length(); i++) {
            double min = onlineColumnStats[i].getMin();
            double max = onlineColumnStats[i].getMax();
            if (max - min < 1.0E-6d) {
                this.mins.set(i, 0.0d);
                denseVector.set(i, 1.0d);
                this.mutliplyConstants.set(i, 1.0d);
            } else {
                this.mins.set(i, min);
                denseVector.set(i, max);
                this.mutliplyConstants.set(i, this.A - this.B);
            }
        }
        denseVector.mutableSubtract(this.mins);
        this.mutliplyConstants.mutablePairwiseDivide(denseVector);
    }

    private LinearTransform(LinearTransform linearTransform) {
        this.A = linearTransform.A;
        this.B = linearTransform.B;
        if (linearTransform.mins != null) {
            this.mins = linearTransform.mins.mo525clone();
        }
        if (linearTransform.mutliplyConstants != null) {
            this.mutliplyConstants = linearTransform.mutliplyConstants.mo525clone();
        }
    }

    @Override // jsat.datatransform.DataTransform
    public DataPoint transform(DataPoint dataPoint) {
        DataPoint m486clone = dataPoint.m486clone();
        mutableTransform(m486clone);
        return m486clone;
    }

    @Override // jsat.datatransform.DataTransform
    public LinearTransform clone() {
        return new LinearTransform(this);
    }

    @Override // jsat.datatransform.InPlaceInvertibleTransform
    public void mutableInverse(DataPoint dataPoint) {
        Vec numericalValues = dataPoint.getNumericalValues();
        numericalValues.mutableSubtract(this.B);
        numericalValues.mutablePairwiseDivide(this.mutliplyConstants);
        numericalValues.mutableAdd(this.mins);
    }

    @Override // jsat.datatransform.InPlaceTransform
    public void mutableTransform(DataPoint dataPoint) {
        Vec numericalValues = dataPoint.getNumericalValues();
        numericalValues.mutableSubtract(this.mins);
        numericalValues.mutablePairwiseMultiply(this.mutliplyConstants);
        numericalValues.mutableAdd(this.B);
    }

    @Override // jsat.datatransform.InPlaceTransform
    public boolean mutatesNominal() {
        return false;
    }

    @Override // jsat.datatransform.InvertibleTransform
    public DataPoint inverse(DataPoint dataPoint) {
        DataPoint m486clone = dataPoint.m486clone();
        mutableInverse(m486clone);
        return m486clone;
    }
}
