package jsat.text.wordweighting;

import java.util.Iterator;
import java.util.List;
import jsat.linear.IndexValue;
import jsat.linear.Vec;

/* loaded from: input_file:JSAT-0.0.7.jar:jsat/text/wordweighting/OkapiBM25.class */
public class OkapiBM25 extends WordWeighting {
    private static final long serialVersionUID = 6456657674702490465L;
    private double k1;
    private double b;
    private double N;
    private double docAvg;
    private int[] df;

    public OkapiBM25() {
        this(1.5d, 0.75d);
    }

    public OkapiBM25(double d, double d2) {
        if (Double.isNaN(d) || Double.isInfinite(d) || d < 0.0d) {
            throw new IllegalArgumentException("coefficient k1 must be a non negative constant, not " + d);
        }
        this.k1 = d;
        if (Double.isNaN(d2) || d2 < 0.0d || d2 > 1.0d) {
            throw new IllegalArgumentException("coefficient b must be in the range [0,1], not " + d2);
        }
        this.b = d2;
    }

    @Override // jsat.text.wordweighting.WordWeighting
    public void setWeight(List<? extends Vec> list, List<Integer> list2) {
        this.df = new int[list2.size()];
        this.docAvg = 0.0d;
        Iterator<? extends Vec> it = list.iterator();
        while (it.hasNext()) {
            Iterator<IndexValue> it2 = it.next().iterator();
            while (it2.hasNext()) {
                IndexValue next = it2.next();
                this.docAvg += next.getValue();
                int[] iArr = this.df;
                int index = next.getIndex();
                iArr[index] = iArr[index] + 1;
            }
        }
        this.N = list.size();
        this.docAvg /= this.N;
    }

    @Override // jsat.text.wordweighting.WordWeighting
    public void applyTo(Vec vec) {
        if (this.df == null) {
            throw new RuntimeException("OkapiBM25 weightings haven't been initialized, setWeight method must be called before first use.");
        }
        double sum = vec.sum();
        Iterator<IndexValue> it = vec.iterator();
        while (it.hasNext()) {
            IndexValue next = it.next();
            double value = next.getValue();
            vec.set(next.getIndex(), (Math.log(((this.N - this.df[r0]) + 0.5d) / (this.df[r0] + 0.5d)) * (value * (this.k1 + 1.0d))) / (value + (this.k1 * ((1.0d - this.b) + ((this.b * sum) / this.docAvg)))));
        }
    }

    @Override // jsat.math.IndexFunction
    public double indexFunc(double d, int i) {
        return (i < 0 || d == 0.0d) ? 0.0d : 0.0d;
    }
}
