package jsat.regression;

import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.classifiers.DataPoint;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.LUPDecomposition;
import jsat.linear.Matrix;
import jsat.linear.SingularValueDecomposition;
import jsat.linear.Vec;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.SystemInfo;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/regression/OrdinaryKriging.class */
public class OrdinaryKriging implements Regressor, Parameterized {
    private static final long serialVersionUID = -5774553215322383751L;
    private Variogram vari;
    private Vec X;
    private RegressionDataSet dataSet;
    private double errorSqrd;
    private double nugget;
    public static final double DEFAULT_NUGGET = 0.1d;
    public static final double DEFAULT_ERROR = 0.1d;
    List<Parameter> params;
    private Map<String, Parameter> paramMap;

    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/regression/OrdinaryKriging$PowVariogram.class */
    public static class PowVariogram implements Variogram {
        private double alpha;
        private double beta;

        public PowVariogram() {
            this(1.5d);
        }

        public PowVariogram(double d) {
            this.beta = d;
        }

        @Override // jsat.regression.OrdinaryKriging.Variogram
        public void train(RegressionDataSet regressionDataSet, double d) {
            int sampleSize = regressionDataSet.getSampleSize();
            double d2 = 0.0d;
            double d3 = 0.0d;
            double d4 = d * d;
            for (int i = 0; i < sampleSize; i++) {
                Vec numericalValues = regressionDataSet.getDataPoint(i).getNumericalValues();
                double targetValue = regressionDataSet.getTargetValue(i);
                for (int i2 = i + 1; i2 < sampleSize; i2++) {
                    Vec numericalValues2 = regressionDataSet.getDataPoint(i2).getNumericalValues();
                    double targetValue2 = regressionDataSet.getTargetValue(i2);
                    double pow = Math.pow(numericalValues.pNormDist(2.0d, numericalValues2), this.beta);
                    d2 += pow * ((0.5d * Math.pow(targetValue - targetValue2, 2.0d)) - d4);
                    d3 += pow * pow;
                }
            }
            this.alpha = d2 / d3;
        }

        @Override // jsat.regression.OrdinaryKriging.Variogram
        public double val(double d) {
            return this.alpha * Math.pow(d, this.beta);
        }

        @Override // jsat.regression.OrdinaryKriging.Variogram
        /* renamed from: clone, reason: merged with bridge method [inline-methods] */
        public Variogram m716clone() {
            PowVariogram powVariogram = new PowVariogram(this.beta);
            powVariogram.alpha = this.alpha;
            return powVariogram;
        }
    }

    /* loaded from: input_file:JSAT-0.0.7.jar:jsat/regression/OrdinaryKriging$Variogram.class */
    public interface Variogram extends Cloneable {
        void train(RegressionDataSet regressionDataSet, double d);

        double val(double d);

        /* renamed from: clone */
        Variogram m716clone();
    }

    public OrdinaryKriging(Variogram variogram, double d, double d2) {
        this.params = Collections.unmodifiableList(Parameter.getParamsFromMethods(this));
        this.paramMap = Parameter.toParameterMap(this.params);
        this.vari = variogram;
        setMeasurementError(d);
        this.nugget = d2;
    }

    public OrdinaryKriging(Variogram variogram, double d) {
        this(variogram, d, 0.1d);
    }

    public OrdinaryKriging(Variogram variogram) {
        this(variogram, 0.1d);
    }

    public OrdinaryKriging() {
        this(new PowVariogram());
    }

    @Override // jsat.regression.Regressor
    public double regress(DataPoint dataPoint) {
        Vec numericalValues = dataPoint.getNumericalValues();
        int length = this.X.length() - 1;
        double[] dArr = new double[length + 1];
        for (int i = 0; i < length; i++) {
            dArr[i] = this.vari.val(numericalValues.pNormDist(2.0d, this.dataSet.getDataPoint(i).getNumericalValues()));
        }
        dArr[length] = 1.0d;
        return this.X.dot(DenseVector.toDenseVec(dArr));
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet, ExecutorService executorService) {
        this.dataSet = regressionDataSet;
        int sampleSize = regressionDataSet.getSampleSize();
        DenseVector denseVector = new DenseVector(sampleSize + 1);
        DenseMatrix denseMatrix = new DenseMatrix(sampleSize + 1, sampleSize + 1);
        this.vari.train(regressionDataSet, this.nugget);
        if (executorService == null) {
            setUpVectorMatrix(sampleSize, regressionDataSet, denseMatrix, denseVector);
        } else {
            setUpVectorMatrix(sampleSize, regressionDataSet, denseMatrix, denseVector, executorService);
        }
        for (int i = 0; i < sampleSize; i++) {
            denseMatrix.increment(i, i, -this.errorSqrd);
        }
        LUPDecomposition lUPDecomposition = executorService == null ? new LUPDecomposition(denseMatrix) : new LUPDecomposition(denseMatrix, executorService);
        this.X = lUPDecomposition.solve(denseVector);
        if (Double.isNaN(lUPDecomposition.det()) || Math.abs(lUPDecomposition.det()) < 1.0E-5d) {
            this.X = new SingularValueDecomposition(denseMatrix).solve(denseVector);
        }
    }

    private void setUpVectorMatrix(int i, RegressionDataSet regressionDataSet, Matrix matrix, Vec vec) {
        for (int i2 = 0; i2 < i; i2++) {
            Vec numericalValues = regressionDataSet.getDataPoint(i2).getNumericalValues();
            for (int i3 = 0; i3 < i; i3++) {
                double val = this.vari.val(numericalValues.pNormDist(2.0d, regressionDataSet.getDataPoint(i3).getNumericalValues()));
                matrix.set(i2, i3, val);
                matrix.set(i3, i2, val);
            }
            matrix.set(i2, i, 1.0d);
            matrix.set(i, i2, 1.0d);
            vec.set(i2, regressionDataSet.getTargetValue(i2));
        }
        matrix.set(i, i, 0.0d);
    }

    private void setUpVectorMatrix(final int i, final RegressionDataSet regressionDataSet, final Matrix matrix, final Vec vec, ExecutorService executorService) {
        int i2 = 0;
        final CountDownLatch countDownLatch = new CountDownLatch(SystemInfo.LogicalCores);
        while (i2 < SystemInfo.LogicalCores) {
            final int i3 = i2;
            i2++;
            executorService.submit(new Runnable() { // from class: jsat.regression.OrdinaryKriging.1
                @Override // java.lang.Runnable
                public void run() {
                    int i4 = i3;
                    while (true) {
                        int i5 = i4;
                        if (i5 >= i) {
                            countDownLatch.countDown();
                            return;
                        }
                        Vec numericalValues = regressionDataSet.getDataPoint(i5).getNumericalValues();
                        for (int i6 = 0; i6 < i; i6++) {
                            double val = OrdinaryKriging.this.vari.val(numericalValues.pNormDist(2.0d, regressionDataSet.getDataPoint(i6).getNumericalValues()));
                            matrix.set(i5, i6, val);
                            matrix.set(i6, i5, val);
                        }
                        matrix.set(i5, i, 1.0d);
                        matrix.set(i, i5, 1.0d);
                        vec.set(i5, regressionDataSet.getTargetValue(i5));
                        i4 = i5 + SystemInfo.LogicalCores;
                    }
                }
            });
        }
        matrix.set(i, i, 0.0d);
        while (true) {
            int i4 = i2;
            i2++;
            if (i4 >= SystemInfo.LogicalCores) {
                try {
                    countDownLatch.await();
                    return;
                } catch (InterruptedException e) {
                    Logger.getLogger(OrdinaryKriging.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
                    return;
                }
            }
            countDownLatch.countDown();
        }
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet) {
        train(regressionDataSet, null);
    }

    @Override // jsat.regression.Regressor
    public boolean supportsWeightedData() {
        return false;
    }

    @Override // jsat.regression.Regressor
    public OrdinaryKriging clone() {
        OrdinaryKriging ordinaryKriging = new OrdinaryKriging(this.vari.m716clone());
        ordinaryKriging.setMeasurementError(getMeasurementError());
        ordinaryKriging.setNugget(getNugget());
        if (this.X != null) {
            ordinaryKriging.X = this.X.mo524clone();
        }
        if (this.dataSet != null) {
            ordinaryKriging.dataSet = this.dataSet;
        }
        return ordinaryKriging;
    }

    public void setMeasurementError(double d) {
        this.errorSqrd = d * d;
    }

    public double getMeasurementError() {
        return Math.sqrt(this.errorSqrd);
    }

    public void setNugget(double d) {
        if (d < 0.0d || Double.isNaN(d) || Double.isInfinite(d)) {
            throw new ArithmeticException("Nugget must be a positive value");
        }
        this.nugget = d;
    }

    public double getNugget() {
        return this.nugget;
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return this.params;
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return this.paramMap.get(str);
    }
}
