package mikera.matrixx.solve.impl.lu;

import mikera.matrixx.AMatrix;
import mikera.matrixx.Matrix;
import mikera.matrixx.decompose.impl.lu.AltLU;
import mikera.matrixx.decompose.impl.lu.LUPResult;
import mikera.matrixx.impl.ADenseArrayMatrix;

/* loaded from: input_file:vectorz-0.48.0.jar:mikera/matrixx/solve/impl/lu/LUSolver.class */
public class LUSolver {
    protected AltLU decomp;
    private LUPResult result;
    boolean doImprove;
    protected AMatrix A;
    protected int numRows;
    protected int numCols;

    public AMatrix getA() {
        return this.A;
    }

    public LUSolver(boolean z) {
        this.doImprove = false;
        this.doImprove = z;
    }

    public LUSolver() {
        this.doImprove = false;
        this.doImprove = false;
    }

    public LUPResult setA(AMatrix aMatrix) {
        if (!aMatrix.isSquare()) {
            throw new IllegalArgumentException("Input must be a square matrix.");
        }
        this.A = aMatrix;
        this.numRows = aMatrix.rowCount();
        this.numCols = aMatrix.columnCount();
        this.decomp = new AltLU();
        this.result = this.decomp._decompose(aMatrix);
        return this.result;
    }

    public double quality() {
        return this.decomp.quality();
    }

    public AMatrix invert() {
        if (!this.A.isSquare()) {
            throw new IllegalArgumentException("Matrix must be square for inverse!");
        }
        double[] _getVV = this.decomp._getVV();
        AMatrix lu = this.decomp.getLU();
        Matrix create = Matrix.create(lu.rowCount(), lu.columnCount());
        int columnCount = this.A.columnCount();
        double[] dArr = create.data;
        int i = 0;
        while (i < columnCount) {
            int i2 = 0;
            while (i2 < columnCount) {
                _getVV[i2] = i2 == i ? 1.0d : 0.0d;
                i2++;
            }
            this.decomp._solveVectorInternal(_getVV);
            int i3 = i;
            int i4 = 0;
            while (i4 < columnCount) {
                dArr[i3] = _getVV[i4];
                i4++;
                i3 += columnCount;
            }
            i++;
        }
        return create;
    }

    public ADenseArrayMatrix solve(AMatrix aMatrix) {
        if (aMatrix.rowCount() != this.numCols) {
            throw new IllegalArgumentException("Unexpected matrix size");
        }
        if (Math.abs(this.result.computeDeterminant()) < 1.0E-10d) {
            return null;
        }
        Matrix create = Matrix.create(this.numCols, aMatrix.columnCount());
        int columnCount = aMatrix.columnCount();
        double[] asDoubleArray = aMatrix.asDoubleArray();
        if (asDoubleArray == null) {
            asDoubleArray = aMatrix.toDoubleArray();
        }
        double[] dArr = create.data;
        double[] _getVV = this.decomp._getVV();
        for (int i = 0; i < columnCount; i++) {
            int i2 = i;
            int i3 = 0;
            while (i3 < this.numCols) {
                _getVV[i3] = asDoubleArray[i2];
                i3++;
                i2 += columnCount;
            }
            this.decomp._solveVectorInternal(_getVV);
            int i4 = i;
            int i5 = 0;
            while (i5 < this.numCols) {
                dArr[i4] = _getVV[i5];
                i5++;
                i4 += columnCount;
            }
        }
        if (this.doImprove) {
            improveSol(aMatrix, create);
        }
        return create;
    }

    public void improveSol(AMatrix aMatrix, AMatrix aMatrix2) {
        if (aMatrix.columnCount() != aMatrix2.columnCount()) {
            throw new IllegalArgumentException("bad shapes");
        }
        double[] asDoubleArray = this.A.asDoubleArray();
        double[] asDoubleArray2 = aMatrix.asDoubleArray();
        double[] asDoubleArray3 = aMatrix2.asDoubleArray();
        int columnCount = aMatrix.columnCount();
        int columnCount2 = aMatrix.columnCount();
        double[] _getVV = this.decomp._getVV();
        for (int i = 0; i < columnCount; i++) {
            for (int i2 = 0; i2 < columnCount2; i2++) {
                double d = -asDoubleArray2[(i2 * columnCount) + i];
                for (int i3 = 0; i3 < columnCount2; i3++) {
                    d += asDoubleArray[(i2 * columnCount2) + i3] * asDoubleArray3[(i3 * columnCount) + i];
                }
                _getVV[i2] = d;
            }
            this.decomp._solveVectorInternal(_getVV);
            for (int i4 = 0; i4 < columnCount2; i4++) {
                int i5 = (i4 * columnCount) + i;
                asDoubleArray3[i5] = asDoubleArray3[i5] - _getVV[i4];
            }
        }
    }
}
