package org.genemania.engine.core.mania;

import java.util.ArrayList;
import java.util.Collection;
import no.uib.cipr.matrix.Vector;
import org.apache.log4j.Logger;
import org.genemania.engine.Constants;
import org.genemania.engine.cache.DataCache;
import org.genemania.engine.config.Config;
import org.genemania.engine.core.data.Data;
import org.genemania.engine.core.integration.CombinedKernelBuilder;
import org.genemania.engine.core.integration.Feature;
import org.genemania.engine.core.integration.FeatureWeightMap;
import org.genemania.engine.core.propagation.PropagateLabels;
import org.genemania.engine.matricks.SymMatrix;
import org.genemania.exception.ApplicationException;
import org.genemania.util.NullProgressReporter;
import org.genemania.util.ProgressReporter;

/* loaded from: input_file:org/genemania/engine/core/mania/CoreMania.class */
public class CoreMania {
    private static Logger logger = Logger.getLogger(CoreMania.class);
    private DataCache cache;
    private Vector discriminant;
    private SymMatrix partiallyCombinedKernel;
    private SymMatrix combinedKernel;
    private FeatureWeightMap featureWeights;
    private ProgressReporter progress;

    public CoreMania(DataCache dataCache, ProgressReporter progressReporter) {
        this.cache = dataCache;
        this.progress = progressReporter;
    }

    public CoreMania(DataCache dataCache) {
        this(dataCache, NullProgressReporter.instance());
    }

    public void compute(String str, long j, Vector vector, Constants.CombiningMethod combiningMethod, Collection<Collection<Long>> collection, Collection<Long> collection2, int i, String str2, String str3) throws ApplicationException {
        long nanoTime = System.nanoTime();
        computeWeights(str, j, vector, combiningMethod, collection, collection2, i);
        computeDiscriminant(Data.CORE, j, vector, str2, str3);
        logger.info("total time for compute: " + (System.nanoTime() - nanoTime));
    }

    public void compute(String str, long j, Vector vector, Constants.CombiningMethod combiningMethod, Collection<Collection<Long>> collection, String str2, String str3) throws ApplicationException {
        compute(str, j, vector, combiningMethod, collection, new ArrayList(), 0, str2, str3);
    }

    public void computeWeights(String str, long j, Vector vector, Constants.CombiningMethod combiningMethod, Collection<Collection<Long>> collection, Collection<Long> collection2, int i) throws ApplicationException {
        logger.info("computing weights");
        long nanoTime = System.nanoTime();
        CalculateNetworkWeights calculateNetworkWeights = new CalculateNetworkWeights(str, this.cache, collection, collection2, j, vector, i, combiningMethod, this.progress);
        calculateNetworkWeights.process();
        SymMatrix combinedMatrix = calculateNetworkWeights.getCombinedMatrix();
        FeatureWeightMap weights = calculateNetworkWeights.getWeights();
        logger.debug("# weights: " + weights.size());
        int i2 = 0;
        int i3 = 0;
        for (Feature feature : weights.keySet()) {
            if (feature.getType() == Constants.NetworkType.SPARSE_MATRIX) {
                i2++;
            } else {
                if (feature.getType() != Constants.NetworkType.ATTRIBUTE_VECTOR) {
                    throw new ApplicationException("unexpected feature type");
                }
                i3++;
            }
        }
        logger.debug(String.format("# sparse: %d, # attribute %d", Integer.valueOf(i2), Integer.valueOf(i3)));
        if (Config.instance().isCombinedNetworkRenormalizationEnabled()) {
            throw new ApplicationException("renormalization of combined network not supported");
        }
        this.partiallyCombinedKernel = combinedMatrix;
        this.featureWeights = weights;
        logger.info("time for computeWeights: " + (System.nanoTime() - nanoTime));
    }

    public void computeDiscriminant(String str, long j, Vector vector, String str2, String str3) throws ApplicationException {
        logger.info("computing scores");
        long nanoTime = System.nanoTime();
        if (str3.equalsIgnoreCase("hierarchy")) {
            logger.info("using GO hierarchy label bias method");
            throw new ApplicationException("hierarchical biasing not implemented");
        }
        if (!str3.equalsIgnoreCase("average")) {
            throw new ApplicationException("illegal biasing method name");
        }
        logger.info("using average label bias method");
        this.discriminant = PropagateLabels.process(getCombinedKernel(j, str), vector, this.progress);
        logger.info("time for computeDiscriminant: " + (System.nanoTime() - nanoTime));
    }

    public Vector getDiscriminant() {
        return this.discriminant;
    }

    public SymMatrix getPartiallyCombinedKernel() {
        return this.partiallyCombinedKernel;
    }

    public SymMatrix getCombinedKernel(long j, String str) throws ApplicationException {
        if (this.combinedKernel == null) {
            if (this.partiallyCombinedKernel == null && this.featureWeights == null) {
                return null;
            }
            this.combinedKernel = new CombinedKernelBuilder(this.cache).build(j, str, this.partiallyCombinedKernel, this.featureWeights);
        }
        return this.combinedKernel;
    }

    public FeatureWeightMap getFeatureWeights() {
        return this.featureWeights;
    }
}
