package jsat.classifiers.boosting;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ExecutorService;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.exceptions.FailedToFitException;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.DoubleList;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/boosting/AdaBoostM1.class */
public class AdaBoostM1 implements Classifier, Parameterized {
    private static final long serialVersionUID = 4205232097748332861L;
    private Classifier weakLearner;
    private int maxIterations;
    protected List<Classifier> hypoths;
    protected List<Double> hypWeights;
    protected CategoricalData predicting;

    public AdaBoostM1(Classifier classifier, int i) {
        setWeakLearner(classifier);
        this.maxIterations = i;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public List<Classifier> getModels() {
        return Collections.unmodifiableList(this.hypoths);
    }

    public List<Double> getModelWeights() {
        return Collections.unmodifiableList(this.hypWeights);
    }

    public void setMaxIterations(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("Number of iterations must be a positive value, no " + i);
        }
        this.maxIterations = i;
    }

    public Classifier getWeakLearner() {
        return this.weakLearner;
    }

    public void setWeakLearner(Classifier classifier) {
        if (!classifier.supportsWeightedData()) {
            throw new FailedToFitException("WeakLearner must support weighted data to be boosted");
        }
        this.weakLearner = classifier;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        if (this.predicting == null) {
            throw new RuntimeException("Classifier has not been trained yet");
        }
        CategoricalResults categoricalResults = new CategoricalResults(this.predicting.getNumOfCategories());
        for (int i = 0; i < this.hypoths.size(); i++) {
            categoricalResults.incProb(this.hypoths.get(i).classify(dataPoint).mostLikely(), this.hypWeights.get(i).doubleValue());
        }
        categoricalResults.normalize();
        return categoricalResults;
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        this.predicting = classificationDataSet.getPredicting();
        this.hypWeights = new DoubleList(this.maxIterations);
        this.hypoths = new ArrayList(this.maxIterations);
        List<DataPointPair<Integer>> asDPPList = classificationDataSet.getAsDPPList();
        Iterator<DataPointPair<Integer>> it = asDPPList.iterator();
        while (it.hasNext()) {
            it.next().getDataPoint().setWeight(1.0d);
        }
        double size = asDPPList.size();
        boolean[] zArr = new boolean[asDPPList.size()];
        for (int i = 0; i < this.maxIterations; i++) {
            if (executorService != null) {
                this.weakLearner.trainC(new ClassificationDataSet(asDPPList, this.predicting), executorService);
            } else {
                this.weakLearner.trainC(new ClassificationDataSet(asDPPList, this.predicting));
            }
            double d = 0.0d;
            for (int i2 = 0; i2 < asDPPList.size(); i2++) {
                int i3 = i2;
                boolean z = this.weakLearner.classify(asDPPList.get(i2).getDataPoint()).mostLikely() == asDPPList.get(i2).getPair().intValue();
                zArr[i3] = z;
                if (!z) {
                    d += asDPPList.get(i2).getDataPoint().getWeight();
                }
            }
            double d2 = d / size;
            if (d2 > 0.5d || d2 == 0.0d) {
                return;
            }
            double d3 = d2 / (1.0d - d2);
            double d4 = 0.0d;
            double d5 = size;
            for (int i4 = 0; i4 < zArr.length; i4++) {
                DataPoint dataPoint = asDPPList.get(i4).getDataPoint();
                if (zArr[i4]) {
                    dataPoint.setWeight(dataPoint.getWeight() * d3);
                }
                double weight = dataPoint.getWeight() / size;
                if (1.0d / weight > d5) {
                    d5 = 1.0d / weight;
                }
                d4 += dataPoint.getWeight() / size;
            }
            for (DataPointPair<Integer> dataPointPair : asDPPList) {
                dataPointPair.getDataPoint().setWeight(((dataPointPair.getDataPoint().getWeight() / size) * d5) / d4);
            }
            size = d5;
            this.hypoths.add(this.weakLearner.mo480clone());
            this.hypWeights.add(Double.valueOf(Math.log(1.0d / d3)));
        }
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        trainC(classificationDataSet, null);
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    @Override // jsat.classifiers.Classifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public AdaBoostM1 mo480clone() {
        AdaBoostM1 adaBoostM1 = new AdaBoostM1(this.weakLearner.mo480clone(), this.maxIterations);
        if (this.hypWeights != null) {
            adaBoostM1.hypWeights = new DoubleList(this.hypWeights);
        }
        if (this.hypoths != null) {
            adaBoostM1.hypoths = new ArrayList(this.hypoths.size());
            for (int i = 0; i < this.hypoths.size(); i++) {
                adaBoostM1.hypoths.add(this.hypoths.get(i).mo480clone());
            }
        }
        if (this.predicting != null) {
            adaBoostM1.predicting = this.predicting.m481clone();
        }
        return adaBoostM1;
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }
}
