package smile.sequence;

import java.io.Serializable;
import java.util.Arrays;
import java.util.function.ToIntFunction;
import smile.math.MathEx;
import smile.math.matrix.DenseMatrix;
import smile.math.matrix.Matrix;
import smile.util.Strings;

/* loaded from: input_file:smile-core-2.4.0.jar:smile/sequence/HMM.class */
public class HMM implements Serializable {
    private static final long serialVersionUID = 2;
    private double[] pi;
    private DenseMatrix a;
    private DenseMatrix b;

    public HMM(double[] dArr, DenseMatrix denseMatrix, DenseMatrix denseMatrix2) {
        if (dArr.length == 0) {
            throw new IllegalArgumentException("Invalid initial state probabilities.");
        }
        if (dArr.length != denseMatrix.nrows()) {
            throw new IllegalArgumentException("Invalid state transition probability matrix.");
        }
        if (denseMatrix.nrows() != denseMatrix2.nrows()) {
            throw new IllegalArgumentException("Invalid symbol emission probability matrix.");
        }
        this.pi = dArr;
        this.a = denseMatrix;
        this.b = denseMatrix2;
    }

    public double[] getInitialStateProbabilities() {
        return this.pi;
    }

    public DenseMatrix getStateTransitionProbabilities() {
        return this.a;
    }

    public DenseMatrix getSymbolEmissionProbabilities() {
        return this.b;
    }

    public double p(int[] iArr, int[] iArr2) {
        return Math.exp(logp(iArr, iArr2));
    }

    public double logp(int[] iArr, int[] iArr2) {
        if (iArr.length != iArr2.length) {
            throw new IllegalArgumentException("The observation sequence and state sequence are not the same length.");
        }
        int length = iArr2.length;
        double log = MathEx.log(this.pi[iArr2[0]]) + MathEx.log(this.b.get(iArr2[0], iArr[0]));
        for (int i = 1; i < length; i++) {
            log += MathEx.log(this.a.get(iArr2[i - 1], iArr2[i])) + MathEx.log(this.b.get(iArr2[i], iArr[i]));
        }
        return log;
    }

    public double p(int[] iArr) {
        return Math.exp(logp(iArr));
    }

    public double logp(int[] iArr) {
        double[][] dArr = new double[iArr.length][this.a.nrows()];
        double[] dArr2 = new double[iArr.length];
        forward(iArr, dArr, dArr2);
        double d = 0.0d;
        for (int i = 0; i < iArr.length; i++) {
            d += Math.log(dArr2[i]);
        }
        return d;
    }

    private void scale(double[] dArr, double[][] dArr2, int i) {
        double[] dArr3 = dArr2[i];
        double d = 0.0d;
        for (double d2 : dArr3) {
            d += d2;
        }
        dArr[i] = d;
        for (int i2 = 0; i2 < dArr3.length; i2++) {
            int i3 = i2;
            dArr3[i3] = dArr3[i3] / d;
        }
    }

    private void forward(int[] iArr, double[][] dArr, double[] dArr2) {
        int nrows = this.a.nrows();
        for (int i = 0; i < nrows; i++) {
            dArr[0][i] = this.pi[i] * this.b.get(i, iArr[0]);
        }
        scale(dArr2, dArr, 0);
        for (int i2 = 1; i2 < iArr.length; i2++) {
            for (int i3 = 0; i3 < nrows; i3++) {
                double d = 0.0d;
                for (int i4 = 0; i4 < nrows; i4++) {
                    d += dArr[i2 - 1][i4] * this.a.get(i4, i3);
                }
                dArr[i2][i3] = d * this.b.get(i3, iArr[i2]);
            }
            scale(dArr2, dArr, i2);
        }
    }

    private void backward(int[] iArr, double[][] dArr, double[] dArr2) {
        int nrows = this.a.nrows();
        int length = iArr.length - 1;
        for (int i = 0; i < nrows; i++) {
            dArr[length][i] = 1.0d / dArr2[length];
        }
        int i2 = length;
        while (true) {
            int i3 = i2;
            i2--;
            if (i3 <= 0) {
                return;
            }
            for (int i4 = 0; i4 < nrows; i4++) {
                double d = 0.0d;
                for (int i5 = 0; i5 < nrows; i5++) {
                    d += dArr[i2 + 1][i5] * this.a.get(i4, i5) * this.b.get(i5, iArr[i2 + 1]);
                }
                dArr[i2][i4] = d / dArr2[i2];
            }
        }
    }

    public int[] predict(int[] iArr) {
        int nrows = this.a.nrows();
        double[][] dArr = new double[iArr.length][nrows];
        int[][] iArr2 = new int[iArr.length][nrows];
        int[] iArr3 = new int[iArr.length];
        for (int i = 0; i < nrows; i++) {
            dArr[0][i] = MathEx.log(this.pi[i]) + MathEx.log(this.b.get(i, iArr[0]));
            iArr2[0][i] = 0;
        }
        for (int i2 = 1; i2 < iArr.length; i2++) {
            for (int i3 = 0; i3 < nrows; i3++) {
                double d = Double.NEGATIVE_INFINITY;
                int i4 = 0;
                for (int i5 = 0; i5 < nrows; i5++) {
                    double log = dArr[i2 - 1][i5] + MathEx.log(this.a.get(i5, i3));
                    if (d < log) {
                        d = log;
                        i4 = i5;
                    }
                }
                dArr[i2][i3] = d + MathEx.log(this.b.get(i3, iArr[i2]));
                iArr2[i2][i3] = i4;
            }
        }
        int length = iArr.length - 1;
        double d2 = Double.NEGATIVE_INFINITY;
        for (int i6 = 0; i6 < nrows; i6++) {
            if (d2 < dArr[length][i6]) {
                d2 = dArr[length][i6];
                iArr3[length] = i6;
            }
        }
        int i7 = length;
        while (true) {
            int i8 = i7;
            i7--;
            if (i8 <= 0) {
                return iArr3;
            }
            iArr3[i7] = iArr2[i7 + 1][iArr3[i7 + 1]];
        }
    }

    public static HMM fit(int[][] iArr, int[][] iArr2) {
        if (iArr.length != iArr2.length) {
            throw new IllegalArgumentException("The number of observation sequences and that of label sequences are different.");
        }
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < iArr.length; i3++) {
            if (iArr[i3].length != iArr2[i3].length) {
                throw new IllegalArgumentException(String.format("The length of observation sequence %d and that of corresponding label sequence are different.", Integer.valueOf(i3)));
            }
            i = Math.max(i, MathEx.max(iArr2[i3]) + 1);
            i2 = Math.max(i2, MathEx.max(iArr[i3]) + 1);
        }
        double[] dArr = new double[i];
        double[][] dArr2 = new double[i][i];
        double[][] dArr3 = new double[i][i2];
        for (int i4 = 0; i4 < iArr.length; i4++) {
            int i5 = iArr2[i4][0];
            dArr[i5] = dArr[i5] + 1.0d;
            double[] dArr4 = dArr3[iArr2[i4][0]];
            int i6 = iArr[i4][0];
            dArr4[i6] = dArr4[i6] + 1.0d;
            for (int i7 = 1; i7 < iArr[i4].length; i7++) {
                double[] dArr5 = dArr2[iArr2[i4][i7 - 1]];
                int i8 = iArr2[i4][i7];
                dArr5[i8] = dArr5[i8] + 1.0d;
                double[] dArr6 = dArr3[iArr2[i4][i7]];
                int i9 = iArr[i4][i7];
                dArr6[i9] = dArr6[i9] + 1.0d;
            }
        }
        MathEx.unitize1(dArr);
        for (int i10 = 0; i10 < i; i10++) {
            MathEx.unitize1(dArr2[i10]);
            MathEx.unitize1(dArr3[i10]);
        }
        return new HMM(dArr, Matrix.of(dArr2), Matrix.of(dArr3));
    }

    public static <T> HMM fit(T[][] tArr, int[][] iArr, ToIntFunction<T> toIntFunction) {
        if (tArr.length != iArr.length) {
            throw new IllegalArgumentException("The number of observation sequences and that of label sequences are different.");
        }
        return fit((int[][]) Arrays.stream(tArr).map(objArr -> {
            return Arrays.stream(objArr).mapToInt(obj -> {
                return toIntFunction.applyAsInt(obj);
            }).toArray();
        }).toArray(i -> {
            return new int[i];
        }), iArr);
    }

    public <T> void update(T[][] tArr, int i, ToIntFunction<T> toIntFunction) {
        update((int[][]) Arrays.stream(tArr).map(objArr -> {
            return Arrays.stream(objArr).mapToInt(obj -> {
                return toIntFunction.applyAsInt(obj);
            }).toArray();
        }).toArray(i2 -> {
            return new int[i2];
        }), i);
    }

    public void update(int[][] iArr, int i) {
        for (int i2 = 0; i2 < i; i2++) {
            iterate(iArr);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void iterate(int[][] iArr) {
        int nrows = this.a.nrows();
        int ncols = this.b.ncols();
        double[][] dArr = new double[iArr.length];
        double[][] dArr2 = new double[nrows][nrows];
        double[] dArr3 = new double[nrows];
        for (int i = 0; i < iArr.length; i++) {
            if (iArr[i].length <= 2) {
                throw new IllegalArgumentException(String.format("Training sequence %d is too short.", Integer.valueOf(i)));
            }
            int[] iArr2 = iArr[i];
            double[][] dArr4 = new double[iArr2.length][nrows];
            double[][] dArr5 = new double[iArr2.length][nrows];
            double[] dArr6 = new double[iArr2.length];
            forward(iArr2, dArr4, dArr6);
            backward(iArr2, dArr5, dArr6);
            double[][][] estimateXi = estimateXi(iArr2, dArr4, dArr5);
            double[][] estimateGamma = estimateGamma(estimateXi);
            dArr[i] = estimateGamma;
            int length = iArr2.length - 1;
            for (int i2 = 0; i2 < nrows; i2++) {
                for (int i3 = 0; i3 < length; i3++) {
                    int i4 = i2;
                    dArr3[i4] = dArr3[i4] + estimateGamma[i3][i2];
                    for (int i5 = 0; i5 < nrows; i5++) {
                        double[] dArr7 = dArr2[i2];
                        int i6 = i5;
                        dArr7[i6] = dArr7[i6] + estimateXi[i3][i2][i5];
                    }
                }
            }
        }
        for (int i7 = 0; i7 < nrows; i7++) {
            if (dArr3[i7] != 0.0d) {
                for (int i8 = 0; i8 < nrows; i8++) {
                    this.a.set(i7, i8, dArr2[i7][i8] / dArr3[i7]);
                }
            }
        }
        Arrays.fill(this.pi, 0.0d);
        for (int i9 = 0; i9 < iArr.length; i9++) {
            for (int i10 = 0; i10 < nrows; i10++) {
                double[] dArr8 = this.pi;
                int i11 = i10;
                dArr8[i11] = dArr8[i11] + dArr[i9][0][i10];
            }
        }
        for (int i12 = 0; i12 < nrows; i12++) {
            double[] dArr9 = this.pi;
            int i13 = i12;
            dArr9[i13] = dArr9[i13] / iArr.length;
        }
        this.b.fill(0.0d);
        for (int i14 = 0; i14 < nrows; i14++) {
            double d = 0.0d;
            for (int i15 = 0; i15 < iArr.length; i15++) {
                int[] iArr3 = iArr[i15];
                for (int i16 = 0; i16 < iArr3.length; i16++) {
                    this.b.add(i14, iArr3[i16], dArr[i15][i16][i14]);
                    d += dArr[i15][i16][i14];
                }
            }
            for (int i17 = 0; i17 < ncols; i17++) {
                this.b.div(i14, i17, d);
            }
        }
    }

    private double[][][] estimateXi(int[] iArr, double[][] dArr, double[][] dArr2) {
        if (iArr.length <= 1) {
            throw new IllegalArgumentException("Observation sequence is too short.");
        }
        int nrows = this.a.nrows();
        int length = iArr.length - 1;
        double[][][] dArr3 = new double[length][nrows][nrows];
        for (int i = 0; i < length; i++) {
            for (int i2 = 0; i2 < nrows; i2++) {
                for (int i3 = 0; i3 < nrows; i3++) {
                    dArr3[i][i2][i3] = dArr[i][i2] * this.a.get(i2, i3) * this.b.get(i3, iArr[i + 1]) * dArr2[i + 1][i3];
                }
            }
        }
        return dArr3;
    }

    private double[][] estimateGamma(double[][][] dArr) {
        int nrows = this.a.nrows();
        double[][] dArr2 = new double[dArr.length + 1][nrows];
        for (int i = 0; i < dArr.length; i++) {
            for (int i2 = 0; i2 < nrows; i2++) {
                for (int i3 = 0; i3 < nrows; i3++) {
                    double[] dArr3 = dArr2[i];
                    int i4 = i2;
                    dArr3[i4] = dArr3[i4] + dArr[i][i2][i3];
                }
            }
        }
        int length = dArr.length - 1;
        for (int i5 = 0; i5 < nrows; i5++) {
            for (int i6 = 0; i6 < nrows; i6++) {
                double[] dArr4 = dArr2[dArr.length];
                int i7 = i5;
                dArr4[i7] = dArr4[i7] + dArr[length][i6][i5];
            }
        }
        return dArr2;
    }

    public String toString() {
        return String.format("HMM (%d states, %d emission symbols)%n", Integer.valueOf(this.a.nrows()), Integer.valueOf(this.b.ncols())) + "Initial state probability: " + Strings.toString(this.pi) + "\nState transition probability:\n" + this.a.toString() + "Symbol emission probability:\n" + this.b.toString();
    }
}
