package jsat.lossfunctions;

import jsat.classifiers.CategoricalResults;
import jsat.linear.Vec;
import jsat.math.MathTricks;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/lossfunctions/SoftmaxLoss.class */
public class SoftmaxLoss extends LogisticLoss implements LossMC {
    private static final long serialVersionUID = 3936898932252996024L;

    @Override // jsat.lossfunctions.LossMC
    public double getLoss(Vec vec, int i) {
        return -Math.log(vec.get(i));
    }

    @Override // jsat.lossfunctions.LossMC
    public void process(Vec vec, Vec vec2) {
        if (vec != vec2) {
            vec.copyTo(vec2);
        }
        MathTricks.softmax(vec2, false);
    }

    @Override // jsat.lossfunctions.LossMC
    public void deriv(Vec vec, Vec vec2, int i) {
        for (int i2 = 0; i2 < vec.length(); i2++) {
            if (i2 == i) {
                vec2.set(i2, vec.get(i2) - 1.0d);
            } else {
                vec2.set(i2, vec.get(i2));
            }
        }
    }

    @Override // jsat.lossfunctions.LossMC
    public CategoricalResults getClassification(Vec vec) {
        return new CategoricalResults(vec.arrayCopy());
    }
}
