package jsat.classifiers.bayesian;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.exceptions.FailedToFitException;
import jsat.exceptions.UntrainedModelException;
import jsat.utils.IntSet;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/bayesian/ConditionalProbabilityTable.class */
public class ConditionalProbabilityTable implements Classifier {
    private static final long serialVersionUID = -287709075031023626L;
    private CategoricalData predicting;
    private double[] countArray;
    private Map<Integer, CategoricalData> valid;
    private int[] realIndexToCatIndex;
    private int[] catIndexToRealIndex;
    private int[] dimSize;
    private int predictingIndex;

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        if (this.catIndexToRealIndex[this.predictingIndex] < 0) {
            throw new UntrainedModelException("CPT has not been trained for a classification problem");
        }
        CategoricalResults categoricalResults = new CategoricalResults(this.predicting.getNumOfCategories());
        int[] iArr = new int[this.dimSize.length];
        dataPointToCord(new DataPointPair<>(dataPoint, -1), this.predictingIndex, iArr);
        for (int i = 0; i < categoricalResults.size(); i++) {
            iArr[this.catIndexToRealIndex[this.predictingIndex]] = i;
            categoricalResults.setProb(i, this.countArray[cordToIndex(iArr)]);
        }
        categoricalResults.normalize();
        return categoricalResults;
    }

    public int getDimensionSize() {
        return this.dimSize.length;
    }

    public int dataPointToCord(DataPointPair<Integer> dataPointPair, int i, int[] iArr) {
        if (iArr.length != getDimensionSize()) {
            throw new ArithmeticException("Storage space and CPT dimension miss match");
        }
        DataPoint dataPoint = dataPointPair.getDataPoint();
        int i2 = -1;
        for (int i3 = 0; i3 < this.dimSize.length; i3++) {
            if (this.realIndexToCatIndex[i3] == i) {
                i2 = i == dataPoint.numCategoricalValues() ? dataPointPair.getPair().intValue() : dataPoint.getCategoricalValue(this.realIndexToCatIndex[i3]);
            }
            if (this.realIndexToCatIndex[i3] == this.predictingIndex) {
                iArr[i3] = dataPointPair.getPair().intValue();
            } else {
                iArr[i3] = dataPoint.getCategoricalValue(this.realIndexToCatIndex[i3]);
            }
        }
        return i2;
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        trainC(classificationDataSet);
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        IntSet intSet = new IntSet();
        for (int i = 0; i < classificationDataSet.getNumCategoricalVars() + 1; i++) {
            intSet.add((IntSet) Integer.valueOf(i));
        }
        trainC(classificationDataSet, intSet);
    }

    public void trainC(ClassificationDataSet classificationDataSet, Set<Integer> set) {
        if (set.size() > classificationDataSet.getNumFeatures() + 1) {
            throw new FailedToFitException("CPT can not train on a number of features greater then the dataset's feature count. Specified " + set.size() + " but data set has only " + classificationDataSet.getNumFeatures());
        }
        CategoricalData[] categories = classificationDataSet.getCategories();
        this.predicting = classificationDataSet.getPredicting();
        this.predictingIndex = classificationDataSet.getNumCategoricalVars();
        this.valid = new HashMap();
        this.realIndexToCatIndex = new int[set.size()];
        this.catIndexToRealIndex = new int[classificationDataSet.getNumCategoricalVars() + 1];
        Arrays.fill(this.catIndexToRealIndex, -1);
        this.dimSize = new int[this.realIndexToCatIndex.length];
        int i = 1;
        int i2 = 0;
        Iterator<Integer> it = set.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            if (intValue != this.predictingIndex) {
                CategoricalData categoricalData = categories[intValue];
                i *= categoricalData.getNumOfCategories();
                this.valid.put(Integer.valueOf(intValue), categoricalData);
                this.realIndexToCatIndex[i2] = intValue;
                this.catIndexToRealIndex[intValue] = i2;
                int i3 = i2;
                i2++;
                this.dimSize[i3] = categoricalData.getNumOfCategories();
            }
        }
        if (set.contains(Integer.valueOf(this.predictingIndex))) {
            i *= this.predicting.getNumOfCategories();
            this.realIndexToCatIndex[i2] = this.predictingIndex;
            this.catIndexToRealIndex[this.predictingIndex] = i2;
            this.dimSize[i2] = this.predicting.getNumOfCategories();
            this.valid.put(Integer.valueOf(this.predictingIndex), this.predicting);
        }
        this.countArray = new double[i];
        Arrays.fill(this.countArray, 1.0d);
        int[] iArr = new int[this.dimSize.length];
        for (int i4 = 0; i4 < classificationDataSet.getSampleSize(); i4++) {
            DataPoint dataPoint = classificationDataSet.getDataPoint(i4);
            for (int i5 = 0; i5 < this.realIndexToCatIndex.length; i5++) {
                if (this.realIndexToCatIndex[i5] != this.predictingIndex) {
                    iArr[i5] = dataPoint.getCategoricalValue(this.realIndexToCatIndex[i5]);
                } else {
                    iArr[i5] = classificationDataSet.getDataPointCategory(i4);
                }
            }
            double[] dArr = this.countArray;
            int cordToIndex = cordToIndex(iArr);
            dArr[cordToIndex] = dArr[cordToIndex] + dataPoint.getWeight();
        }
    }

    public double query(int i, DataPointPair<Integer> dataPointPair) {
        int[] iArr = new int[this.dimSize.length];
        return query(i, dataPointToCord(dataPointPair, i, iArr), iArr);
    }

    public double query(int i, int i2, int[] iArr) {
        double d = 0.0d;
        double d2 = 0.0d;
        int i3 = this.catIndexToRealIndex[i];
        CategoricalData categoricalData = this.valid.get(Integer.valueOf(i));
        for (int i4 = 0; i4 < categoricalData.getNumOfCategories(); i4++) {
            iArr[i3] = i4;
            double d3 = this.countArray[cordToIndex(iArr)];
            d += d3;
            if (i4 == i2) {
                d2 = d3;
            }
        }
        return d2 / d;
    }

    private int cordToIndex(int... iArr) {
        if (iArr.length != this.realIndexToCatIndex.length) {
            throw new RuntimeException("Something bad");
        }
        int i = 0;
        for (int i2 = 0; i2 < iArr.length; i2++) {
            i = iArr[i2] + (this.dimSize[i2] * i);
        }
        return i;
    }

    private int cordToIndex(DataPointPair<Integer> dataPointPair) {
        DataPoint dataPoint = dataPointPair.getDataPoint();
        int i = 0;
        for (int i2 = 0; i2 < this.dimSize.length; i2++) {
            i = dataPoint.getCategoricalValue(this.realIndexToCatIndex[i2]) + (this.dimSize[i2] * i);
        }
        return i;
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    @Override // jsat.classifiers.Classifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public Classifier m494clone() {
        throw new UnsupportedOperationException("Not supported yet.");
    }
}
