Source code for nsdlib.source_detection

"""Source detection algorithm."""

from typing import Dict, List, Tuple

from networkx import Graph

from nsdlib.algorithms.algorithms_utils import (
    compute_source_detection_evaluation,
    evaluate_nodes_cached,
    identify_outbreaks_cached,
    reconstruct_propagation_cached,
)
from nsdlib.common.models import (
    NODE_TYPE,
    EnsembleSourceDetectionConfig,
    EnsembleSourceDetectionResult,
    SourceDetectionConfig,
    SourceDetectionEvaluation,
    SourceDetectionResult,
)
from nsdlib.common.nx_utils import create_subgraphs_based_on_outbreaks
from nsdlib.commons import normalize_dict_values
from nsdlib.taxonomies import EnsembleVotingType


[docs] class SourceDetector: """Source detection generic algorithm.""" def __init__(self, config: SourceDetectionConfig): self.config = config
[docs] def detect_sources(self, IG: Graph, G: Graph) -> SourceDetectionResult: IG = self._reconstruct_propagation(IG, G) outbreaks = self._detect_outbreaks(IG) scores_in_outbreaks = self._evaluate_outbreaks(outbreaks) global_scores = self._get_global_scores(scores_in_outbreaks) detected_sources = self._select_sources(IG, scores_in_outbreaks) return SourceDetectionResult( config=self.config, G=G, IG=IG, global_scores=global_scores, scores_in_outbreaks=scores_in_outbreaks, detected_sources=detected_sources, )
[docs] def detect_sources_and_evaluate( self, IG: Graph, G: Graph, real_sources: List[NODE_TYPE] ) -> Tuple[SourceDetectionResult, SourceDetectionEvaluation]: sd_result = self.detect_sources(IG, G) evaluation = compute_source_detection_evaluation( G=sd_result.IG, real_sources=real_sources, detected_sources=sd_result.detected_sources, ) return sd_result, evaluation
def _reconstruct_propagation(self, IG, G): if self.config.propagation_reconstruction_algorithm: IG = reconstruct_propagation_cached( G=G, IG=IG, reconstruction_alg=self.config.propagation_reconstruction_algorithm, ) return IG def _detect_outbreaks(self, IG): outbreaks = [IG] if self.config.outbreaks_detection_algorithm: outbreaks = identify_outbreaks_cached( network=IG, outbreaks_alg=self.config.outbreaks_detection_algorithm, ) outbreaks = [ subgraph for subgraph in create_subgraphs_based_on_outbreaks( G=IG, outbreaks=outbreaks ) ] return outbreaks def _get_global_scores(self, outbreaks_evaluation: List[Dict[NODE_TYPE, float]]): global_scores = {} for outbreak_evaluation in outbreaks_evaluation: for node, evaluation in outbreak_evaluation.items(): global_scores[node] = evaluation return global_scores def _evaluate_outbreaks( self, outbreaks: List[Graph] ) -> List[Dict[NODE_TYPE, float]]: scores = [] for outbreak in outbreaks: scores.append( evaluate_nodes_cached( network=outbreak, evaluation_alg=self.config.node_evaluation_algorithm, ) ) return scores def _select_sources( self, IG: Graph, outbreaks_evaluation: List[Dict[NODE_TYPE, float]] ): sources = [] for outbreak_evaluation in outbreaks_evaluation: if self.config.selection_algorithm.selection_method: max_score = max(outbreak_evaluation.values()) nodes_with_higher_score = [ node for node, score in outbreak_evaluation.items() if score == max_score ] if len(nodes_with_higher_score) == 1: sources.append(nodes_with_higher_score[0]) else: outbreak_nodes = list(outbreak_evaluation.keys()) subgraph = IG.subgraph(outbreak_nodes) selection_evaluation = evaluate_nodes_cached( network=subgraph, evaluation_alg=self.config.selection_algorithm.selection_method, ) filtered_second_evaluation = { node: selection_evaluation[node] for node in nodes_with_higher_score } max_second_score = max(filtered_second_evaluation.values()) sources.extend( [ node for node, score in filtered_second_evaluation.items() if score == max_second_score ] ) elif self.config.selection_algorithm.selection_threshold is None: sources.append(max(outbreak_evaluation, key=outbreak_evaluation.get)) else: outbreaks_evaluation_normalized = normalize_dict_values( outbreak_evaluation ) sources.extend( [ node for node, evaluation in outbreaks_evaluation_normalized.items() if evaluation >= self.config.selection_algorithm.selection_threshold ] ) return sources
[docs] class EnsembleSourceDetector: """Ensemble source detection algorithm.""" def __init__(self, config: EnsembleSourceDetectionConfig): self.config = config
[docs] def detect_sources(self, IG: Graph, G: Graph) -> List[SourceDetectionResult]: return [ SourceDetector(config).detect_sources(IG, G) for config in self.config.detection_configs ]
[docs] def detect_sources_and_evaluate( self, IG: Graph, G: Graph, real_sources: List[NODE_TYPE] ) -> Tuple[EnsembleSourceDetectionResult, SourceDetectionEvaluation]: sd_results = self.detect_sources(IG, G) ensemble_result = self._combine_results(sd_results) evaluation = compute_source_detection_evaluation( G=ensemble_result.IG, real_sources=real_sources, detected_sources=ensemble_result.detected_sources, ) return ensemble_result, evaluation
def _combine_results( self, results: List[SourceDetectionResult] ) -> EnsembleSourceDetectionResult: if self.config.voting_type == EnsembleVotingType.SOFT: return self._soft_voting(results) else: return self._hard_voting(results) def _soft_voting( self, results: List[SourceDetectionResult] ) -> EnsembleSourceDetectionResult: combined_scores = {} for result in results: for node, score in result.global_scores.items(): if node not in combined_scores: combined_scores[node] = 0 combined_scores[node] += score * ( self.config.classifier_weights[results.index(result)] if self.config.classifier_weights else 1 ) total_weight = ( sum(self.config.classifier_weights) if self.config.classifier_weights else len(results) ) for node in combined_scores: combined_scores[node] /= total_weight detected_sources = [ k for k, v in sorted( combined_scores.items(), key=lambda item: item[1], reverse=True ) ] return EnsembleSourceDetectionResult( config=self.config, G=results[0].G, IG=results[0].IG, global_scores=combined_scores, ensemble_scores=results, detected_sources=detected_sources, ) def _hard_voting( self, results: List[SourceDetectionResult] ) -> EnsembleSourceDetectionResult: vote_counts = {} for result in results: for node in result.detected_sources: if node not in vote_counts: vote_counts[node] = 0 vote_counts[node] += 1 * ( self.config.classifier_weights[results.index(result)] if self.config.classifier_weights else 1 ) detected_sources = [ k for k, v in sorted( vote_counts.items(), key=lambda item: item[1], reverse=True ) ] return EnsembleSourceDetectionResult( config=self.config, G=results[0].G, IG=results[0].IG, global_scores=vote_counts, ensemble_scores=results, detected_sources=detected_sources, )