package jsat.classifiers.neuralnetwork.regularizers;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.linear.Matrix;
import jsat.linear.Vec;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/neuralnetwork/regularizers/Max2NormRegularizer.class */
public class Max2NormRegularizer implements WeightRegularizer {
    private static final long serialVersionUID = 1989826758516880355L;
    private double maxNorm;

    public Max2NormRegularizer(double d) {
        setMaxNorm(d);
    }

    public void setMaxNorm(double d) {
        if (Double.isNaN(d) || Double.isInfinite(d) || d <= 0.0d) {
            throw new IllegalArgumentException("The maximum norm must be a positive constant, not " + d);
        }
        this.maxNorm = d;
    }

    public double getMaxNorm() {
        return this.maxNorm;
    }

    @Override // jsat.classifiers.neuralnetwork.regularizers.WeightRegularizer
    public void applyRegularization(Matrix matrix, Vec vec) {
        for (int i = 0; i < matrix.rows(); i++) {
            Vec rowView = matrix.getRowView(i);
            double pNorm = rowView.pNorm(2.0d);
            if (pNorm >= this.maxNorm) {
                rowView.mutableMultiply(this.maxNorm / pNorm);
                vec.set(i, (vec.get(i) * this.maxNorm) / pNorm);
            }
        }
    }

    @Override // jsat.classifiers.neuralnetwork.regularizers.WeightRegularizer
    public void applyRegularization(final Matrix matrix, final Vec vec, ExecutorService executorService) {
        ArrayList arrayList = new ArrayList(matrix.rows());
        for (int i = 0; i < matrix.rows(); i++) {
            final int i2 = i;
            arrayList.add(executorService.submit(new Runnable() { // from class: jsat.classifiers.neuralnetwork.regularizers.Max2NormRegularizer.1
                @Override // java.lang.Runnable
                public void run() {
                    Vec rowView = matrix.getRowView(i2);
                    double pNorm = rowView.pNorm(2.0d);
                    if (pNorm >= Max2NormRegularizer.this.maxNorm) {
                        rowView.mutableMultiply(Max2NormRegularizer.this.maxNorm / pNorm);
                        vec.set(i2, (vec.get(i2) * Max2NormRegularizer.this.maxNorm) / pNorm);
                    }
                }
            }));
        }
        try {
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                ((Future) it.next()).get();
            }
        } catch (InterruptedException e) {
            Logger.getLogger(Max2NormRegularizer.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
        } catch (ExecutionException e2) {
            Logger.getLogger(Max2NormRegularizer.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e2);
        }
    }

    @Override // jsat.classifiers.neuralnetwork.regularizers.WeightRegularizer
    public double applyRegularizationToRow(Vec vec, double d) {
        double pNorm = vec.pNorm(2.0d);
        if (pNorm < this.maxNorm) {
            return d;
        }
        vec.mutableMultiply(this.maxNorm / pNorm);
        return (d * this.maxNorm) / pNorm;
    }

    @Override // jsat.classifiers.neuralnetwork.regularizers.WeightRegularizer
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public Max2NormRegularizer m568clone() {
        return new Max2NormRegularizer(this.maxNorm);
    }
}
