package jsat.classifiers.bayesian.graphicalmodel;

import java.util.Iterator;
import java.util.Set;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.math.SpecialMath;
import jsat.utils.IntList;
import jsat.utils.IntSet;
import jsat.utils.ListUtils;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/bayesian/graphicalmodel/K2NetworkLearner.class */
public class K2NetworkLearner extends DiscreteBayesNetwork {
    private static final long serialVersionUID = -9681177007308829L;
    private int[] ri;
    private int maxParents;

    public void setMaxParents(int i) {
        this.maxParents = i;
    }

    public int getMaxParents() {
        return Math.max(this.maxParents, 0);
    }

    public void learnNetwork(ClassificationDataSet classificationDataSet) {
        IntList intList = new IntList(classificationDataSet.getNumCategoricalVars() + 1);
        intList.add(classificationDataSet.getNumCategoricalVars());
        ListUtils.addRange(intList, 0, classificationDataSet.getNumCategoricalVars(), 1);
        this.ri = new int[intList.size()];
        Iterator<Integer> it = intList.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            if (intValue == classificationDataSet.getNumCategoricalVars()) {
                this.ri[intValue] = classificationDataSet.getClassSize();
            } else {
                this.ri[intValue] = classificationDataSet.getCategories()[intValue].getNumOfCategories();
            }
        }
        int i = this.maxParents;
        if (i <= 0) {
            i = this.ri.length;
        }
        IntSet intSet = new IntSet();
        Iterator<Integer> it2 = intList.iterator();
        while (it2.hasNext()) {
            int intValue2 = it2.next().intValue();
            IntSet intSet2 = new IntSet();
            double f = f(intValue2, intSet2, classificationDataSet);
            boolean z = true;
            IntSet intSet3 = new IntSet((Set<Integer>) intSet);
            while (z && intSet2.size() < i && !intSet3.isEmpty()) {
                double d = Double.NEGATIVE_INFINITY;
                int i2 = -1;
                intSet3.removeAll(intSet2);
                Iterator<Integer> it3 = intSet3.iterator();
                while (it3.hasNext()) {
                    int intValue3 = it3.next().intValue();
                    intSet2.add((IntSet) Integer.valueOf(intValue3));
                    double f2 = f(intValue2, intSet2, classificationDataSet);
                    if (f2 > d) {
                        d = f2;
                        i2 = intValue3;
                    }
                    intSet2.remove(Integer.valueOf(intValue3));
                }
                if (d > f) {
                    f = d;
                    intSet2.add((IntSet) Integer.valueOf(i2));
                } else {
                    z = false;
                }
            }
            Iterator<Integer> it4 = intSet2.iterator();
            while (it4.hasNext()) {
                depends(it4.next().intValue(), intValue2);
            }
            intSet.add((IntSet) Integer.valueOf(intValue2));
        }
        this.ri = null;
    }

    @Override // jsat.classifiers.bayesian.graphicalmodel.DiscreteBayesNetwork, jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        if (this.dag.getNodes().isEmpty() || this.dag.getParents(Integer.valueOf(classificationDataSet.getNumCategoricalVars())).isEmpty()) {
            learnNetwork(classificationDataSet);
        }
        super.trainC(classificationDataSet);
    }

    private double query(int[] iArr, int[] iArr2, ClassificationDataSet classificationDataSet) {
        double d = 1.0d;
        for (int i = 0; i < classificationDataSet.getSampleSize(); i++) {
            DataPoint dataPoint = classificationDataSet.getDataPoint(i);
            int i2 = 0;
            while (i2 < iArr.length) {
                if (iArr[i2] == classificationDataSet.getNumCategoricalVars()) {
                    if (classificationDataSet.getDataPointCategory(i) != iArr2[i2]) {
                        i2 = iArr.length + 1;
                    }
                } else if (dataPoint.getCategoricalValue(i2) != iArr2[i2]) {
                    i2 = iArr.length + 1;
                }
                i2++;
            }
            if (i2 == iArr.length) {
                d += dataPoint.getWeight();
            }
        }
        return d;
    }

    public double f(int i, Set<Integer> set, ClassificationDataSet classificationDataSet) {
        int i2;
        double d = 0.0d;
        double d2 = 0.0d;
        if (set.isEmpty()) {
            int[] iArr = {i};
            int[] iArr2 = new int[1];
            for (int i3 = 0; i3 < this.ri[i]; i3++) {
                iArr2[0] = i3;
                double query = query(iArr, iArr2, classificationDataSet);
                d2 += query;
                d += SpecialMath.lnGamma(query + 1.0d);
            }
            return (SpecialMath.lnGamma(this.ri[i]) - SpecialMath.lnGamma(d2 + this.ri[i])) + d;
        }
        double d3 = 0.0d;
        int[] iArr3 = new int[set.size() + 1];
        int[] iArr4 = new int[set.size() + 1];
        int i4 = 0;
        Iterator<Integer> it = set.iterator();
        while (it.hasNext()) {
            int i5 = i4;
            i4++;
            iArr3[i5] = it.next().intValue();
        }
        iArr3[i4] = i;
        do {
            double d4 = i2 == true ? 1 : 0;
            double d5 = 0.0d;
            for (int i6 = 0; i6 < this.ri[i]; i6++) {
                iArr4[set.size()] = i6;
                double query2 = query(iArr3, iArr4, classificationDataSet);
                d4 += query2;
                d5 += SpecialMath.lnGamma(query2 + 1.0d);
            }
            d3 += (SpecialMath.lnGamma(this.ri[i]) - SpecialMath.lnGamma(d4 + this.ri[i])) + d5;
            int i7 = 0;
            iArr4[0] = iArr4[0] + 1;
            iArr4[set.size()] = 0;
            while (true) {
                i2 = i7;
                if (iArr4[i7] < this.ri[iArr3[i2 == true ? 1 : 0]] || i7 >= set.size()) {
                    break;
                }
                int i8 = i7;
                i7++;
                iArr4[i8] = 0;
                iArr4[i7] = iArr4[i7] + 1;
            }
        } while (iArr4[set.size()] == 0);
        return d3;
    }
}
