package smile.math.matrix;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.MathEx;

/* loaded from: input_file:smile-math-2.4.0.jar:smile/math/matrix/BiconjugateGradient.class */
public class BiconjugateGradient {
    private static final Logger logger = LoggerFactory.getLogger(BiconjugateGradient.class);
    private static BiconjugateGradient instance = new BiconjugateGradient();
    private double tol = 1.0E-10d;
    private int itol = 1;
    private int maxIter = 0;
    private Preconditioner preconditioner;

    public BiconjugateGradient() {
    }

    public BiconjugateGradient(double d, int i, int i2) {
        setTolerance(d);
        setConvergenceTest(i);
        setMaxIter(i2);
    }

    public static BiconjugateGradient getInstance() {
        return instance;
    }

    public BiconjugateGradient setTolerance(double d) {
        if (d <= 0.0d) {
            throw new IllegalArgumentException("Invalid tolerance: " + d);
        }
        this.tol = d;
        return this;
    }

    public BiconjugateGradient setConvergenceTest(int i) {
        if (i < 1 || i > 4) {
            throw new IllegalArgumentException(String.format("Invalid itol: %d", Integer.valueOf(i)));
        }
        this.itol = i;
        return this;
    }

    public BiconjugateGradient setMaxIter(int i) {
        this.maxIter = i;
        return this;
    }

    public BiconjugateGradient setPreconditioner(Preconditioner preconditioner) {
        this.preconditioner = preconditioner;
        return this;
    }

    private Preconditioner diagonalPreconditioner(Matrix matrix) {
        return (dArr, dArr2) -> {
            double[] diag = matrix.diag();
            int length = diag.length;
            for (int i = 0; i < length; i++) {
                dArr2[i] = diag[i] != 0.0d ? dArr[i] / diag[i] : dArr[i];
            }
        };
    }

    public double solve(Matrix matrix, double[] dArr, double[] dArr2) {
        double snorm;
        if (this.maxIter <= 0) {
            this.maxIter = 2 * Math.max(matrix.nrows(), matrix.ncols());
        }
        if (this.preconditioner == null) {
            this.preconditioner = diagonalPreconditioner(matrix);
        }
        double d = 0.0d;
        double d2 = 1.0d;
        double d3 = 0.0d;
        int length = dArr.length;
        double[] dArr3 = new double[length];
        double[] dArr4 = new double[length];
        double[] dArr5 = new double[length];
        double[] dArr6 = new double[length];
        double[] dArr7 = new double[length];
        double[] dArr8 = new double[length];
        matrix.ax(dArr2, dArr5);
        for (int i = 0; i < length; i++) {
            dArr5[i] = dArr[i] - dArr5[i];
            dArr6[i] = dArr5[i];
        }
        if (this.itol == 1) {
            snorm = snorm(dArr);
            this.preconditioner.solve(dArr5, dArr7);
        } else if (this.itol == 2) {
            this.preconditioner.solve(dArr, dArr7);
            snorm = snorm(dArr7);
            this.preconditioner.solve(dArr5, dArr7);
        } else {
            if (this.itol != 3 && this.itol != 4) {
                throw new IllegalArgumentException(String.format("Illegal itol: %d", Integer.valueOf(this.itol)));
            }
            this.preconditioner.solve(dArr, dArr7);
            snorm = snorm(dArr7);
            this.preconditioner.solve(dArr5, dArr7);
            d3 = snorm(dArr7);
        }
        int i2 = 1;
        while (true) {
            if (i2 > this.maxIter) {
                break;
            }
            this.preconditioner.solve(dArr6, dArr8);
            double d4 = 0.0d;
            for (int i3 = 0; i3 < length; i3++) {
                d4 += dArr7[i3] * dArr6[i3];
            }
            if (i2 == 1) {
                for (int i4 = 0; i4 < length; i4++) {
                    dArr3[i4] = dArr7[i4];
                    dArr4[i4] = dArr8[i4];
                }
            } else {
                double d5 = d4 / d2;
                for (int i5 = 0; i5 < length; i5++) {
                    dArr3[i5] = (d5 * dArr3[i5]) + dArr7[i5];
                    dArr4[i5] = (d5 * dArr4[i5]) + dArr8[i5];
                }
            }
            d2 = d4;
            matrix.ax(dArr3, dArr7);
            double d6 = 0.0d;
            for (int i6 = 0; i6 < length; i6++) {
                d6 += dArr7[i6] * dArr4[i6];
            }
            double d7 = d4 / d6;
            matrix.atx(dArr4, dArr8);
            for (int i7 = 0; i7 < length; i7++) {
                int i8 = i7;
                dArr2[i8] = dArr2[i8] + (d7 * dArr3[i7]);
                int i9 = i7;
                dArr5[i9] = dArr5[i9] - (d7 * dArr7[i7]);
                int i10 = i7;
                dArr6[i10] = dArr6[i10] - (d7 * dArr8[i7]);
            }
            this.preconditioner.solve(dArr5, dArr7);
            if (this.itol == 1) {
                d = snorm(dArr5) / snorm;
            } else if (this.itol == 2) {
                d = snorm(dArr7) / snorm;
            } else if (this.itol == 3 || this.itol == 4) {
                double d8 = d3;
                d3 = snorm(dArr7);
                if (Math.abs(d8 - d3) > MathEx.EPSILON * d3) {
                    double abs = (d3 / Math.abs(d8 - d3)) * Math.abs(d7) * snorm(dArr3);
                    double snorm2 = snorm(dArr2);
                    if (abs <= 0.5d * snorm2) {
                        d = abs / snorm2;
                    } else {
                        d = d3 / snorm;
                    }
                } else {
                    d = d3 / snorm;
                }
                i2++;
            }
            if (i2 % 10 == 0) {
                logger.info(String.format("BCG: the error after %3d iterations: %.5g", Integer.valueOf(i2), Double.valueOf(d)));
            }
            if (d <= this.tol) {
                logger.info(String.format("BCG: the error after %3d iterations: %.5g", Integer.valueOf(i2), Double.valueOf(d)));
                break;
            }
            i2++;
        }
        return d;
    }

    private double snorm(double[] dArr) {
        int length = dArr.length;
        if (this.itol <= 3) {
            double d = 0.0d;
            for (int i = 0; i < length; i++) {
                d += dArr[i] * dArr[i];
            }
            return Math.sqrt(d);
        }
        int i2 = 0;
        for (int i3 = 0; i3 < length; i3++) {
            if (Math.abs(dArr[i3]) > Math.abs(dArr[i2])) {
                i2 = i3;
            }
        }
        return Math.abs(dArr[i2]);
    }
}
