package org.genemania.engine.actions;

import java.io.File;
import java.io.FileWriter;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import no.uib.cipr.matrix.DenseMatrix;
import no.uib.cipr.matrix.DenseVector;
import no.uib.cipr.matrix.Vector;
import org.apache.log4j.Logger;
import org.genemania.dto.AttributeDto;
import org.genemania.dto.InteractionDto;
import org.genemania.dto.NetworkDto;
import org.genemania.dto.NodeDto;
import org.genemania.dto.RelatedGenesEngineRequestDto;
import org.genemania.dto.RelatedGenesEngineResponseDto;
import org.genemania.engine.Constants;
import org.genemania.engine.cache.DataCache;
import org.genemania.engine.core.MatrixUtils;
import org.genemania.engine.core.data.AttributeData;
import org.genemania.engine.core.data.AttributeGroups;
import org.genemania.engine.core.data.Data;
import org.genemania.engine.core.data.DataSupport;
import org.genemania.engine.core.data.NetworkIds;
import org.genemania.engine.core.data.NodeIds;
import org.genemania.engine.core.integration.Feature;
import org.genemania.engine.core.integration.FeatureWeightMap;
import org.genemania.engine.core.mania.CoreMania;
import org.genemania.engine.core.utils.Logging;
import org.genemania.engine.exception.CancellationException;
import org.genemania.engine.labels.LabelVectorGenerator;
import org.genemania.engine.matricks.Matrix;
import org.genemania.engine.matricks.SymMatrix;
import org.genemania.exception.ApplicationException;
import org.genemania.type.CombiningMethod;

/* loaded from: input_file:org/genemania/engine/actions/FindRelated.class */
public class FindRelated {
    private static Logger logger = Logger.getLogger(FindRelated.class);
    private DataCache cache;
    private RelatedGenesEngineRequestDto request;
    private int numRequestNetworks;
    private int numRequestAttributeGroups;
    private boolean hasUserNetworks;
    private boolean hasUserAttributes;
    static final double posLabelValue = 1.0d;
    static final double negLabelValue = -1.0d;
    static final double unLabeledValueProduction = -1.0d;
    static final double unLabeledValueValidation = 0.0d;
    private long requestStartTimeMillis;
    private long requestEndTimeMillis;

    public FindRelated(DataCache dataCache, RelatedGenesEngineRequestDto relatedGenesEngineRequestDto) {
        this.cache = dataCache;
        this.request = relatedGenesEngineRequestDto;
    }

    public RelatedGenesEngineResponseDto process() throws ApplicationException {
        try {
            this.requestStartTimeMillis = System.currentTimeMillis();
            logStart();
            checkQuery();
            logQuery();
            DenseVector createLabelsFromIds = LabelVectorGenerator.createLabelsFromIds(this.cache.getNodeIds(this.request.getOrganismId()), this.request.getPositiveNodes(), new ArrayList(), 1.0d, -1.0d, -1.0d);
            Constants.CombiningMethod convertCombiningMethod = Constants.convertCombiningMethod(this.request.getCombiningMethod(), this.request.getPositiveNodes().size());
            Constants.ScoringMethod convertScoringMethod = Constants.convertScoringMethod(this.request.getScoringMethod());
            Collection<Collection<Long>> interactionNetworks = this.request.getInteractionNetworks();
            CoreMania coreMania = new CoreMania(this.cache, this.request.getProgressReporter());
            coreMania.compute(safeGetNamespace(), this.request.getOrganismId(), createLabelsFromIds, convertCombiningMethod, interactionNetworks, this.request.getAttributeGroups(), this.request.getAttributesLimit(), null, "average");
            SymMatrix partiallyCombinedKernel = coreMania.getPartiallyCombinedKernel();
            FeatureWeightMap featureWeights = coreMania.getFeatureWeights();
            Vector discriminant = coreMania.getDiscriminant();
            RelatedGenesEngineResponseDto prepareResponse = prepareResponse(convertScore(convertScoringMethod, discriminant, partiallyCombinedKernel, createLabelsFromIds, 1.0d, -1.0d), discriminant, featureWeights, partiallyCombinedKernel, selectScoreThreshold(convertScoringMethod), convertScoringMethod, Constants.convertCombiningMethod(convertCombiningMethod));
            this.requestEndTimeMillis = System.currentTimeMillis();
            logEnd();
            return prepareResponse;
        } catch (CancellationException e) {
            logger.info("request was cancelled");
            return null;
        }
    }

    private void encodeAttributes(RelatedGenesEngineResponseDto relatedGenesEngineResponseDto, int[] iArr, FeatureWeightMap featureWeightMap) throws ApplicationException {
        if (this.request.getAttributeGroups() == null || this.request.getAttributeGroups().size() == 0) {
            setEmptyAttributeResponse(relatedGenesEngineResponseDto);
        } else {
            addAttributesForSelectedNodes(relatedGenesEngineResponseDto, makeAllAttributeDtos(relatedGenesEngineResponseDto, featureWeightMap), featureWeightMap);
        }
    }

    private void setEmptyAttributeResponse(RelatedGenesEngineResponseDto relatedGenesEngineResponseDto) {
        relatedGenesEngineResponseDto.setNodeToAttributes(new HashMap());
    }

    private void addAttributesForSelectedNodes(RelatedGenesEngineResponseDto relatedGenesEngineResponseDto, Map<Long, AttributeDto> map, FeatureWeightMap featureWeightMap) throws ApplicationException {
        NodeIds nodeIds = this.cache.getNodeIds(this.request.getOrganismId());
        AttributeGroups attributeGroups = this.cache.getAttributeGroups(safeGetNamespace(), this.request.getOrganismId());
        HashMap hashMap = new HashMap();
        for (Feature feature : featureWeightMap.keySet()) {
            if (feature.getType() == Constants.NetworkType.ATTRIBUTE_VECTOR) {
                if (this.request.getProgressReporter().isCanceled()) {
                    throw new CancellationException();
                }
                long groupId = feature.getGroupId();
                AttributeDto attributeDto = map.get(Long.valueOf(feature.getId()));
                AttributeData attributeData = this.cache.getAttributeData(safeGetNamespace(), this.request.getOrganismId(), groupId);
                int indexForAttributeId = attributeGroups.getIndexForAttributeId(feature.getGroupId(), feature.getId());
                Matrix data = attributeData.getData();
                for (NodeDto nodeDto : relatedGenesEngineResponseDto.getNodes()) {
                    if (data.get(nodeIds.getIndexForId(nodeDto.getId()), indexForAttributeId) != 0.0d) {
                        updateNodeAttributes(nodeDto.getId(), hashMap, attributeDto);
                    }
                }
            }
        }
        relatedGenesEngineResponseDto.setNodeToAttributes(hashMap);
    }

    private void updateNodeAttributes(long j, Map<Long, Collection<AttributeDto>> map, AttributeDto attributeDto) {
        Collection<AttributeDto> collection = map.get(Long.valueOf(j));
        if (collection == null) {
            collection = new HashSet();
            map.put(Long.valueOf(j), collection);
        }
        collection.add(attributeDto);
    }

    private Map<Long, AttributeDto> makeAllAttributeDtos(RelatedGenesEngineResponseDto relatedGenesEngineResponseDto, FeatureWeightMap featureWeightMap) throws ApplicationException {
        HashMap hashMap = new HashMap();
        ArrayList arrayList = new ArrayList();
        for (Feature feature : featureWeightMap.keySet()) {
            if (feature.getType() == Constants.NetworkType.ATTRIBUTE_VECTOR && ((Double) featureWeightMap.get(feature)).doubleValue() > 0.0d) {
                if (this.request.getProgressReporter().isCanceled()) {
                    throw new CancellationException();
                }
                AttributeDto attributeDto = new AttributeDto();
                attributeDto.setId(feature.getId());
                attributeDto.setGroupId(feature.getGroupId());
                attributeDto.setWeight(((Double) featureWeightMap.get(feature)).doubleValue());
                hashMap.put(Long.valueOf(attributeDto.getId()), attributeDto);
                arrayList.add(attributeDto);
            }
        }
        relatedGenesEngineResponseDto.setAttributes(arrayList);
        return hashMap;
    }

    private double selectScoreThreshold(Constants.ScoringMethod scoringMethod) {
        return scoringMethod == Constants.ScoringMethod.ZSCORE ? Double.NEGATIVE_INFINITY : 0.0d;
    }

    private Vector convertScore(Constants.ScoringMethod scoringMethod, Vector vector, SymMatrix symMatrix, Vector vector2, double d, double d2) throws ApplicationException {
        Vector computeZScore;
        if (scoringMethod == Constants.ScoringMethod.DISCRIMINANT) {
            vector.set(MatrixUtils.rescale(vector));
            computeZScore = vector;
        } else {
            if (scoringMethod == Constants.ScoringMethod.CONTEXT) {
                throw new ApplicationException("context score no longer supported");
            }
            if (scoringMethod != Constants.ScoringMethod.ZSCORE) {
                throw new ApplicationException("Unexpected scoring method: " + scoringMethod);
            }
            computeZScore = computeZScore(vector, symMatrix, vector2, d, d2);
        }
        return computeZScore;
    }

    private Vector computeZScore(Vector vector, SymMatrix symMatrix, Vector vector2, double d, double d2) throws ApplicationException {
        Vector copy;
        logger.debug("computing z-score");
        DenseVector denseVector = new DenseVector(vector.size());
        symMatrix.columnSums(denseVector.getData());
        DenseMatrix denseMatrix = new DenseMatrix(vector.size(), 1);
        int i = 0;
        for (int i2 = 0; i2 < vector.size(); i2++) {
            if (denseVector.get(i2) > 0.0d) {
                denseMatrix.set(i2, 0, vector.get(i2));
                i++;
            } else {
                denseMatrix.set(i2, 0, Double.NaN);
            }
        }
        logger.debug("# of nodes with +ve degree in combined network: " + i);
        for (int i3 = 0; i3 < vector2.size(); i3++) {
            if (vector2.get(i3) == d) {
                logger.debug("clearing modes with postive label value for " + i3);
                denseMatrix.set(i3, 0, Double.NaN);
            }
        }
        Vector columnCountsIgnoreMissingData = MatrixUtils.columnCountsIgnoreMissingData(denseMatrix);
        if (columnCountsIgnoreMissingData.get(0) == 0.0d) {
            logger.info("no nodes connected to query nodes, special casing z-scores");
            copy = vector.copy();
            seteq(copy, Double.NEGATIVE_INFINITY);
            setmatches(d, vector2, 1.0d, copy);
        } else {
            Vector columnMeanIgnoreMissingData = MatrixUtils.columnMeanIgnoreMissingData(denseMatrix, columnCountsIgnoreMissingData);
            Vector columnVarianceIgnoreMissingData = MatrixUtils.columnVarianceIgnoreMissingData(denseMatrix, columnMeanIgnoreMissingData);
            MatrixUtils.sqrt(columnVarianceIgnoreMissingData);
            logger.debug("count, mean, stdev: " + columnCountsIgnoreMissingData.get(0) + ", " + columnMeanIgnoreMissingData.get(0) + ", " + columnVarianceIgnoreMissingData.get(0));
            copy = vector.copy();
            MatrixUtils.add(copy, -columnMeanIgnoreMissingData.get(0));
            copy.scale(1.0d / (columnVarianceIgnoreMissingData.get(0) + 0.01d));
            logger.debug("max of z-scores: " + MatrixUtils.max(copy));
        }
        return copy;
    }

    private static void setlt(Vector vector, double d, double d2) {
        int size = vector.size();
        for (int i = 0; i < size; i++) {
            if (vector.get(i) < d) {
                vector.set(i, d2);
            }
        }
    }

    private static void setge(Vector vector, double d, double d2) {
        int size = vector.size();
        for (int i = 0; i < size; i++) {
            if (vector.get(i) >= d) {
                vector.set(i, d2);
            }
        }
    }

    private static void seteq(Vector vector, double d) {
        int size = vector.size();
        for (int i = 0; i < size; i++) {
            vector.set(i, d);
        }
    }

    private static void setmatches(double d, Vector vector, double d2, Vector vector2) {
        int size = vector.size();
        for (int i = 0; i < size; i++) {
            if (vector.get(i) == d) {
                vector2.set(i, d2);
            }
        }
    }

    protected RelatedGenesEngineResponseDto prepareResponse(Vector vector, Vector vector2, FeatureWeightMap featureWeightMap, SymMatrix symMatrix, double d, Constants.ScoringMethod scoringMethod, CombiningMethod combiningMethod) throws ApplicationException {
        logPreparingOutputs();
        RelatedGenesEngineResponseDto relatedGenesEngineResponseDto = new RelatedGenesEngineResponseDto();
        List<Integer> indicesForIds = this.cache.getNodeIds(this.request.getOrganismId()).getIndicesForIds(this.request.getPositiveNodes());
        int[] indicesForTopScores = scoringMethod == Constants.ScoringMethod.CONTEXT ? MatrixUtils.getIndicesForTopScores(vector2, indicesForIds, this.request.getLimitResults(), d) : MatrixUtils.getIndicesForTopScores(vector, indicesForIds, this.request.getLimitResults(), d);
        logger.debug(String.format("number of nodes available for return: %d", Integer.valueOf(indicesForTopScores.length)));
        if (this.request.getProgressReporter().isCanceled()) {
            throw new CancellationException();
        }
        logger.debug("extracting source interactions");
        getSourceInteractions(relatedGenesEngineResponseDto, indicesForTopScores, vector, featureWeightMap);
        logger.debug("extracting attributes");
        encodeAttributes(relatedGenesEngineResponseDto, indicesForTopScores, featureWeightMap);
        relatedGenesEngineResponseDto.setCombiningMethodApplied(combiningMethod);
        return relatedGenesEngineResponseDto;
    }

    public Collection<InteractionDto> matrixToInteractions(SymMatrix symMatrix, int[] iArr, HashMap<Long, NodeDto> hashMap) throws ApplicationException {
        ArrayList arrayList = new ArrayList();
        NodeIds nodeIds = this.cache.getNodeIds(this.request.getOrganismId());
        for (int i = 0; i < iArr.length; i++) {
            for (int i2 = 0; i2 < i; i2++) {
                int i3 = iArr[i];
                int i4 = iArr[i2];
                long idForIndex = nodeIds.getIdForIndex(i3);
                long idForIndex2 = nodeIds.getIdForIndex(i4);
                double d = symMatrix.get(i3, i4);
                if (d != 0.0d) {
                    NodeDto nodeDto = hashMap.get(Long.valueOf(idForIndex));
                    NodeDto nodeDto2 = hashMap.get(Long.valueOf(idForIndex2));
                    if (nodeDto == null || nodeDto2 == null) {
                        throw new ApplicationException("mapping error");
                    }
                    InteractionDto interactionDto = new InteractionDto();
                    interactionDto.setNodeVO1(nodeDto);
                    interactionDto.setNodeVO2(nodeDto2);
                    interactionDto.setWeight(d);
                    arrayList.add(interactionDto);
                }
            }
        }
        return arrayList;
    }

    public void getSourceInteractions(RelatedGenesEngineResponseDto relatedGenesEngineResponseDto, int[] iArr, Vector vector, FeatureWeightMap featureWeightMap) throws ApplicationException {
        List<NetworkDto> arrayList = new ArrayList<>();
        HashMap<Long, NodeDto> hashMap = new HashMap<>();
        for (int i = 0; i < iArr.length; i++) {
            NodeDto nodeDto = new NodeDto();
            long idForIndex = this.cache.getNodeIds(this.request.getOrganismId()).getIdForIndex(iArr[i]);
            double d = vector.get(iArr[i]);
            nodeDto.setId(idForIndex);
            nodeDto.setScore(d);
            hashMap.put(Long.valueOf(idForIndex), nodeDto);
        }
        for (Feature feature : featureWeightMap.keySet()) {
            if (feature.getType() == Constants.NetworkType.SPARSE_MATRIX) {
                if (this.request.getProgressReporter().isCanceled()) {
                    throw new CancellationException();
                }
                Double d2 = (Double) featureWeightMap.get(feature);
                long id = feature.getId();
                if (d2 != null && d2.doubleValue() == 0.0d) {
                    logger.debug(String.format("network %s has zero weight, excluding from results", Long.valueOf(id)));
                }
                if (d2 != null && d2.doubleValue() != 0.0d) {
                    NetworkDto networkDto = new NetworkDto();
                    networkDto.setWeight(d2.doubleValue());
                    networkDto.setId(id);
                    networkDto.setInteractions(matrixToInteractions(this.cache.getNetwork(safeGetNamespace(), this.request.getOrganismId(), id).getData(), iArr, hashMap));
                    arrayList.add(networkDto);
                }
            }
        }
        relatedGenesEngineResponseDto.setNetworks(arrayList);
        ArrayList arrayList2 = new ArrayList();
        arrayList2.addAll(hashMap.values());
        relatedGenesEngineResponseDto.setNodes(arrayList2);
    }

    private void logQuery() {
        logger.info(String.format("findRelated query using combining method %s for organism %d contains %d nodes, %d network groups, %d networks, %d attribute groups, and requests a maximum of %d related nodes using a maximum of %d attributes per group", this.request.getCombiningMethod(), Long.valueOf(this.request.getOrganismId()), Integer.valueOf(this.request.getPositiveNodes().size()), Integer.valueOf(this.request.getInteractionNetworks().size()), Integer.valueOf(this.numRequestNetworks), Integer.valueOf(this.numRequestAttributeGroups), Integer.valueOf(this.request.getLimitResults()), Integer.valueOf(this.request.getAttributesLimit())));
    }

    private void logStart() {
        logger.info("processing findRelated() request");
        this.request.getProgressReporter().setMaximumProgress(5);
        this.request.getProgressReporter().setStatus("starting");
        this.request.getProgressReporter().setProgress(0);
    }

    private void logPreparingOutputs() {
        logger.info("preparing outputs for findRelated() request");
        this.request.getProgressReporter().setDescription(Constants.PROGRESS_OUTPUTS_MESSAGE);
        this.request.getProgressReporter().setProgress(4);
    }

    private void logEnd() {
        logger.info("completed processing request, duration = " + Logging.duration(this.requestStartTimeMillis, this.requestEndTimeMillis));
        this.request.getProgressReporter().setStatus("done");
        this.request.getProgressReporter().setProgress(5);
    }

    private void logNodeScores(int[] iArr, Vector vector) throws ApplicationException {
        if (logger.isDebugEnabled()) {
            NodeIds nodeIds = this.cache.getNodeIds(this.request.getOrganismId());
            for (int i = 0; i < iArr.length; i++) {
                logger.debug(String.format("Node %d as a score of %f", Long.valueOf(nodeIds.getIdForIndex(iArr[i])), Double.valueOf(vector.get(iArr[i]))));
            }
        }
    }

    public void checkQuery() throws ApplicationException {
        if (this.request.getPositiveNodes() == null || this.request.getPositiveNodes().size() == 0) {
            throw new ApplicationException("No query nodes given");
        }
        if ((this.request.getInteractionNetworks() == null || this.request.getInteractionNetworks().size() == 0) && (this.request.getAttributeGroups() == null || this.request.getAttributeGroups().size() == 0)) {
            throw new ApplicationException("No query networks or attributes given");
        }
        checkNodes(this.request.getOrganismId(), this.request.getPositiveNodes());
        this.hasUserNetworks = DataSupport.queryHasUserNetworks(this.request.getInteractionNetworks());
        this.hasUserAttributes = DataSupport.queryHasUserAttributes(this.request.getAttributeGroups());
        this.numRequestNetworks = checkNetworks(safeGetNamespace(), this.request.getOrganismId(), this.request.getInteractionNetworks());
        if (this.request.getAttributeGroups() == null) {
            this.request.setAttributeGroups(new ArrayList());
        }
        this.numRequestAttributeGroups = checkAttributeGroups(safeGetNamespace(), this.request.getOrganismId(), this.request.getAttributeGroups());
    }

    protected void checkNodes(long j, Collection<Long> collection) throws ApplicationException {
        if (collection.size() == 0) {
            throw new ApplicationException("the list of nodes in the request is empty");
        }
        HashSet hashSet = new HashSet();
        NodeIds nodeIds = this.cache.getNodeIds(j);
        for (Long l : collection) {
            if (hashSet.contains(l)) {
                throw new ApplicationException(String.format("the node id %d was passed multiple times in request", l));
            }
            try {
                nodeIds.getIndexForId(l.longValue());
            } catch (ApplicationException e) {
                throw new ApplicationException(String.format("node id %d is not valid for organism id %d", l, Long.valueOf(j)));
            }
        }
    }

    protected int checkNetworks(String str, long j, Collection<Collection<Long>> collection) throws ApplicationException {
        HashSet hashSet = new HashSet();
        NetworkIds networkIds = this.cache.getNetworkIds(str, j);
        Iterator<Collection<Long>> it = collection.iterator();
        while (it.hasNext()) {
            for (Long l : it.next()) {
                if (hashSet.contains(l)) {
                    throw new ApplicationException(String.format("the network id %d was passed multiple times in request", l));
                }
                hashSet.add(l);
                long longValue = l.longValue();
                if (longValue > 2147483647L || longValue < -2147483648L) {
                    throw new ApplicationException(String.format("network ids must be in integer range, got id: %d", l));
                }
                if (longValue < 0) {
                    if (str == null) {
                        throw new ApplicationException(String.format("no namespace provided for user network %d organism %d", l, Long.valueOf(j)));
                    }
                    logger.warn("skipping validation check on user network: " + longValue);
                }
                try {
                    networkIds.getIndexForId(longValue);
                } catch (ApplicationException e) {
                    throw new ApplicationException(String.format("network id %d is not valid for organism id %d", l, Long.valueOf(j)));
                }
            }
        }
        return hashSet.size();
    }

    protected int checkAttributeGroups(String str, long j, Collection<Long> collection) throws ApplicationException {
        if (collection.size() == 0) {
            return 0;
        }
        if (collection.size() != new HashSet(collection).size()) {
            throw new ApplicationException("the list of attribute groups contains duplicates");
        }
        HashMap<Long, ArrayList<Long>> attributeGroups = this.cache.getAttributeGroups(str, j).getAttributeGroups();
        Iterator<Long> it = collection.iterator();
        while (it.hasNext()) {
            long longValue = it.next().longValue();
            if (!attributeGroups.containsKey(Long.valueOf(longValue))) {
                throw new ApplicationException(String.format("organism %d in namspace '%s' does not contain the attribute group %d", Long.valueOf(j), str, Long.valueOf(longValue)));
            }
        }
        return collection.size();
    }

    private String safeGetNamespace() {
        String namespace = this.request.getNamespace();
        return (namespace == null || namespace.equals("")) ? Data.CORE : (this.hasUserNetworks || this.hasUserAttributes) ? namespace : Data.CORE;
    }

    private void dumpNumbers(String str, Vector vector, Vector vector2, Vector vector3) {
        try {
            logger.info("dumping to " + str);
            FileWriter fileWriter = new FileWriter(new File(str));
            int size = vector.size();
            fileWriter.write("node\tdiscriminant\tlabels\tdegrees\n");
            for (int i = 0; i < size; i++) {
                fileWriter.write(String.format("%d\t%.15e\t%.15e\t%.15e\n", Integer.valueOf(i), Double.valueOf(vector.get(i)), Double.valueOf(vector2.get(i)), Double.valueOf(vector3.get(i))));
            }
            fileWriter.close();
        } catch (Exception e) {
            logger.warn("failed to dump data", e);
        }
    }

    public static void logInteractions(long j, Collection<InteractionDto> collection) {
        logger.debug("interactions for network " + j);
        for (InteractionDto interactionDto : collection) {
            logger.debug(String.format("   %d %d %f", Long.valueOf(interactionDto.getNodeVO1().getId()), Long.valueOf(interactionDto.getNodeVO2().getId()), Double.valueOf(interactionDto.getWeight())));
        }
    }
}
