package jsat.distributions.kernels;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import jsat.DataSet;
import jsat.classifiers.ClassificationDataSet;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.utils.DoubleList;
import jsat.utils.IntList;
import jsat.utils.ListUtils;
import org.apache.log4j.Level;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/distributions/kernels/GeneralRBFKernel.class */
public class GeneralRBFKernel extends DistanceMetricBasedKernel {
    private static final long serialVersionUID = 1368225926995372017L;
    private double sigma;
    private double sigmaSqrd2Inv;

    public GeneralRBFKernel(DistanceMetric distanceMetric, double d) {
        super(distanceMetric);
        setSigma(d);
    }

    public void setSigma(double d) {
        if (d <= 0.0d || Double.isNaN(d) || Double.isInfinite(d)) {
            throw new IllegalArgumentException("Sigma must be a positive constant, not " + d);
        }
        this.sigma = d;
        this.sigmaSqrd2Inv = 0.5d / (d * d);
    }

    public double getSigma() {
        return this.sigma;
    }

    @Override // jsat.distributions.kernels.DistanceMetricBasedKernel
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public KernelTrick mo625clone() {
        return new GeneralRBFKernel(this.d.m649clone(), this.sigma);
    }

    @Override // jsat.distributions.kernels.KernelTrick
    public double eval(Vec vec, Vec vec2) {
        double dist = this.d.dist(vec, vec2);
        return Math.exp((-dist) * dist * this.sigmaSqrd2Inv);
    }

    @Override // jsat.distributions.kernels.KernelTrick
    public double eval(int i, Vec vec, List<Double> list, List<? extends Vec> list2, List<Double> list3) {
        double dist = this.d.dist(i, vec, list, list2, list3);
        return Math.exp((-dist) * dist * this.sigmaSqrd2Inv);
    }

    @Override // jsat.distributions.kernels.KernelTrick
    public double eval(int i, int i2, List<? extends Vec> list, List<Double> list2) {
        double dist = this.d.dist(i, i2, list, list2);
        return Math.exp((-dist) * dist * this.sigmaSqrd2Inv);
    }

    public Distribution guessSigma(DataSet dataSet) {
        return guessSigma(dataSet, this.d);
    }

    public static Distribution guessSigma(DataSet dataSet, DistanceMetric distanceMetric) {
        List<? extends Vec> dataVectors = dataSet.getDataVectors();
        int sampleSize = dataSet.getSampleSize();
        if (sampleSize > 5000) {
            sampleSize = Level.TRACE_INT + ((int) Math.floor(Math.sqrt(dataSet.getSampleSize() - Level.TRACE_INT)));
        }
        DoubleList doubleList = new DoubleList(sampleSize * sampleSize);
        if ((dataSet instanceof ClassificationDataSet) && ((ClassificationDataSet) dataSet).getPredicting().getNumOfCategories() == 2) {
            ClassificationDataSet classificationDataSet = (ClassificationDataSet) dataSet;
            List<? extends Vec> arrayList = new ArrayList<>(sampleSize / 2);
            ArrayList arrayList2 = new ArrayList(sampleSize / 2);
            IntList intList = new IntList(dataSet.getSampleSize());
            ListUtils.addRange(intList, 0, dataSet.getSampleSize(), 1);
            Collections.shuffle(intList);
            for (int i = 0; i < intList.size(); i++) {
                int i2 = intList.getI(i);
                if (classificationDataSet.getDataPointCategory(i2) == 0 && arrayList.size() < sampleSize / 2) {
                    arrayList.add(classificationDataSet.getDataPoint(i2).getNumericalValues());
                } else if (classificationDataSet.getDataPointCategory(i2) == 1 && arrayList.size() < sampleSize / 2) {
                    arrayList2.add(classificationDataSet.getDataPoint(i2).getNumericalValues());
                }
            }
            int size = arrayList.size();
            arrayList.addAll(arrayList2);
            List<Double> accelerationCache = distanceMetric.getAccelerationCache(arrayList);
            for (int i3 = 0; i3 < size; i3++) {
                for (int i4 = size; i4 < arrayList.size(); i4++) {
                    doubleList.add(distanceMetric.dist(i3, i4, dataVectors, accelerationCache));
                }
            }
        } else {
            Collections.shuffle(dataVectors);
            if (dataSet.getSampleSize() > 5000) {
                dataVectors = dataVectors.subList(0, sampleSize);
            }
            List<Double> accelerationCache2 = distanceMetric.getAccelerationCache(dataVectors);
            for (int i5 = 0; i5 < dataVectors.size(); i5++) {
                for (int i6 = i5 + 1; i6 < dataVectors.size(); i6++) {
                    doubleList.add(distanceMetric.dist(i5, i6, dataVectors, accelerationCache2));
                }
            }
        }
        Collections.sort(doubleList);
        double doubleValue = doubleList.get(doubleList.size() / 2).doubleValue();
        return new LogUniform(Math.exp(Math.log(doubleValue) - 4.0d), Math.exp(Math.log(doubleValue) + 4.0d));
    }

    @Override // jsat.distributions.kernels.KernelTrick
    public boolean normalized() {
        return true;
    }
}
