package jsat.regression;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.classifiers.DataPoint;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.distributions.kernels.KernelTrick;
import jsat.distributions.kernels.RBFKernel;
import jsat.linear.CholeskyDecomposition;
import jsat.linear.DenseMatrix;
import jsat.linear.Vec;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.FakeExecutor;
import jsat.utils.SystemInfo;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/regression/KernelRidgeRegression.class */
public class KernelRidgeRegression implements Regressor, Parameterized {
    private static final long serialVersionUID = 6275333785663250072L;
    private double lambda;

    @Parameter.ParameterHolder
    private KernelTrick k;
    private List<Vec> vecs;
    private double[] alphas;

    public KernelRidgeRegression() {
        this(1.0E-6d, new RBFKernel());
    }

    public KernelRidgeRegression(double d, KernelTrick kernelTrick) {
        setLambda(d);
        setKernel(kernelTrick);
    }

    protected KernelRidgeRegression(KernelRidgeRegression kernelRidgeRegression) {
        this(kernelRidgeRegression.lambda, kernelRidgeRegression.getKernel().m628clone());
        if (kernelRidgeRegression.alphas != null) {
            this.alphas = Arrays.copyOf(kernelRidgeRegression.alphas, kernelRidgeRegression.alphas.length);
        }
        if (kernelRidgeRegression.vecs != null) {
            this.vecs = new ArrayList(kernelRidgeRegression.vecs);
        }
    }

    public static Distribution guessLambda(DataSet dataSet) {
        return new LogUniform(1.0E-7d, 0.01d);
    }

    public void setLambda(double d) {
        if (Double.isNaN(d) || Double.isInfinite(d) || d <= 0.0d) {
            throw new IllegalArgumentException("lambda must be a positive constant, not " + d);
        }
        this.lambda = d;
    }

    public double getLambda() {
        return this.lambda;
    }

    public void setKernel(KernelTrick kernelTrick) {
        this.k = kernelTrick;
    }

    public KernelTrick getKernel() {
        return this.k;
    }

    @Override // jsat.regression.Regressor
    public double regress(DataPoint dataPoint) {
        Vec numericalValues = dataPoint.getNumericalValues();
        double d = 0.0d;
        for (int i = 0; i < this.alphas.length; i++) {
            d += this.alphas[i] * this.k.eval(this.vecs.get(i), numericalValues);
        }
        return d;
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet, ExecutorService executorService) {
        final int sampleSize = regressionDataSet.getSampleSize();
        this.vecs = new ArrayList(sampleSize);
        Vec targetValues = regressionDataSet.getTargetValues();
        for (int i = 0; i < sampleSize; i++) {
            this.vecs.add(regressionDataSet.getDataPoint(i).getNumericalValues());
        }
        final DenseMatrix denseMatrix = new DenseMatrix(sampleSize, sampleSize);
        final CountDownLatch countDownLatch = new CountDownLatch(SystemInfo.LogicalCores);
        for (int i2 = 0; i2 < SystemInfo.LogicalCores; i2++) {
            final int i3 = i2;
            executorService.submit(new Runnable() { // from class: jsat.regression.KernelRidgeRegression.1
                @Override // java.lang.Runnable
                public void run() {
                    int i4 = i3;
                    while (true) {
                        int i5 = i4;
                        if (i5 >= sampleSize) {
                            countDownLatch.countDown();
                            return;
                        }
                        denseMatrix.set(i5, i5, KernelRidgeRegression.this.k.eval((Vec) KernelRidgeRegression.this.vecs.get(i5), (Vec) KernelRidgeRegression.this.vecs.get(i5)) + KernelRidgeRegression.this.lambda);
                        for (int i6 = i5 + 1; i6 < sampleSize; i6++) {
                            double eval = KernelRidgeRegression.this.k.eval((Vec) KernelRidgeRegression.this.vecs.get(i5), (Vec) KernelRidgeRegression.this.vecs.get(i6));
                            denseMatrix.set(i5, i6, eval);
                            denseMatrix.set(i6, i5, eval);
                        }
                        i4 = i5 + SystemInfo.LogicalCores;
                    }
                }
            });
        }
        try {
            countDownLatch.await();
        } catch (InterruptedException e) {
            Logger.getLogger(KernelRidgeRegression.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
        }
        this.alphas = (executorService instanceof FakeExecutor ? new CholeskyDecomposition(denseMatrix) : new CholeskyDecomposition(denseMatrix, executorService)).solve(targetValues).arrayCopy();
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet) {
        train(regressionDataSet, new FakeExecutor());
    }

    @Override // jsat.regression.Regressor
    public boolean supportsWeightedData() {
        return false;
    }

    @Override // jsat.regression.Regressor
    public KernelRidgeRegression clone() {
        return new KernelRidgeRegression(this);
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }
}
