package jsat.regression;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import jsat.DataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.SparseVector;
import jsat.linear.Vec;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/regression/RegressionDataSet.class */
public class RegressionDataSet extends DataSet<RegressionDataSet> {
    protected List<DataPointPair<Double>> dataPoints;
    private static final int[] emptyInt = new int[0];

    public RegressionDataSet(int i, CategoricalData[] categoricalDataArr) {
        this.numNumerVals = i;
        this.categories = categoricalDataArr;
        this.dataPoints = new ArrayList();
        this.numericalVariableNames = new ArrayList(getNumNumericalVars());
        setUpGenericNumericNames();
    }

    public RegressionDataSet(List<DataPoint> list, int i) {
        DataPoint dataPoint = list.get(0);
        this.categories = new CategoricalData[dataPoint.numCategoricalValues()];
        System.arraycopy(dataPoint.getCategoricalData(), 0, this.categories, 0, this.categories.length);
        this.numNumerVals = dataPoint.numNumericalValues() - 1;
        this.dataPoints = new ArrayList(list.size());
        for (DataPoint dataPoint2 : list) {
            Vec numericalValues = dataPoint2.getNumericalValues();
            double d = 0.0d;
            Vec sparseVector = numericalValues.isSparse() ? new SparseVector(numericalValues.length() - 1, numericalValues.nnz()) : new DenseVector(numericalValues.length() - 1);
            Iterator<IndexValue> it = numericalValues.iterator();
            while (it.hasNext()) {
                IndexValue next = it.next();
                if (next.getIndex() < i) {
                    sparseVector.set(next.getIndex(), next.getValue());
                } else if (next.getIndex() == i) {
                    d = next.getValue();
                } else {
                    sparseVector.set(next.getIndex() - 1, next.getValue());
                }
            }
            this.dataPoints.add(new DataPointPair<>(new DataPoint(sparseVector, dataPoint2.getCategoricalValues(), this.categories, dataPoint2.getWeight()), Double.valueOf(d)));
        }
        this.numericalVariableNames = new ArrayList(getNumNumericalVars());
        setUpGenericNumericNames();
    }

    public RegressionDataSet(List<DataPointPair<Double>> list) {
        this.numNumerVals = list.get(0).getDataPoint().numNumericalValues();
        this.numericalVariableNames = new ArrayList(getNumNumericalVars());
        setUpGenericNumericNames();
        this.categories = CategoricalData.copyOf(list.get(0).getDataPoint().getCategoricalData());
        this.dataPoints = new ArrayList(list.size());
        for (DataPointPair<Double> dataPointPair : list) {
            this.dataPoints.add(new DataPointPair<>(dataPointPair.getDataPoint().m486clone(), dataPointPair.getPair()));
        }
    }

    private void setUpGenericNumericNames() {
        if (getNumNumericalVars() > 100) {
            return;
        }
        for (int i = 0; i < getNumNumericalVars(); i++) {
            this.numericalVariableNames.add("Numeric Input " + (i + 1));
        }
    }

    private RegressionDataSet() {
    }

    public static RegressionDataSet comineAllBut(List<RegressionDataSet> list, int i) {
        RegressionDataSet regressionDataSet = new RegressionDataSet(list.get(i).getNumNumericalVars(), list.get(i).getCategories());
        for (int i2 = 0; i2 < list.size(); i2++) {
            if (i2 != i) {
                regressionDataSet.dataPoints.addAll(list.get(i2).dataPoints);
            }
        }
        return regressionDataSet;
    }

    public void addDataPoint(Vec vec, double d) {
        addDataPoint(vec, emptyInt, d);
    }

    public void addDataPoint(Vec vec, int[] iArr, double d) {
        if (vec.length() != this.numNumerVals) {
            throw new RuntimeException("Data point does not contain enough numerical data points");
        }
        if (iArr.length != iArr.length) {
            throw new RuntimeException("Data point does not contain enough categorical data points");
        }
        for (int i = 0; i < iArr.length; i++) {
            if (!this.categories[i].isValidCategory(iArr[i]) && iArr[i] >= 0) {
                throw new RuntimeException("Categoriy value given is invalid");
            }
        }
        addDataPoint(new DataPoint(vec, iArr, this.categories), d);
    }

    public void addDataPoint(DataPoint dataPoint, double d) {
        if (dataPoint.numNumericalValues() != getNumNumericalVars() || dataPoint.numCategoricalValues() != getNumCategoricalVars()) {
            throw new RuntimeException("The added data point does not match the number of values and categories for the data set");
        }
        if (Double.isInfinite(d) || Double.isNaN(d)) {
            throw new ArithmeticException("Unregressiable value " + d + " given for regression");
        }
        this.dataPoints.add(new DataPointPair<>(dataPoint, Double.valueOf(d)));
        this.columnVecCache.clear();
    }

    public void addDataPointPair(DataPointPair<Double> dataPointPair) {
        this.dataPoints.add(dataPointPair);
        this.columnVecCache.clear();
    }

    @Override // jsat.DataSet
    public DataPoint getDataPoint(int i) {
        return this.dataPoints.get(i).getDataPoint();
    }

    public DataPointPair<Double> getDataPointPair(int i) {
        return this.dataPoints.get(i);
    }

    public List<DataPointPair<Double>> getAsDPPList() {
        ArrayList arrayList = new ArrayList(this.dataPoints.size());
        for (DataPointPair<Double> dataPointPair : this.dataPoints) {
            arrayList.add(new DataPointPair(dataPointPair.getDataPoint().m486clone(), dataPointPair.getPair()));
        }
        return arrayList;
    }

    public List<DataPointPair<Double>> getDPPList() {
        return new ArrayList(this.dataPoints);
    }

    @Override // jsat.DataSet
    public void setDataPoint(int i, DataPoint dataPoint) {
        this.dataPoints.get(i).setDataPoint(dataPoint);
        this.columnVecCache.clear();
    }

    public void setTargetValue(int i, double d) {
        if (Double.isInfinite(d) || Double.isNaN(d)) {
            throw new ArithmeticException("Can not predict a " + d + " value");
        }
        this.dataPoints.get(i).setPair(Double.valueOf(d));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // jsat.DataSet
    protected RegressionDataSet getSubset(List<Integer> list) {
        RegressionDataSet regressionDataSet = new RegressionDataSet(this.numNumerVals, this.categories);
        Iterator<Integer> it = list.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            regressionDataSet.addDataPoint(getDataPoint(intValue), getTargetValue(intValue));
        }
        return regressionDataSet;
    }

    @Override // jsat.DataSet
    public int getSampleSize() {
        return this.dataPoints.size();
    }

    public Vec getTargetValues() {
        DenseVector denseVector = new DenseVector(getSampleSize());
        for (int i = 0; i < getSampleSize(); i++) {
            denseVector.set(i, this.dataPoints.get(i).getPair().doubleValue());
        }
        return denseVector;
    }

    public double getTargetValue(int i) {
        return this.dataPoints.get(i).getPair().doubleValue();
    }

    public static RegressionDataSet usingDPPList(List<DataPointPair<Double>> list) {
        RegressionDataSet regressionDataSet = new RegressionDataSet();
        regressionDataSet.dataPoints = list;
        regressionDataSet.numNumerVals = list.get(0).getDataPoint().numNumericalValues();
        regressionDataSet.numericalVariableNames = new ArrayList(regressionDataSet.getNumNumericalVars());
        for (int i = 0; i < regressionDataSet.getNumNumericalVars(); i++) {
            regressionDataSet.numericalVariableNames.add("Numeric Input " + (i + 1));
        }
        regressionDataSet.categories = CategoricalData.copyOf(list.get(0).getDataPoint().getCategoricalData());
        return regressionDataSet;
    }

    @Override // jsat.DataSet
    /* renamed from: shallowClone */
    public DataSet<RegressionDataSet> shallowClone2() {
        RegressionDataSet regressionDataSet = new RegressionDataSet(this.numNumerVals, this.categories);
        for (DataPointPair<Double> dataPointPair : this.dataPoints) {
            regressionDataSet.dataPoints.add(new DataPointPair<>(dataPointPair.getDataPoint(), dataPointPair.getPair()));
        }
        regressionDataSet.columnVecCache.putAll(this.columnVecCache);
        return regressionDataSet;
    }

    @Override // jsat.DataSet
    public RegressionDataSet getTwiceShallowClone() {
        return (RegressionDataSet) super.getTwiceShallowClone();
    }

    @Override // jsat.DataSet
    protected /* bridge */ /* synthetic */ RegressionDataSet getSubset(List list) {
        return getSubset((List<Integer>) list);
    }
}
