package jsat.classifiers.neuralnetwork;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ExecutorService;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.neuralnetwork.activations.ReLU;
import jsat.classifiers.neuralnetwork.activations.SoftmaxLayer;
import jsat.classifiers.neuralnetwork.initializers.ConstantInit;
import jsat.classifiers.neuralnetwork.initializers.GaussianNormalInit;
import jsat.classifiers.neuralnetwork.regularizers.Max2NormRegularizer;
import jsat.linear.SparseVector;
import jsat.linear.Vec;
import jsat.math.optimization.stochastic.AdaDelta;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.IntList;
import jsat.utils.ListUtils;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/classifiers/neuralnetwork/DReDNetSimple.class */
public class DReDNetSimple implements Classifier, Parameterized {
    private static final long serialVersionUID = -342281027279571332L;
    private SGDNetworkTrainer network;
    private int[] hiddenSizes;
    private int batchSize;
    private int epochs;

    public DReDNetSimple() {
        this(1024, 1024);
    }

    public DReDNetSimple(int... iArr) {
        this.batchSize = 256;
        this.epochs = 100;
        setHiddenSizes(iArr);
    }

    public void setHiddenSizes(int[] iArr) {
        for (int i = 0; i < iArr.length; i++) {
            if (iArr[i] <= 0) {
                throw new IllegalArgumentException("Hidden layer " + i + " must contain a positive number of neurons, not " + iArr[i]);
            }
        }
        this.hiddenSizes = Arrays.copyOf(iArr, iArr.length);
    }

    public int[] getHiddenSizes() {
        return this.hiddenSizes;
    }

    public void setBatchSize(int i) {
        this.batchSize = i;
    }

    public int getBatchSize() {
        return this.batchSize;
    }

    public void setEpochs(int i) {
        if (i <= 0) {
            throw new IllegalArgumentException("Number of epochs must be positive");
        }
        this.epochs = i;
    }

    public int getEpochs() {
        return this.epochs;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        return new CategoricalResults(this.network.feedfoward(dataPoint.getNumericalValues()).arrayCopy());
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        setup(classificationDataSet);
        List<Vec> dataVectors = classificationDataSet.getDataVectors();
        ArrayList arrayList = new ArrayList(classificationDataSet.getSampleSize());
        for (int i = 0; i < classificationDataSet.getSampleSize(); i++) {
            SparseVector sparseVector = new SparseVector(classificationDataSet.getClassSize(), 1);
            sparseVector.set(classificationDataSet.getDataPointCategory(i), 1.0d);
            arrayList.add(sparseVector);
        }
        IntList intList = new IntList(dataVectors.size());
        ListUtils.addRange(intList, 0, dataVectors.size(), 1);
        ArrayList arrayList2 = new ArrayList(this.batchSize);
        ArrayList arrayList3 = new ArrayList(this.batchSize);
        for (int i2 = 0; i2 < this.epochs; i2++) {
            System.currentTimeMillis();
            double d = 0.0d;
            Collections.shuffle(intList);
            int i3 = 0;
            while (true) {
                int i4 = i3;
                if (i4 < dataVectors.size()) {
                    int min = Math.min(i4 + this.batchSize, dataVectors.size());
                    arrayList2.clear();
                    arrayList3.clear();
                    for (int i5 = i4; i5 < min; i5++) {
                        arrayList2.add(dataVectors.get(i5));
                        arrayList3.add(arrayList.get(i5));
                    }
                    d += executorService != null ? this.network.updateMiniBatch(arrayList2, arrayList3, executorService) : this.network.updateMiniBatch(arrayList2, arrayList3);
                    i3 = i4 + this.batchSize;
                }
            }
            System.currentTimeMillis();
        }
        this.network.finishUpdating();
    }

    private void setup(ClassificationDataSet classificationDataSet) {
        this.network = new SGDNetworkTrainer();
        int[] iArr = new int[this.hiddenSizes.length + 2];
        iArr[0] = classificationDataSet.getNumNumericalVars();
        for (int i = 0; i < this.hiddenSizes.length; i++) {
            iArr[i + 1] = this.hiddenSizes[i];
        }
        iArr[iArr.length - 1] = classificationDataSet.getClassSize();
        this.network.setLayerSizes(iArr);
        ArrayList arrayList = new ArrayList(this.hiddenSizes.length + 2);
        for (int i2 : this.hiddenSizes) {
            arrayList.add(new ReLU());
        }
        arrayList.add(new SoftmaxLayer());
        this.network.setLayersActivation(arrayList);
        this.network.setRegularizer(new Max2NormRegularizer(25.0d));
        this.network.setWeightInit(new GaussianNormalInit(0.01d));
        this.network.setBiasInit(new ConstantInit(0.1d));
        this.network.setEta(1.0d);
        this.network.setGradientUpdater(new AdaDelta());
        this.network.setup();
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        trainC(classificationDataSet, null);
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    @Override // jsat.classifiers.Classifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public DReDNetSimple m546clone() {
        DReDNetSimple dReDNetSimple = new DReDNetSimple(this.hiddenSizes);
        if (this.network != null) {
            dReDNetSimple.network = this.network.m556clone();
        }
        dReDNetSimple.batchSize = this.batchSize;
        dReDNetSimple.epochs = this.epochs;
        return dReDNetSimple;
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }
}
