package jsat.classifiers.bayesian;

import java.util.Arrays;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.DataPoint;
import jsat.exceptions.FailedToFitException;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/bayesian/ODE.class */
public class ODE extends BaseUpdateableClassifier {
    private static final long serialVersionUID = -7732070257669428977L;
    protected int dependent;
    protected int predTargets;
    protected int depTargets;
    protected double[][][][] counts;
    protected double[][] priors;
    protected double priorSum;

    public ODE(int i) {
        this.dependent = i;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v16, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v9, types: [double[][][], double[][][][]] */
    protected ODE(ODE ode) {
        this(ode.dependent);
        this.predTargets = ode.predTargets;
        this.depTargets = ode.depTargets;
        if (ode.counts != null) {
            this.counts = new double[ode.counts.length][];
            for (int i = 0; i < this.counts.length; i++) {
                this.counts[i] = new double[i];
                for (int i2 = 0; i2 < this.counts[i].length; i2++) {
                    this.counts[i][i2] = new double[ode.counts[i][i2].length];
                    for (int i3 = 0; i3 < this.counts[i][i2].length; i3++) {
                        this.counts[i][i2][i3] = Arrays.copyOf(ode.counts[i][i2][i3], ode.counts[i][i2][i3].length);
                    }
                }
            }
            this.priors = new double[ode.priors.length];
            for (int i4 = 0; i4 < this.priors.length; i4++) {
                this.priors[i4] = Arrays.copyOf(ode.priors[i4], ode.priors[i4].length);
            }
            this.priorSum = ode.priorSum;
        }
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        CategoricalResults categoricalResults = new CategoricalResults(this.predTargets);
        int[] categoricalValues = dataPoint.getCategoricalValues();
        for (int i = 0; i < this.predTargets; i++) {
            categoricalResults.setProb(i, Math.exp(getLogPrb(categoricalValues, i)));
        }
        categoricalResults.normalize();
        return categoricalResults;
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return true;
    }

    @Override // jsat.classifiers.BaseUpdateableClassifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public ODE mo479clone() {
        return new ODE(this);
    }

    @Override // jsat.classifiers.UpdateableClassifier
    public void setUp(CategoricalData[] categoricalDataArr, int i, CategoricalData categoricalData) {
        if (categoricalDataArr.length < 1) {
            throw new FailedToFitException("At least 2 categorical varaibles are needed for ODE");
        }
        this.predTargets = categoricalData.getNumOfCategories();
        this.depTargets = categoricalDataArr[this.dependent].getNumOfCategories();
        this.counts = new double[this.predTargets][this.depTargets][categoricalDataArr.length];
        for (int i2 = 0; i2 < this.counts.length; i2++) {
            for (int i3 = 0; i3 < this.counts[i2].length; i3++) {
                for (int i4 = 0; i4 < this.counts[i2][i3].length; i4++) {
                    this.counts[i2][i3][i4] = new double[categoricalDataArr[i4].getNumOfCategories()];
                    Arrays.fill(this.counts[i2][i3][i4], 1.0d);
                }
            }
        }
        this.priors = new double[this.predTargets][this.depTargets];
        for (int i5 = 0; i5 < this.priors.length; i5++) {
            Arrays.fill(this.priors[i5], 1.0d);
            this.priorSum += this.priors[i5].length;
        }
    }

    @Override // jsat.classifiers.UpdateableClassifier
    public void update(DataPoint dataPoint, int i) {
        int[] categoricalValues = dataPoint.getCategoricalValues();
        double weight = dataPoint.getWeight();
        for (int i2 = 0; i2 < categoricalValues.length; i2++) {
            if (i2 != this.dependent) {
                double[] dArr = this.counts[i][categoricalValues[this.dependent]][i2];
                int i3 = categoricalValues[i2];
                dArr[i3] = dArr[i3] + weight;
            }
        }
        double[] dArr2 = this.priors[i];
        int i4 = categoricalValues[this.dependent];
        dArr2[i4] = dArr2[i4] + weight;
        this.priorSum += weight;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double getLogPrb(int[] iArr, int i) {
        double d = 0.0d;
        int i2 = iArr[this.dependent];
        for (int i3 = 0; i3 < iArr.length; i3++) {
            if (i3 != this.dependent) {
                double d2 = 0.0d;
                for (int i4 = 0; i4 < this.counts[i][i2][i3].length; i4++) {
                    d2 += this.counts[i][i2][i3][i4];
                }
                d += Math.log(this.counts[i][i2][i3][iArr[i3]] / d2);
            }
        }
        return d + Math.log(this.priors[i][i2] / this.priorSum);
    }
}
