package jsat.io;

import com.google.common.primitives.UnsignedBytes;
import com.itextpdf.text.pdf.BaseFont;
import com.itextpdf.text.xml.xmp.XmpWriter;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.EnumSet;
import java.util.Iterator;
import jsat.DataSet;
import jsat.SimpleDataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.SparseVector;
import jsat.linear.Vec;
import jsat.regression.RegressionDataSet;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/io/JSATData.class */
public class JSATData {
    public static final byte[] MAGIC_NUMBER = {74, 83, 65, 84, 95, 48, 48};
    public static final byte STRING_ENCODING_ASCII = 0;
    public static final byte STRING_ENCODING_UTF_16 = 1;

    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/io/JSATData$DatasetTypeMarker.class */
    public enum DatasetTypeMarker {
        STANDARD,
        REGRESSION,
        CLASSIFICATION
    }

    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/io/JSATData$FloatStorageMethod.class */
    public enum FloatStorageMethod {
        AUTO { // from class: jsat.io.JSATData.FloatStorageMethod.1
            @Override // jsat.io.JSATData.FloatStorageMethod
            protected void writeFP(double d, DataOutputStream dataOutputStream) throws IOException {
                throw new UnsupportedOperationException("Not supported .");
            }

            @Override // jsat.io.JSATData.FloatStorageMethod
            protected double readFP(DataInputStream dataInputStream) throws IOException {
                throw new UnsupportedOperationException("Not supported .");
            }

            @Override // jsat.io.JSATData.FloatStorageMethod
            protected boolean noLoss(double d) {
                return true;
            }
        },
        FP64 { // from class: jsat.io.JSATData.FloatStorageMethod.2
            @Override // jsat.io.JSATData.FloatStorageMethod
            protected void writeFP(double d, DataOutputStream dataOutputStream) throws IOException {
                dataOutputStream.writeDouble(d);
            }

            @Override // jsat.io.JSATData.FloatStorageMethod
            protected double readFP(DataInputStream dataInputStream) throws IOException {
                return dataInputStream.readDouble();
            }

            @Override // jsat.io.JSATData.FloatStorageMethod
            protected boolean noLoss(double d) {
                return true;
            }
        },
        FP32 { // from class: jsat.io.JSATData.FloatStorageMethod.3
            @Override // jsat.io.JSATData.FloatStorageMethod
            protected void writeFP(double d, DataOutputStream dataOutputStream) throws IOException {
                dataOutputStream.writeFloat((float) d);
            }

            @Override // jsat.io.JSATData.FloatStorageMethod
            protected double readFP(DataInputStream dataInputStream) throws IOException {
                return dataInputStream.readFloat();
            }

            @Override // jsat.io.JSATData.FloatStorageMethod
            protected boolean noLoss(double d) {
                return Double.valueOf((double) ((float) d)).doubleValue() - d == 0.0d;
            }
        },
        SHORT { // from class: jsat.io.JSATData.FloatStorageMethod.4
            @Override // jsat.io.JSATData.FloatStorageMethod
            protected void writeFP(double d, DataOutputStream dataOutputStream) throws IOException {
                dataOutputStream.writeShort(Math.min(Math.max((int) d, -32768), BaseFont.CID_NEWLINE));
            }

            @Override // jsat.io.JSATData.FloatStorageMethod
            protected double readFP(DataInputStream dataInputStream) throws IOException {
                return dataInputStream.readShort();
            }

            @Override // jsat.io.JSATData.FloatStorageMethod
            protected boolean noLoss(double d) {
                return -32768.0d <= d && d <= 32767.0d && d == Math.rint(d);
            }
        },
        BYTE { // from class: jsat.io.JSATData.FloatStorageMethod.5
            @Override // jsat.io.JSATData.FloatStorageMethod
            protected void writeFP(double d, DataOutputStream dataOutputStream) throws IOException {
                dataOutputStream.writeByte(Math.min(Math.max((int) d, UnsignedBytes.MAX_POWER_OF_TWO), 127));
            }

            @Override // jsat.io.JSATData.FloatStorageMethod
            protected double readFP(DataInputStream dataInputStream) throws IOException {
                return dataInputStream.readByte();
            }

            @Override // jsat.io.JSATData.FloatStorageMethod
            protected boolean noLoss(double d) {
                return -128.0d <= d && d <= 127.0d && d == Math.rint(d);
            }
        },
        U_BYTE { // from class: jsat.io.JSATData.FloatStorageMethod.6
            @Override // jsat.io.JSATData.FloatStorageMethod
            protected void writeFP(double d, DataOutputStream dataOutputStream) throws IOException {
                dataOutputStream.writeByte(Math.min(Math.max((int) d, 0), 255));
            }

            @Override // jsat.io.JSATData.FloatStorageMethod
            protected double readFP(DataInputStream dataInputStream) throws IOException {
                return dataInputStream.readByte() & 255;
            }

            @Override // jsat.io.JSATData.FloatStorageMethod
            protected boolean noLoss(double d) {
                return 0.0d <= d && d <= 255.0d && d == Math.rint(d);
            }
        };

        protected abstract void writeFP(double d, DataOutputStream dataOutputStream) throws IOException;

        protected abstract double readFP(DataInputStream dataInputStream) throws IOException;

        protected abstract boolean noLoss(double d);

        public static <Type extends DataSet<Type>> FloatStorageMethod getMethod(DataSet<Type> dataSet, FloatStorageMethod floatStorageMethod) {
            if (floatStorageMethod != AUTO) {
                return floatStorageMethod;
            }
            EnumSet complementOf = EnumSet.complementOf(EnumSet.of(AUTO));
            for (int i = 0; i < dataSet.getSampleSize(); i++) {
                DataPoint dataPoint = dataSet.getDataPoint(i);
                Iterator<IndexValue> it = dataPoint.getNumericalValues().iterator();
                while (it.hasNext()) {
                    IndexValue next = it.next();
                    Iterator it2 = complementOf.iterator();
                    while (it2.hasNext()) {
                        if (!((FloatStorageMethod) it2.next()).noLoss(next.getValue())) {
                            it2.remove();
                        }
                    }
                    if (complementOf.size() == 1) {
                        break;
                    }
                }
                Iterator it3 = complementOf.iterator();
                while (it3.hasNext()) {
                    if (!((FloatStorageMethod) it3.next()).noLoss(dataPoint.getWeight())) {
                        it3.remove();
                    }
                }
                if (complementOf.size() == 1) {
                    break;
                }
            }
            if (dataSet instanceof RegressionDataSet) {
                Iterator<IndexValue> it4 = ((RegressionDataSet) dataSet).getTargetValues().iterator();
                while (it4.hasNext()) {
                    IndexValue next2 = it4.next();
                    Iterator it5 = complementOf.iterator();
                    while (it5.hasNext()) {
                        if (!((FloatStorageMethod) it5.next()).noLoss(next2.getValue())) {
                            it5.remove();
                        }
                    }
                    if (complementOf.size() == 1) {
                        break;
                    }
                }
            }
            return complementOf.contains(BYTE) ? BYTE : complementOf.contains(U_BYTE) ? U_BYTE : complementOf.contains(SHORT) ? SHORT : complementOf.contains(FP32) ? FP32 : FP64;
        }
    }

    private JSATData() {
    }

    public static <Type extends DataSet<Type>> void writeData(DataSet<Type> dataSet, OutputStream outputStream) throws IOException {
        writeData(dataSet, outputStream, FloatStorageMethod.AUTO);
    }

    public static <Type extends DataSet<Type>> void writeData(DataSet<Type> dataSet, OutputStream outputStream, FloatStorageMethod floatStorageMethod) throws IOException {
        FloatStorageMethod method = FloatStorageMethod.getMethod(dataSet, floatStorageMethod);
        DataOutputStream dataOutputStream = new DataOutputStream(outputStream);
        dataOutputStream.write(MAGIC_NUMBER);
        int numNumericalVars = dataSet.getNumNumericalVars();
        int numCategoricalVars = dataSet.getNumCategoricalVars();
        DatasetTypeMarker datasetTypeMarker = DatasetTypeMarker.STANDARD;
        if (dataSet instanceof RegressionDataSet) {
            numNumericalVars++;
            datasetTypeMarker = DatasetTypeMarker.REGRESSION;
        }
        if (dataSet instanceof ClassificationDataSet) {
            numCategoricalVars++;
            datasetTypeMarker = DatasetTypeMarker.CLASSIFICATION;
        }
        dataOutputStream.writeByte(datasetTypeMarker.ordinal());
        dataOutputStream.writeByte(method.ordinal());
        dataOutputStream.writeInt(numNumericalVars);
        dataOutputStream.writeInt(numCategoricalVars);
        dataOutputStream.writeInt(dataSet.getSampleSize());
        for (CategoricalData categoricalData : dataSet.getCategories()) {
            writeString(categoricalData.getCategoryName(), dataOutputStream);
            dataOutputStream.writeInt(categoricalData.getNumOfCategories());
            for (int i = 0; i < categoricalData.getNumOfCategories(); i++) {
                writeString(categoricalData.getOptionName(i), dataOutputStream);
            }
        }
        if (dataSet instanceof ClassificationDataSet) {
            CategoricalData predicting = ((ClassificationDataSet) dataSet).getPredicting();
            writeString(predicting.getCategoryName(), dataOutputStream);
            dataOutputStream.writeInt(predicting.getNumOfCategories());
            for (int i2 = 0; i2 < predicting.getNumOfCategories(); i2++) {
                writeString(predicting.getOptionName(i2), dataOutputStream);
            }
        }
        for (int i3 = 0; i3 < dataSet.getSampleSize(); i3++) {
            DataPoint dataPoint = dataSet.getDataPoint(i3);
            method.writeFP(dataPoint.getWeight(), dataOutputStream);
            for (int i4 : dataPoint.getCategoricalValues()) {
                dataOutputStream.writeInt(i4);
            }
            if (dataSet instanceof ClassificationDataSet) {
                dataOutputStream.writeInt(((ClassificationDataSet) dataSet).getDataPointCategory(i3));
            }
            Vec numericalValues = dataPoint.getNumericalValues();
            dataOutputStream.writeBoolean(numericalValues.isSparse());
            if (numericalValues.isSparse()) {
                if (datasetTypeMarker == DatasetTypeMarker.REGRESSION) {
                    dataOutputStream.writeInt(numericalValues.nnz() + 1);
                } else {
                    dataOutputStream.writeInt(numericalValues.nnz());
                }
                Iterator<IndexValue> it = numericalValues.iterator();
                while (it.hasNext()) {
                    IndexValue next = it.next();
                    dataOutputStream.writeInt(next.getIndex());
                    method.writeFP(next.getValue(), dataOutputStream);
                }
            } else {
                for (int i5 = 0; i5 < numericalValues.length(); i5++) {
                    method.writeFP(numericalValues.get(i5), dataOutputStream);
                }
            }
            if (dataSet instanceof RegressionDataSet) {
                if (numericalValues.isSparse()) {
                    dataOutputStream.writeInt(numericalValues.length());
                }
                method.writeFP(((RegressionDataSet) dataSet).getTargetValue(i3), dataOutputStream);
            }
        }
        dataOutputStream.flush();
        dataOutputStream.close();
    }

    public static DataSet<?> load(InputStream inputStream) throws IOException {
        return load(inputStream, false);
    }

    public static SimpleDataSet loadSimple(InputStream inputStream) throws IOException {
        return (SimpleDataSet) load(inputStream, true);
    }

    public static ClassificationDataSet loadClassification(InputStream inputStream) throws IOException {
        return (ClassificationDataSet) load(inputStream);
    }

    public static RegressionDataSet loadRegression(InputStream inputStream) throws IOException {
        return (RegressionDataSet) load(inputStream);
    }

    protected static DataSet<?> load(InputStream inputStream, boolean z) throws IOException {
        DataSet simpleDataSet;
        Vec denseVector;
        DataInputStream dataInputStream = new DataInputStream(inputStream);
        byte[] bArr = new byte[MAGIC_NUMBER.length];
        dataInputStream.readFully(bArr);
        if (!new String(bArr, "US-ASCII").startsWith("JSAT_")) {
            throw new RuntimeException("data does not contain magic number");
        }
        DatasetTypeMarker datasetTypeMarker = DatasetTypeMarker.values()[dataInputStream.readByte()];
        FloatStorageMethod floatStorageMethod = FloatStorageMethod.values()[dataInputStream.readByte()];
        int readInt = dataInputStream.readInt();
        int readInt2 = dataInputStream.readInt();
        int readInt3 = dataInputStream.readInt();
        if (z) {
            datasetTypeMarker = DatasetTypeMarker.STANDARD;
        }
        if (datasetTypeMarker == DatasetTypeMarker.CLASSIFICATION) {
            readInt2--;
        } else if (datasetTypeMarker == DatasetTypeMarker.REGRESSION) {
            readInt--;
        }
        CategoricalData[] categoricalDataArr = new CategoricalData[readInt2];
        CategoricalData categoricalData = null;
        for (int i = 0; i < categoricalDataArr.length; i++) {
            String readString = readString(dataInputStream);
            int readInt4 = dataInputStream.readInt();
            categoricalDataArr[i] = new CategoricalData(readInt4);
            categoricalDataArr[i].setCategoryName(readString);
            for (int i2 = 0; i2 < readInt4; i2++) {
                categoricalDataArr[i].setOptionName(readString(dataInputStream), i2);
            }
        }
        if (datasetTypeMarker == DatasetTypeMarker.CLASSIFICATION) {
            String readString2 = readString(dataInputStream);
            int readInt5 = dataInputStream.readInt();
            categoricalData = new CategoricalData(readInt5);
            categoricalData.setCategoryName(readString2);
            for (int i3 = 0; i3 < readInt5; i3++) {
                categoricalData.setOptionName(readString(dataInputStream), i3);
            }
        }
        switch (datasetTypeMarker) {
            case CLASSIFICATION:
                simpleDataSet = new ClassificationDataSet(readInt, categoricalDataArr, categoricalData);
                break;
            case REGRESSION:
                simpleDataSet = new RegressionDataSet(readInt, categoricalDataArr);
                break;
            default:
                simpleDataSet = new SimpleDataSet(categoricalDataArr, readInt);
                break;
        }
        for (int i4 = 0; i4 < readInt3; i4++) {
            double readFP = floatStorageMethod.readFP(dataInputStream);
            int[] iArr = new int[readInt2];
            for (int i5 = 0; i5 < iArr.length; i5++) {
                iArr[i5] = dataInputStream.readInt();
            }
            double readInt6 = datasetTypeMarker == DatasetTypeMarker.CLASSIFICATION ? dataInputStream.readInt() : 0.0d;
            if (dataInputStream.readBoolean()) {
                int readInt7 = dataInputStream.readInt();
                if (datasetTypeMarker == DatasetTypeMarker.REGRESSION) {
                    readInt7--;
                }
                int[] iArr2 = new int[readInt7];
                double[] dArr = new double[readInt7];
                for (int i6 = 0; i6 < readInt7; i6++) {
                    iArr2[i6] = dataInputStream.readInt();
                    dArr[i6] = floatStorageMethod.readFP(dataInputStream);
                }
                denseVector = new SparseVector(iArr2, dArr, readInt, readInt7);
            } else {
                denseVector = new DenseVector(readInt);
                for (int i7 = 0; i7 < readInt; i7++) {
                    denseVector.set(i7, floatStorageMethod.readFP(dataInputStream));
                }
            }
            if (datasetTypeMarker == DatasetTypeMarker.REGRESSION) {
                if (denseVector.isSparse()) {
                    dataInputStream.readInt();
                }
                readInt6 = floatStorageMethod.readFP(dataInputStream);
            }
            DataPoint dataPoint = new DataPoint(denseVector, iArr, categoricalDataArr, readFP);
            switch (datasetTypeMarker) {
                case CLASSIFICATION:
                    ((ClassificationDataSet) simpleDataSet).addDataPoint(dataPoint, (int) readInt6);
                    break;
                case REGRESSION:
                    ((RegressionDataSet) simpleDataSet).addDataPoint(dataPoint, readInt6);
                    break;
                default:
                    ((SimpleDataSet) simpleDataSet).add(dataPoint);
                    break;
            }
        }
        dataInputStream.close();
        return simpleDataSet;
    }

    private static void writeString(String str, DataOutputStream dataOutputStream) throws IOException {
        boolean z = true;
        for (int i = 0; i < str.length() && z; i++) {
            if (str.charAt(i) >= 256 || str.charAt(i) <= 0) {
                z = false;
            }
        }
        if (!z) {
            byte[] bytes = str.getBytes(XmpWriter.UTF16);
            dataOutputStream.writeByte(1);
            dataOutputStream.writeInt(bytes.length);
            dataOutputStream.write(bytes);
            return;
        }
        dataOutputStream.writeByte(0);
        dataOutputStream.writeInt(str.length());
        for (int i2 = 0; i2 < str.length(); i2++) {
            dataOutputStream.writeByte(str.charAt(i2));
        }
    }

    private static String readString(DataInputStream dataInputStream) throws IOException {
        StringBuilder sb = new StringBuilder();
        byte readByte = dataInputStream.readByte();
        int readInt = dataInputStream.readInt();
        switch (readByte) {
            case 0:
                for (int i = 0; i < readInt; i++) {
                    sb.append(Character.toChars(dataInputStream.readByte()));
                }
                return sb.toString();
            case 1:
                byte[] bArr = new byte[readInt];
                dataInputStream.readFully(bArr);
                return new String(bArr, XmpWriter.UTF16);
            default:
                throw new RuntimeException("Unkown string encoding value " + ((int) readByte));
        }
    }
}
