package jsat.text;

import java.util.List;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.ClassificationDataSet;
import jsat.text.tokenizer.Tokenizer;
import jsat.text.wordweighting.WordWeighting;
import jsat.utils.IntList;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/text/ClassificationTextDataLoader.class */
public abstract class ClassificationTextDataLoader extends TextDataLoader {
    private static final long serialVersionUID = -3826551504785236576L;
    protected final List<Integer> classLabels;
    protected CategoricalData labelInfo;

    public ClassificationTextDataLoader(Tokenizer tokenizer, WordWeighting wordWeighting) {
        super(tokenizer, wordWeighting);
        this.classLabels = new IntList();
    }

    protected abstract void setLabelInfo();

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // jsat.text.TextDataLoader
    public int addOriginalDocument(String str) {
        throw new UnsupportedOperationException("addOriginalDocument(String text, int label) should be used instead");
    }

    protected int addOriginalDocument(String str, int i) {
        if (i >= this.labelInfo.getNumOfCategories()) {
            throw new RuntimeException("Invalid label given");
        }
        int addOriginalDocument = super.addOriginalDocument(str);
        synchronized (this.classLabels) {
            while (this.classLabels.size() < addOriginalDocument) {
                this.classLabels.add(-1);
            }
            if (this.classLabels.size() == addOriginalDocument) {
                this.classLabels.add(Integer.valueOf(i));
            } else {
                this.classLabels.set(addOriginalDocument, Integer.valueOf(i));
            }
        }
        return addOriginalDocument;
    }

    @Override // jsat.text.TextDataLoader
    public ClassificationDataSet getDataSet() {
        if (!this.noMoreAdding) {
            setLabelInfo();
            initialLoad();
            finishAdding();
        }
        ClassificationDataSet classificationDataSet = new ClassificationDataSet(this.vectors.get(0).length(), new CategoricalData[0], this.labelInfo);
        for (int i = 0; i < this.vectors.size(); i++) {
            classificationDataSet.addDataPoint(this.vectors.get(i), new int[0], this.classLabels.get(i).intValue());
        }
        return classificationDataSet;
    }
}
