package jsat.math.optimization;

import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.LUPDecomposition;
import jsat.linear.Vec;
import jsat.math.Function;
import jsat.utils.FakeExecutor;
import jsat.utils.SystemInfo;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/math/optimization/IterativelyReweightedLeastSquares.class */
public class IterativelyReweightedLeastSquares implements Optimizer {
    private static final long serialVersionUID = -6872953184371630318L;
    private DenseMatrix hessian;
    private DenseMatrix coefficentMatrix;
    private DenseVector derivatives;
    private DenseVector errors;
    private DenseVector gradiant;

    @Override // jsat.math.optimization.Optimizer
    public Vec optimize(double d, int i, Function function, Function function2, Vec vec, List<Vec> list, Vec vec2) {
        return optimize(d, i, function, function2, vec, list, vec2, null);
    }

    @Override // jsat.math.optimization.Optimizer
    public Vec optimize(double d, int i, Function function, Function function2, Vec vec, List<Vec> list, Vec vec2, ExecutorService executorService) {
        int i2;
        int i3;
        this.hessian = new DenseMatrix(vec.length(), vec.length());
        this.coefficentMatrix = new DenseMatrix(list.size(), vec.length());
        for (int i4 = 0; i4 < list.size(); i4++) {
            Vec vec3 = list.get(i4);
            this.coefficentMatrix.set(i4, 0, 1.0d);
            for (int i5 = 1; i5 < vec.length(); i5++) {
                this.coefficentMatrix.set(i4, i5, vec3.get(i5 - 1));
            }
        }
        this.derivatives = new DenseVector(list.size());
        this.errors = new DenseVector(vec2.length());
        this.gradiant = new DenseVector(vec.length());
        if (executorService == null || (executorService instanceof FakeExecutor)) {
            do {
                double iterationStep = iterationStep(function, function2, vec, list, vec2);
                if (Double.isNaN(iterationStep) || iterationStep <= d) {
                    break;
                }
                i2 = i;
                i--;
            } while (i2 > 0);
            return vec;
        }
        do {
            double iterationStep2 = iterationStep(function, function2, vec, list, vec2, executorService);
            if (Double.isNaN(iterationStep2) || iterationStep2 <= d) {
                break;
            }
            i3 = i;
            i--;
        } while (i3 > 0);
        return vec;
    }

    private double iterationStep(Function function, Function function2, Vec vec, List<Vec> list, Vec vec2) {
        for (int i = 0; i < list.size(); i++) {
            Vec vec3 = list.get(i);
            this.errors.set(i, function.f(vec3) - vec2.get(i));
            this.derivatives.set(i, function2.f(vec3));
        }
        for (int i2 = 0; i2 < this.hessian.rows(); i2++) {
            double d = 0.0d;
            for (int i3 = 0; i3 < this.coefficentMatrix.rows(); i3++) {
                double d2 = this.coefficentMatrix.get(i3, i2);
                d += d2 * this.errors.get(i3);
                double d3 = this.derivatives.get(i3) * d2;
                for (int i4 = 0; i4 < this.hessian.rows(); i4++) {
                    this.hessian.increment(i2, i4, this.coefficentMatrix.get(i3, i4) * d3);
                }
            }
            this.gradiant.set(i2, d);
        }
        LUPDecomposition lUPDecomposition = new LUPDecomposition(this.hessian.mo641clone());
        if (Math.abs(lUPDecomposition.det()) < 1.0E-14d) {
            return Double.NaN;
        }
        Vec solve = lUPDecomposition.solve(this.gradiant);
        vec.mutableSubtract(solve);
        return Math.max(solve.max(), Math.abs(solve.min()));
    }

    private double iterationStep(Function function, Function function2, Vec vec, List<Vec> list, Vec vec2, ExecutorService executorService) {
        for (int i = 0; i < list.size(); i++) {
            Vec vec3 = list.get(i);
            this.errors.set(i, function.f(vec3) - vec2.get(i));
            this.derivatives.set(i, function2.f(vec3));
        }
        int rows = this.hessian.rows() % SystemInfo.LogicalCores;
        int rows2 = this.hessian.rows() / SystemInfo.LogicalCores;
        int i2 = 0;
        final CountDownLatch countDownLatch = new CountDownLatch(SystemInfo.LogicalCores);
        for (int i3 = 0; i3 < SystemInfo.LogicalCores; i3++) {
            final int i4 = i2;
            int i5 = rows;
            rows--;
            final int i6 = (i5 > 0 ? 1 : 0) + i4 + rows2;
            i2 = i6;
            executorService.submit(new Runnable() { // from class: jsat.math.optimization.IterativelyReweightedLeastSquares.1
                @Override // java.lang.Runnable
                public void run() {
                    for (int i7 = i4; i7 < i6; i7++) {
                        double d = 0.0d;
                        for (int i8 = 0; i8 < IterativelyReweightedLeastSquares.this.coefficentMatrix.rows(); i8++) {
                            double d2 = IterativelyReweightedLeastSquares.this.coefficentMatrix.get(i8, i7);
                            d += d2 * IterativelyReweightedLeastSquares.this.errors.get(i8);
                            double d3 = IterativelyReweightedLeastSquares.this.derivatives.get(i8) * d2;
                            for (int i9 = 0; i9 < IterativelyReweightedLeastSquares.this.hessian.rows(); i9++) {
                                IterativelyReweightedLeastSquares.this.hessian.increment(i7, i9, IterativelyReweightedLeastSquares.this.coefficentMatrix.get(i8, i9) * d3);
                            }
                        }
                        IterativelyReweightedLeastSquares.this.gradiant.set(i7, d);
                    }
                    countDownLatch.countDown();
                }
            });
        }
        try {
            countDownLatch.await();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        LUPDecomposition lUPDecomposition = new LUPDecomposition(this.hessian.mo641clone(), executorService);
        if (Math.abs(lUPDecomposition.det()) < 1.0E-14d) {
            return Double.NaN;
        }
        Vec solve = lUPDecomposition.solve(this.gradiant);
        vec.mutableSubtract(solve);
        return Math.max(solve.max(), Math.abs(solve.min()));
    }
}
