package jsat.classifiers.boosting;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ExecutorService;
import jsat.DataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.classifiers.trees.DecisionTree;
import jsat.distributions.Distribution;
import jsat.distributions.Uniform;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.DoubleList;
import jsat.utils.FakeExecutor;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/boosting/EmphasisBoost.class */
public class EmphasisBoost implements Classifier, Parameterized, BinaryScoreClassifier {
    private static final long serialVersionUID = -6372897830449685891L;

    @Parameter.ParameterHolder
    private Classifier weakLearner;
    private int maxIterations;
    protected List<Classifier> hypoths;
    protected List<Double> hypWeights;
    protected CategoricalData predicting;
    private double lambda;

    public EmphasisBoost() {
        this(new DecisionTree(6), 200, 0.35d);
    }

    public EmphasisBoost(Classifier classifier, int i, double d) {
        setWeakLearner(classifier);
        setMaxIterations(i);
        setLambda(d);
    }

    protected EmphasisBoost(EmphasisBoost emphasisBoost) {
        this(emphasisBoost.weakLearner.m504clone(), emphasisBoost.maxIterations, emphasisBoost.lambda);
        if (emphasisBoost.hypWeights != null) {
            this.hypWeights = new DoubleList(emphasisBoost.hypWeights);
            this.hypoths = new ArrayList(emphasisBoost.maxIterations);
            Iterator<Classifier> it = emphasisBoost.hypoths.iterator();
            while (it.hasNext()) {
                this.hypoths.add(it.next().m504clone());
            }
            this.predicting = emphasisBoost.predicting.m481clone();
        }
    }

    public List<Classifier> getModels() {
        return Collections.unmodifiableList(this.hypoths);
    }

    public List<Double> getModelWeights() {
        return Collections.unmodifiableList(this.hypWeights);
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public void setMaxIterations(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("Iterations must be positive, not " + i);
        }
        this.maxIterations = i;
    }

    public Classifier getWeakLearner() {
        return this.weakLearner;
    }

    public void setWeakLearner(Classifier classifier) {
        if (!classifier.supportsWeightedData()) {
            throw new IllegalArgumentException("WeakLearner must support weighted data to be boosted");
        }
        this.weakLearner = classifier;
    }

    public static Distribution guessLambda(DataSet dataSet) {
        return new Uniform(0.25d, 0.45d);
    }

    public void setLambda(double d) {
        this.lambda = d;
    }

    public double getLambda() {
        return this.lambda;
    }

    @Override // jsat.classifiers.calibration.BinaryScoreClassifier
    public double getScore(DataPoint dataPoint) {
        double d = 0.0d;
        for (int i = 0; i < this.hypoths.size(); i++) {
            d += H(this.hypoths.get(i), dataPoint) * this.hypWeights.get(i).doubleValue();
        }
        return d;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        if (this.predicting == null) {
            throw new RuntimeException("Classifier has not been trained yet");
        }
        CategoricalResults categoricalResults = new CategoricalResults(this.predicting.getNumOfCategories());
        if (getScore(dataPoint) < 0.0d) {
            categoricalResults.setProb(0, 1.0d);
        } else {
            categoricalResults.setProb(1, 1.0d);
        }
        return categoricalResults;
    }

    private double H(Classifier classifier, DataPoint dataPoint) {
        return (classifier.classify(dataPoint).getProb(1) * 2.0d) - 1.0d;
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        this.predicting = classificationDataSet.getPredicting();
        this.hypWeights = new DoubleList(this.maxIterations);
        this.hypoths = new ArrayList(this.maxIterations);
        int sampleSize = classificationDataSet.getSampleSize();
        List<DataPointPair<Integer>> asDPPList = classificationDataSet.getTwiceShallowClone().getAsDPPList();
        Iterator<DataPointPair<Integer>> it = asDPPList.iterator();
        while (it.hasNext()) {
            it.next().getDataPoint().setWeight(1.0d / sampleSize);
        }
        double[] dArr = new double[sampleSize];
        double[] dArr2 = new double[sampleSize];
        for (int i = 0; i < this.maxIterations; i++) {
            Classifier m504clone = this.weakLearner.m504clone();
            if (executorService == null || (executorService instanceof FakeExecutor)) {
                m504clone.trainC(new ClassificationDataSet(asDPPList, this.predicting));
            } else {
                m504clone.trainC(new ClassificationDataSet(asDPPList, this.predicting), executorService);
            }
            double d = 0.0d;
            for (int i2 = 0; i2 < asDPPList.size(); i2++) {
                DataPointPair<Integer> dataPointPair = asDPPList.get(i2);
                double H = H(m504clone, dataPointPair.getDataPoint());
                dArr[i2] = H;
                d += dataPointPair.getDataPoint().getWeight() * H * ((dataPointPair.getPair().intValue() * 2) - 1);
            }
            if (d < 0.0d) {
                return;
            }
            double log = Math.log((1.0d + d) / (1.0d - d)) / 2.0d;
            double d2 = 0.0d;
            for (int i3 = 0; i3 < asDPPList.size(); i3++) {
                int i4 = i3;
                dArr2[i4] = dArr2[i4] + (log * dArr[i3]);
                double d3 = dArr2[i3];
                DataPoint dataPoint = asDPPList.get(i3).getDataPoint();
                double exp = Math.exp((this.lambda * Math.pow(d3 - ((r0.getPair().intValue() * 2) - 1), 2.0d)) - (((1.0d - this.lambda) * d3) * d3));
                if (Double.isInfinite(exp)) {
                    exp = 50.0d;
                }
                d2 += exp;
                dataPoint.setWeight(exp);
            }
            for (int i5 = 0; i5 < asDPPList.size(); i5++) {
                DataPoint dataPoint2 = asDPPList.get(i5).getDataPoint();
                dataPoint2.setWeight(dataPoint2.getWeight() / d2);
            }
            this.hypoths.add(m504clone);
            this.hypWeights.add(Double.valueOf(log));
        }
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        trainC(classificationDataSet, null);
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    @Override // jsat.classifiers.Classifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public EmphasisBoost m504clone() {
        return new EmphasisBoost(this);
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }
}
