package javastat.survival.regression;

import JSci.maths.statistics.NormalDistribution;
import Jama.Matrix;
import java.util.Hashtable;
import javastat.Regression;
import javastat.StatisticalAnalysis;
import javastat.util.Argument;
import javastat.util.BasicStatistics;
import javastat.util.DataManager;
import javastat.util.Output;

/* loaded from: input_file:javastat-1.4.jar:javastat/survival/regression/CoxRegression.class */
public class CoxRegression extends Regression {
    public double alpha;
    public double[] coefficients;
    public double[][] variance;
    public double[] testStatistic;
    public double[] pValue;
    public double[][] confidenceInterval;
    public double[] time;
    public double[] censor;
    public double[][] covariate;
    private double[][] doubleCovariate;
    public StatisticalAnalysis statisticalAnalysis;
    public double[] initialEstimate;
    public double error1;
    public double error2;
    public int maxIterationNumber;
    private double zAlpha;
    private double numerator;
    private double denominator;
    private double[] riskFactor;
    private double[][] expectedCovariate;
    private double[] score;
    private double[][] information;
    private Matrix scoreMatrix;
    private Matrix informationMatrix;
    private Matrix coefficientMatrix;
    private Matrix errorMatrix;
    private DataManager dataManager;
    private NormalDistribution normalDistribution;
    private int iterationIndex;

    public CoxRegression() {
        this.initialEstimate = null;
        this.maxIterationNumber = 1000;
    }

    public CoxRegression(Hashtable hashtable, Object[] objArr) {
        this.initialEstimate = null;
        this.maxIterationNumber = 1000;
        this.argument = hashtable;
        this.dataObject = objArr;
        if (objArr == null) {
            this.statisticalAnalysis = new CoxRegression();
            return;
        }
        if (objArr.length == 3 && objArr[0].getClass().getName().equalsIgnoreCase("[D") && objArr[1].getClass().getName().equalsIgnoreCase("[D") && objArr[2].getClass().getName().equalsIgnoreCase("[[D")) {
            this.doubleCovariate = (double[][]) objArr[2];
        } else {
            if (objArr.length < 3 || !objArr[0].getClass().getName().equalsIgnoreCase("[D") || !objArr[1].getClass().getName().equalsIgnoreCase("[D") || (!objArr.getClass().getName().equalsIgnoreCase("[Ljava.lang.Object;") && !objArr.getClass().getName().equalsIgnoreCase("[[D"))) {
                throw new IllegalArgumentException("Wrong input data type");
            }
            this.doubleCovariate = DataManager.castDoubleObject(2, objArr);
        }
        if (hashtable.size() > 0 && hashtable.get(Argument.ALPHA) != null) {
            this.statisticalAnalysis = new CoxRegression(((Double) hashtable.get(Argument.ALPHA)).doubleValue(), (double[]) objArr[0], (double[]) objArr[1], this.doubleCovariate);
        } else {
            if (hashtable.size() != 0) {
                throw new IllegalArgumentException("Wrong input argument(s).");
            }
            this.statisticalAnalysis = new CoxRegression((double[]) objArr[0], (double[]) objArr[1], this.doubleCovariate);
        }
    }

    public CoxRegression(double d, double[] dArr, double[] dArr2, double[][] dArr3) {
        this.initialEstimate = null;
        this.maxIterationNumber = 1000;
        this.alpha = d;
        this.time = dArr;
        this.censor = dArr2;
        this.covariate = dArr3;
        this.testStatistic = testStatistic(dArr, dArr2, dArr3);
        this.confidenceInterval = confidenceInterval(d, dArr, dArr2, dArr3);
    }

    public CoxRegression(double[] dArr, double[] dArr2, double[][] dArr3) {
        this(0.05d, dArr, dArr2, dArr3);
    }

    @Override // javastat.Regression
    public double[] coefficients(Hashtable hashtable, Object[] objArr) {
        this.argument = hashtable;
        this.dataObject = objArr;
        if (objArr != null && objArr.length == 3 && objArr[0].getClass().getName().equalsIgnoreCase("[D") && objArr[1].getClass().getName().equalsIgnoreCase("[D") && objArr[2].getClass().getName().equalsIgnoreCase("[[D")) {
            this.coefficients = coefficients((double[]) objArr[0], (double[]) objArr[1], (double[][]) objArr[2]);
        } else {
            if (objArr == null || objArr.length < 3 || !objArr[0].getClass().getName().equalsIgnoreCase("[D") || !objArr[1].getClass().getName().equalsIgnoreCase("[D") || (!objArr.getClass().getName().equalsIgnoreCase("[Ljava.lang.Object;") && !objArr.getClass().getName().equalsIgnoreCase("[[D"))) {
                throw new IllegalArgumentException("Wrong input arguments or data.");
            }
            this.coefficients = coefficients((double[]) objArr[0], (double[]) objArr[1], DataManager.castDoubleObject(2, objArr));
        }
        return this.coefficients;
    }

    public double[] coefficients(double[] dArr, double[] dArr2, double[][] dArr3) {
        this.time = dArr;
        this.censor = dArr2;
        this.covariate = dArr3;
        this.testStatistic = testStatistic(dArr, dArr2, dArr3);
        return this.coefficients;
    }

    @Override // javastat.StatisticalInference
    public double[] testStatistic(Hashtable hashtable, Object[] objArr) {
        this.argument = hashtable;
        this.dataObject = objArr;
        if (objArr != null && objArr.length == 3 && objArr[0].getClass().getName().equalsIgnoreCase("[D") && objArr[1].getClass().getName().equalsIgnoreCase("[D") && objArr[2].getClass().getName().equalsIgnoreCase("[[D")) {
            this.testStatistic = testStatistic((double[]) objArr[0], (double[]) objArr[1], (double[][]) objArr[2]);
        } else {
            if (objArr == null || objArr.length < 3 || !objArr[0].getClass().getName().equalsIgnoreCase("[D") || !objArr[1].getClass().getName().equalsIgnoreCase("[D") || (!objArr.getClass().getName().equalsIgnoreCase("[Ljava.lang.Object;") && !objArr.getClass().getName().equalsIgnoreCase("[[D"))) {
                throw new IllegalArgumentException("Wrong input arguments or data.");
            }
            this.testStatistic = testStatistic((double[]) objArr[0], (double[]) objArr[1], DataManager.castDoubleObject(2, objArr));
        }
        return this.testStatistic;
    }

    public double[] testStatistic(double[] dArr, double[] dArr2, double[][] dArr3) {
        this.time = dArr;
        this.censor = dArr2;
        this.covariate = dArr3;
        BasicStatistics.convergenceCriterion = new double[]{1.0E-5d};
        this.denominator = 0.0d;
        this.numerator = 0.0d;
        this.error1 = 1.0d;
        this.error2 = 1.0d;
        this.pValue = new double[dArr3.length];
        this.testStatistic = new double[dArr3.length];
        this.riskFactor = new double[dArr.length];
        this.expectedCovariate = new double[dArr3.length][dArr.length];
        this.score = new double[dArr3.length];
        this.coefficients = new double[dArr3.length];
        this.information = new double[dArr3.length][dArr3.length];
        this.normalDistribution = new NormalDistribution();
        this.dataManager = new DataManager();
        this.dataManager.checkDimension(dArr3);
        this.dataManager.checkPositiveRange(dArr, "time");
        this.dataManager.checkCensor(dArr2);
        if (dArr.length != dArr3[0].length || dArr.length != dArr2.length) {
            throw new IllegalArgumentException("The time vector, censor vector, and rows of the covariate matrix must have the same length.");
        }
        BasicStatistics.maxIterationNumber = this.maxIterationNumber;
        BasicStatistics.initialEstimate = this.initialEstimate;
        if (BasicStatistics.initialEstimate == null) {
            BasicStatistics.initialEstimate = new double[dArr3.length];
            for (int i = 0; i < dArr3.length; i++) {
                BasicStatistics.initialEstimate[i] = 0.1d;
            }
        }
        this.coefficientMatrix = new Matrix(BasicStatistics.initialEstimate, dArr3.length);
        this.coefficients = this.coefficientMatrix.getColumnPackedCopy();
        this.iterationIndex = 1;
        do {
            if ((this.error1 <= BasicStatistics.convergenceCriterion[0] && this.error2 <= BasicStatistics.convergenceCriterion[0]) || this.iterationIndex > BasicStatistics.maxIterationNumber) {
                this.informationMatrix = new Matrix(this.information);
                this.variance = this.informationMatrix.inverse().times(-1.0d).getArray();
                for (int i2 = 0; i2 < dArr3.length; i2++) {
                    this.testStatistic[i2] = this.coefficients[i2] / Math.sqrt(this.variance[i2][i2]);
                }
                for (int i3 = 0; i3 < this.testStatistic.length; i3++) {
                    this.pValue[i3] = 2.0d * (1.0d - this.normalDistribution.cumulative(Math.abs(this.testStatistic[i3])));
                }
                this.output.put(Output.COEFFICIENTS, this.coefficients);
                this.output.put(Output.COEFFICIENT_VARIANCE, this.variance);
                this.output.put(Output.TEST_STATISTIC, this.testStatistic);
                this.output.put(Output.PVALUE, this.pValue);
                return this.testStatistic;
            }
            for (int i4 = 0; i4 < dArr.length; i4++) {
                this.riskFactor[i4] = 1.0d;
                for (int i5 = 0; i5 < dArr3.length; i5++) {
                    double[] dArr4 = this.riskFactor;
                    int i6 = i4;
                    dArr4[i6] = dArr4[i6] * Math.exp(this.coefficients[i5] * dArr3[i5][i4]);
                }
            }
            for (int i7 = 0; i7 < dArr3.length; i7++) {
                this.score[i7] = 0.0d;
                for (int i8 = 0; i8 < dArr.length; i8++) {
                    for (int i9 = 0; i9 < dArr.length; i9++) {
                        if (dArr[i9] >= dArr[i8]) {
                            this.denominator += this.riskFactor[i9];
                            this.numerator += dArr3[i7][i9] * this.riskFactor[i9];
                        }
                    }
                    this.expectedCovariate[i7][i8] = this.numerator / this.denominator;
                    double[] dArr5 = this.score;
                    int i10 = i7;
                    dArr5[i10] = dArr5[i10] + (dArr2[i8] * (dArr3[i7][i8] - this.expectedCovariate[i7][i8]));
                    this.numerator = 0.0d;
                    this.denominator = 0.0d;
                }
            }
            for (int i11 = 0; i11 < dArr3.length; i11++) {
                for (int i12 = 0; i12 < dArr3.length; i12++) {
                    this.information[i11][i12] = 0.0d;
                    for (int i13 = 0; i13 < dArr.length; i13++) {
                        for (int i14 = 0; i14 < dArr.length; i14++) {
                            if (dArr[i14] >= dArr[i13]) {
                                this.denominator += this.riskFactor[i14];
                                this.numerator += dArr3[i11][i14] * dArr3[i12][i14] * this.riskFactor[i14];
                            }
                        }
                        double[] dArr6 = this.information[i11];
                        int i15 = i12;
                        dArr6[i15] = dArr6[i15] + (dArr2[i13] * ((this.expectedCovariate[i11][i13] * this.expectedCovariate[i12][i13]) - (this.numerator / this.denominator)));
                        this.numerator = 0.0d;
                        this.denominator = 0.0d;
                    }
                }
            }
            this.informationMatrix = new Matrix(this.information);
            this.scoreMatrix = new Matrix(this.score, dArr3.length);
            this.errorMatrix = this.informationMatrix.inverse().times(this.scoreMatrix);
            this.coefficientMatrix.minusEquals(this.errorMatrix);
            this.coefficients = this.coefficientMatrix.getColumnPackedCopy();
            this.error1 = Math.sqrt(this.errorMatrix.normF());
            this.error2 = Math.sqrt(this.scoreMatrix.normF());
            this.iterationIndex++;
        } while (this.iterationIndex <= BasicStatistics.maxIterationNumber);
        throw new IllegalArgumentException("The algorithm can not converge.");
    }

    public double[][] confidenceInterval(Hashtable hashtable, Object[] objArr) {
        this.argument = hashtable;
        this.dataObject = objArr;
        if (objArr == null) {
            throw new IllegalArgumentException("Wrong input data.");
        }
        if (objArr.length == 3 && objArr[0].getClass().getName().equalsIgnoreCase("[D") && objArr[1].getClass().getName().equalsIgnoreCase("[D") && objArr[2].getClass().getName().equalsIgnoreCase("[[D")) {
            this.doubleCovariate = (double[][]) objArr[2];
        } else {
            if (objArr.length < 3 || !objArr[0].getClass().getName().equalsIgnoreCase("[D") || !objArr[1].getClass().getName().equalsIgnoreCase("[D") || (!objArr.getClass().getName().equalsIgnoreCase("[Ljava.lang.Object;") && !objArr.getClass().getName().equalsIgnoreCase("[[D"))) {
                throw new IllegalArgumentException("Wrong input data type");
            }
            this.doubleCovariate = DataManager.castDoubleObject(2, objArr);
        }
        if (hashtable.size() > 0 && hashtable.get(Argument.ALPHA) != null) {
            this.confidenceInterval = confidenceInterval(((Double) hashtable.get(Argument.ALPHA)).doubleValue(), (double[]) objArr[0], (double[]) objArr[1], this.doubleCovariate);
        } else {
            if (hashtable.size() != 0) {
                throw new IllegalArgumentException("Wrong input argument(s).");
            }
            this.confidenceInterval = confidenceInterval((double[]) objArr[0], (double[]) objArr[1], this.doubleCovariate);
        }
        return this.confidenceInterval;
    }

    public double[][] confidenceInterval(double d, double[] dArr, double[] dArr2, double[][] dArr3) {
        if (d <= 0.0d || d > 1.0d) {
            throw new IllegalArgumentException("The level of significance should be (strictly) positive and not greater than 1.");
        }
        this.alpha = d;
        this.time = dArr;
        this.censor = dArr2;
        this.covariate = dArr3;
        this.testStatistic = testStatistic(dArr, dArr2, dArr3);
        this.zAlpha = new NormalDistribution().inverse(1.0d - (d / 2.0d));
        this.confidenceInterval = new double[this.testStatistic.length][2];
        for (int i = 0; i < this.testStatistic.length; i++) {
            this.confidenceInterval[i][0] = this.coefficients[i] - (this.zAlpha * Math.sqrt(this.variance[i][i]));
            this.confidenceInterval[i][1] = this.coefficients[i] + (this.zAlpha * Math.sqrt(this.variance[i][i]));
        }
        this.output.put(Output.CONFIDENCE_INTERVAL, this.confidenceInterval);
        return this.confidenceInterval;
    }

    public double[][] confidenceInterval(double[] dArr, double[] dArr2, double[][] dArr3) {
        return confidenceInterval(0.05d, dArr, dArr2, dArr3);
    }

    @Override // javastat.StatisticalInference
    public double[] pValue(Hashtable hashtable, Object[] objArr) {
        this.argument = hashtable;
        this.dataObject = objArr;
        if (objArr != null && objArr.length == 3 && objArr[0].getClass().getName().equalsIgnoreCase("[D") && objArr[1].getClass().getName().equalsIgnoreCase("[D") && objArr[2].getClass().getName().equalsIgnoreCase("[[D")) {
            this.pValue = pValue((double[]) objArr[0], (double[]) objArr[1], (double[][]) objArr[2]);
        } else {
            if (objArr == null || objArr.length < 3 || !objArr[0].getClass().getName().equalsIgnoreCase("[D") || !objArr[1].getClass().getName().equalsIgnoreCase("[D") || (!objArr.getClass().getName().equalsIgnoreCase("[Ljava.lang.Object;") && !objArr.getClass().getName().equalsIgnoreCase("[[D"))) {
                throw new IllegalArgumentException("Wrong input arguments or data.");
            }
            this.pValue = pValue((double[]) objArr[0], (double[]) objArr[1], DataManager.castDoubleObject(2, objArr));
        }
        return this.pValue;
    }

    public double[] pValue(double[] dArr, double[] dArr2, double[][] dArr3) {
        this.time = dArr;
        this.censor = dArr2;
        this.covariate = dArr3;
        this.testStatistic = testStatistic(dArr, dArr2, dArr3);
        return this.pValue;
    }

    @Override // javastat.Regression
    public Object coefficients(Hashtable hashtable, Object[] objArr) {
        return coefficients(hashtable, objArr);
    }

    @Override // javastat.StatisticalInference
    public Object pValue(Hashtable hashtable, Object[] objArr) {
        return pValue(hashtable, objArr);
    }

    @Override // javastat.StatisticalInference
    public Object testStatistic(Hashtable hashtable, Object[] objArr) {
        return testStatistic(hashtable, objArr);
    }
}
