package jsat.classifiers.boosting;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ExecutorService;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.classifiers.UpdateableClassifier;
import jsat.linear.DenseVector;
import jsat.regression.BaseUpdateableRegressor;
import jsat.regression.RegressionDataSet;
import jsat.regression.UpdateableRegressor;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/boosting/UpdatableStacking.class */
public class UpdatableStacking implements UpdateableClassifier, UpdateableRegressor {
    private static final long serialVersionUID = -5111303510263114862L;
    private int weightsPerModel;
    private UpdateableClassifier aggregatingClassifier;
    private List<UpdateableClassifier> baseClassifiers;
    private UpdateableRegressor aggregatingRegressor;
    private List<UpdateableRegressor> baseRegressors;

    /* JADX WARN: Multi-variable type inference failed */
    public UpdatableStacking(UpdateableClassifier updateableClassifier, List<UpdateableClassifier> list) {
        if (list.size() < 2) {
            throw new IllegalArgumentException("base classifiers must contain at least 2 elements, not " + list.size());
        }
        this.aggregatingClassifier = updateableClassifier;
        this.baseClassifiers = list;
        boolean z = updateableClassifier instanceof UpdateableRegressor;
        Iterator it = list.iterator();
        while (it.hasNext()) {
            if (!(((UpdateableClassifier) it.next()) instanceof UpdateableRegressor)) {
                z = false;
            }
        }
        if (z) {
            this.aggregatingRegressor = (UpdateableRegressor) updateableClassifier;
            this.baseRegressors = list;
        }
    }

    public UpdatableStacking(UpdateableClassifier updateableClassifier, UpdateableClassifier... updateableClassifierArr) {
        this(updateableClassifier, (List<UpdateableClassifier>) Arrays.asList(updateableClassifierArr));
    }

    /* JADX WARN: Multi-variable type inference failed */
    public UpdatableStacking(UpdateableRegressor updateableRegressor, List<UpdateableRegressor> list) {
        this.aggregatingRegressor = updateableRegressor;
        this.baseRegressors = list;
        boolean z = updateableRegressor instanceof UpdateableClassifier;
        Iterator it = list.iterator();
        while (it.hasNext()) {
            if (!(((UpdateableRegressor) it.next()) instanceof UpdateableClassifier)) {
                z = false;
            }
        }
        if (z) {
            this.aggregatingClassifier = (UpdateableClassifier) updateableRegressor;
            this.baseClassifiers = list;
        }
    }

    public UpdatableStacking(UpdateableRegressor updateableRegressor, UpdateableRegressor... updateableRegressorArr) {
        this(updateableRegressor, (List<UpdateableRegressor>) Arrays.asList(updateableRegressorArr));
    }

    public UpdatableStacking(UpdatableStacking updatableStacking) {
        this.weightsPerModel = updatableStacking.weightsPerModel;
        if (updatableStacking.aggregatingClassifier == null) {
            this.aggregatingRegressor = updatableStacking.aggregatingRegressor.clone();
            this.baseRegressors = new ArrayList(updatableStacking.baseRegressors.size());
            Iterator<UpdateableRegressor> it = updatableStacking.baseRegressors.iterator();
            while (it.hasNext()) {
                this.baseRegressors.add(it.next().clone());
            }
            return;
        }
        this.aggregatingClassifier = updatableStacking.aggregatingClassifier.m490clone();
        this.baseClassifiers = new ArrayList(updatableStacking.baseClassifiers.size());
        Iterator<UpdateableClassifier> it2 = updatableStacking.baseClassifiers.iterator();
        while (it2.hasNext()) {
            this.baseClassifiers.add(it2.next().m490clone());
        }
        if (updatableStacking.aggregatingRegressor == updatableStacking.aggregatingClassifier) {
            this.aggregatingRegressor = (UpdateableRegressor) this.aggregatingClassifier;
            this.baseRegressors = this.baseClassifiers;
        }
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        return this.aggregatingClassifier.classify(getPredVecC(dataPoint, 1.0d));
    }

    private DataPoint getPredVecC(DataPoint dataPoint, double d) {
        DenseVector denseVector = new DenseVector(this.weightsPerModel * this.baseClassifiers.size());
        if (this.weightsPerModel == 1) {
            for (int i = 0; i < this.baseClassifiers.size(); i++) {
                denseVector.set(i, (this.baseClassifiers.get(i).classify(dataPoint).getProb(0) * 2.0d) - 1.0d);
            }
        } else {
            for (int i2 = 0; i2 < this.baseClassifiers.size(); i2++) {
                CategoricalResults classify = this.baseClassifiers.get(i2).classify(dataPoint);
                for (int i3 = 0; i3 < this.weightsPerModel; i3++) {
                    denseVector.set((i2 * this.weightsPerModel) + i3, classify.getProb(i3));
                }
            }
        }
        return new DataPoint(denseVector, d);
    }

    private DataPoint getPredVecR(DataPoint dataPoint, double d) {
        DenseVector denseVector = new DenseVector(this.baseRegressors.size());
        for (int i = 0; i < this.baseRegressors.size(); i++) {
            denseVector.set(i, this.baseRegressors.get(i).regress(dataPoint));
        }
        return new DataPoint(denseVector, d);
    }

    @Override // jsat.classifiers.UpdateableClassifier
    public void setUp(CategoricalData[] categoricalDataArr, int i, CategoricalData categoricalData) {
        int numOfCategories = categoricalData.getNumOfCategories();
        this.weightsPerModel = numOfCategories == 2 ? 1 : numOfCategories;
        this.aggregatingClassifier.setUp(new CategoricalData[0], this.weightsPerModel * this.baseClassifiers.size(), categoricalData);
        Iterator<UpdateableClassifier> it = this.baseClassifiers.iterator();
        while (it.hasNext()) {
            it.next().setUp(categoricalDataArr, i, categoricalData);
        }
    }

    @Override // jsat.classifiers.UpdateableClassifier
    public void update(DataPoint dataPoint, int i) {
        this.aggregatingClassifier.update(getPredVecC(dataPoint, dataPoint.getWeight()), i);
        Iterator<UpdateableClassifier> it = this.baseClassifiers.iterator();
        while (it.hasNext()) {
            it.next().update(dataPoint, i);
        }
    }

    @Override // jsat.regression.UpdateableRegressor
    public void setUp(CategoricalData[] categoricalDataArr, int i) {
        this.weightsPerModel = 1;
        this.aggregatingRegressor.setUp(new CategoricalData[0], this.weightsPerModel * this.baseRegressors.size());
        Iterator<UpdateableRegressor> it = this.baseRegressors.iterator();
        while (it.hasNext()) {
            it.next().setUp(categoricalDataArr, i);
        }
    }

    @Override // jsat.regression.UpdateableRegressor
    public void update(DataPoint dataPoint, double d) {
        this.aggregatingRegressor.update(getPredVecR(dataPoint, dataPoint.getWeight()), d);
        Iterator<UpdateableRegressor> it = this.baseRegressors.iterator();
        while (it.hasNext()) {
            it.next().update(dataPoint, d);
        }
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        trainC(classificationDataSet);
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        BaseUpdateableClassifier.trainEpochs(classificationDataSet, this, 1);
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return this.aggregatingClassifier != null ? this.aggregatingClassifier.supportsWeightedData() : this.aggregatingRegressor.supportsWeightedData();
    }

    @Override // jsat.regression.Regressor
    public double regress(DataPoint dataPoint) {
        return this.aggregatingRegressor.regress(getPredVecR(dataPoint, 1.0d));
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet, ExecutorService executorService) {
        train(regressionDataSet);
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet) {
        BaseUpdateableRegressor.trainEpochs(regressionDataSet, this, 1);
    }

    @Override // jsat.regression.Regressor
    /* renamed from: clone, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public UpdatableStacking clone() {
        return new UpdatableStacking(this);
    }
}
