Source code for nsdlib.algorithms.algorithms_utils

from functools import lru_cache
from typing import Dict, List, Set, Union

from cdlib import NodeClustering
from netcenlib.common import nx_cached
from netcenlib.common.nx_cached import MAX_SIZE
from networkx import Graph

from nsdlib.algorithms import evaluation, outbreaks, reconstruction
from nsdlib.common.models import NODE_TYPE, SourceDetectionEvaluation
from nsdlib.taxonomies import (
    NodeEvaluationAlgorithm,
    OutbreaksDetectionAlgorithm,
    PropagationReconstructionAlgorithm,
)


[docs] def node_clustering_into_communities(result: NodeClustering) -> Dict[NODE_TYPE, list]: """Convert the node clustering result into a dictionary.""" return {index: community for index, community in enumerate(result.communities)}
[docs] def identify_outbreaks( network: Graph, outbreaks_alg: OutbreaksDetectionAlgorithm, *args, **kwargs ) -> Dict[NODE_TYPE, list]: """Identify outbreaks in a given network.""" function_name = f"{outbreaks_alg.value.lower()}" result = getattr(outbreaks, function_name)(network, *args, **kwargs) return node_clustering_into_communities(result)
[docs] def evaluate_nodes( network: Graph, evaluation_alg: NodeEvaluationAlgorithm, *args, **kwargs ): """Evaluate nodes in a given network.""" function_name = f"{evaluation_alg.value.lower()}" return getattr(evaluation, function_name)(network, *args, **kwargs)
[docs] def reconstruct_propagation( G: Graph, IG: Graph, reconstruction_alg: PropagationReconstructionAlgorithm, *args, **kwargs, ): """Reconstruct the propagation of a given network.""" function_name = f"{reconstruction_alg.value.lower()}" return getattr(reconstruction, function_name)(G, IG, *args, **kwargs)
[docs] @lru_cache(maxsize=MAX_SIZE) def identify_outbreaks_cached( network: Graph, outbreaks_alg: OutbreaksDetectionAlgorithm, *args, **kwargs ) -> Dict[int, list]: """Identify outbreaks in a given network.""" return identify_outbreaks(network, outbreaks_alg, *args, **kwargs)
[docs] @lru_cache(maxsize=MAX_SIZE) def evaluate_nodes_cached( network: Graph, evaluation_alg: NodeEvaluationAlgorithm, *args, **kwargs ): """Evaluate nodes in a given network.""" return evaluate_nodes(network, evaluation_alg, *args, **kwargs)
[docs] @lru_cache(maxsize=MAX_SIZE) def reconstruct_propagation_cached( G: Graph, IG: Graph, reconstruction_alg: PropagationReconstructionAlgorithm, *args, **kwargs, ): """Reconstruct the propagation of a given network.""" return reconstruct_propagation(G, IG, reconstruction_alg, *args, **kwargs)
[docs] def compute_error_distances( G: Graph, not_detected_sources: Set[int], invalid_detected_sources: Set[int] ) -> Dict[NODE_TYPE, float]: """Compute the error distances for the source detection evaluation.""" if not_detected_sources and invalid_detected_sources: return { source: min( [ nx_cached.shortest_path_length( G, source=source, target=invalid_source ) for invalid_source in invalid_detected_sources ] ) for source in not_detected_sources } else: return {}
[docs] def compute_source_detection_evaluation( G: Graph, real_sources: List[NODE_TYPE], detected_sources: Union[NODE_TYPE, List[NODE_TYPE]], ) -> SourceDetectionEvaluation: """Compute the evaluation of the source detection.""" detected_sources = ( detected_sources if isinstance(detected_sources, list) else [detected_sources] ) correctly_detected_sources = set(real_sources).intersection(detected_sources) invalid_detected_sources = set(detected_sources).difference( correctly_detected_sources ) not_detected_sources = set(real_sources).difference(correctly_detected_sources) P = len(real_sources) N = len(G.nodes) - P FP = len(invalid_detected_sources) TP = len(correctly_detected_sources) FN = len(real_sources) - TP TN = N - FN error_distances = compute_error_distances( G=G, not_detected_sources=not_detected_sources, invalid_detected_sources=invalid_detected_sources, ) return SourceDetectionEvaluation( real_sources=real_sources, detected_sources=detected_sources, error_distances=error_distances, TP=TP, FP=FP, TN=TN, FN=FN, P=P, N=N, )