from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union
from networkx import Graph
from nsdlib.taxonomies import (
EnsembleVotingType,
NodeEvaluationAlgorithm,
OutbreaksDetectionAlgorithm,
PropagationReconstructionAlgorithm,
)
NODE_TYPE = Union[int, str]
[docs]
@dataclass
class SelectionAlgorithm:
selection_method: Optional[NodeEvaluationAlgorithm] = None
selection_threshold: Optional[float] = None
def __post_init__(self):
if self.selection_threshold is not None and not (
0 <= self.selection_threshold <= 1
):
raise ValueError("selection_threshold must be None or between 0 and 1.")
if self.selection_method and self.selection_threshold:
raise ValueError(
"selection_method and selection_threshold cannot be used together."
)
[docs]
@dataclass
class SourceDetectionConfig:
"""Source detection configuration."""
node_evaluation_algorithm: NodeEvaluationAlgorithm = (
NodeEvaluationAlgorithm.CENTRALITY_DEGREE
)
selection_algorithm: Optional[SelectionAlgorithm] = None
outbreaks_detection_algorithm: Optional[OutbreaksDetectionAlgorithm] = None
propagation_reconstruction_algorithm: Optional[
PropagationReconstructionAlgorithm
] = None
def __post_init__(self):
if not self.selection_algorithm:
self.selection_algorithm = SelectionAlgorithm()
[docs]
@dataclass
class EnsembleSourceDetectionConfig:
"""Ensemble source detection configuration."""
detection_configs: List[SourceDetectionConfig] = field(default_factory=list)
voting_type: EnsembleVotingType = EnsembleVotingType.HARD
classifier_weights: List[float] = field(default_factory=list)
CLASSIFICATION_REPORT_FIELDS = (
"P",
"N",
"TP",
"TN",
"FP",
"FN",
"ACC",
"F1",
"TPR",
"TNR",
"PPV",
"NPV",
"FNR",
"FPR",
"FDR",
"FOR",
"TS",
)
[docs]
@dataclass
class ClassificationMetrics:
"""Confusion matrix representation.
It is based on https://en.wikipedia.org/wiki/Confusion_matrix.
"""
TP: int # true positive
TN: int # true negative (TN)
FP: int # false positive (FP)
FN: int # false negative (FN)
P: int # condition positive (P) - the number of real positive cases in
# the data
N: int # condition negative (N) - the number of real negative cases in
# the data
@property
def confusion_matrix(self) -> List[List[float]]:
"""Confusion matrix."""
return [[self.TP, self.FP], [self.FN, self.TN]]
@property
def TPR(self):
"""Sensitivity, recall, hit rate, or true positive rate (TPR)."""
return self.TP / self.P
@property
def TNR(self):
"""Specificity, selectivity or true negative rate (TNR)."""
return self.TN / self.N
@property
def PPV(self):
"""Precision or positive predictive value (PPV)."""
return self.TP / (self.TP + self.FP)
@property
def NPV(self):
"""Negative predictive value (NPV)."""
return self.TN / (self.TN + self.FN)
@property
def FNR(self):
"""
Miss rate or false negative rate (FNR).
"""
return self.TN / (self.TN + self.FN)
@property
def FPR(self):
"""Fall-out or false positive rate (FPR)."""
return self.FP / (self.FP + self.TN)
@property
def FDR(self):
"""False discovery rate (FDR).""" # noqa
return self.FP / (self.FP + self.TP)
@property
def FOR(self):
"""False omission rate (FOR).""" # noqa
return self.FN / (self.FN + self.TN)
@property
def TS(self):
"""False omission rate (FOR).""" # noqa
return self.TP / (self.TP + self.FN + self.FP)
@property
def ACC(self):
"""
Accuracy (ACC).
"""
return (self.TP + self.TN) / (self.P + self.N)
@property
def F1(self):
"""F1 score."""
return (
0
if self.PPV + self.TPR == 0
else 2 * self.PPV * self.TPR / (self.PPV + self.TPR)
)
[docs]
def get_classification_report(self) -> Dict[str, float]:
"""Classification report as string."""
return {attr: getattr(self, attr) for attr in CLASSIFICATION_REPORT_FIELDS}
[docs]
@dataclass
class SourceDetectionEvaluation(ClassificationMetrics):
real_sources: List[NODE_TYPE]
detected_sources: List[NODE_TYPE]
# shortest path length from the detected invalid source to the closest
# real source
error_distances: Dict[NODE_TYPE, float]
@property
def avg_error_distance(self) -> float:
"""Average error distance."""
return sum(self.error_distances.values()) / len(self.error_distances)
[docs]
@dataclass
class SourceDetectionResult:
config: SourceDetectionConfig
G: Graph
IG: Graph
global_scores: Dict[NODE_TYPE, float]
scores_in_outbreaks: List[Dict[NODE_TYPE, float]]
detected_sources: List[NODE_TYPE]
[docs]
@dataclass
class EnsembleSourceDetectionResult:
config: EnsembleSourceDetectionConfig
G: Graph
IG: Graph
global_scores: Dict[NODE_TYPE, float]
ensemble_scores: List[SourceDetectionResult]
detected_sources: List[NODE_TYPE]