Source code for rastervision.core.evaluation.class_evaluation_item

"""Defines ``ClassEvaluationItem``."""

from typing import Optional

import numpy as np

from rastervision.core.evaluation import EvaluationItem


[docs]class ClassEvaluationItem(EvaluationItem): """A wrapper around a binary (2x2) confusion matrix of the form .. line-block:: [``TN`` ``FP``] [``FN`` ``TP``] where ``TN`` need not necessarily be available. Exposes evaluation metrics computed from the confusion matrix as properties. Attributes: class_id (int): Class ID. class_name (str): Class name. conf_mat (np.ndarray): Confusion matrix: ``[[TN, FP], [FN, TP]]``. extra_info (dict): Arbitrary extra key-value pairs that will be included in the dict returned by ``to_json()``. """
[docs] def __init__(self, class_id: int, class_name: str, tp: int, fp: int, fn: int, tn: Optional[int] = None, **kwargs): """Constructor. Args: class_id (int): Class ID. class_name (str): Class name. tp (int): True positive count. fp (int): False positive count. fn (int): False negative count. tn (Optional[int], optional): True negative count. Defaults to None. **kwargs: Additional data can be provided as keyword arguments. These will be included in the dict returned by ``to_json()``. """ self.class_id = class_id self.class_name = class_name if tn is None: tn = -1 self.conf_mat = np.array([[tn, fp], [fn, tp]]) self.extra_info = kwargs
[docs] @classmethod def from_multiclass_conf_mat(cls, conf_mat: np.ndarray, class_id: int, class_name: str, **kwargs) -> 'ClassEvaluationItem': """Construct from a multi-class confusion matrix and a target class ID. Args: conf_mat (np.ndarray): A multi-class confusion matrix. class_id (int): The ID of the target class. class_name (str): The name of the target class. **kwargs: Extra args for :meth:`.__init__`. Returns: ClassEvaluationItem: ClassEvaluationItem for target class. """ tp = conf_mat[class_id, class_id] fp = conf_mat[:, class_id].sum() - tp fn = conf_mat[class_id, :].sum() - tp tn = conf_mat.sum() - tp - fp - fn item = cls( class_id=class_id, class_name=class_name, tp=tp, fp=fp, fn=fn, tn=tn, **kwargs) return item
[docs] def merge(self, other: 'ClassEvaluationItem') -> None: """Merge with another ``ClassEvaluationItem``. This is accomplished by summing the confusion matrices. """ if self.class_id != other.class_id: raise ValueError( 'Cannot merge evaluation items for different classes.') self.conf_mat += other.conf_mat
@property def gt_count(self) -> int: """Positive ground-truth count.""" return self.conf_mat[1, :].sum() @property def pred_count(self) -> int: """Positive prediction count.""" return self.conf_mat[:, 1].sum() @property def true_pos(self) -> int: """True positive count.""" return self.conf_mat[1, 1] @property def true_neg(self) -> Optional[int]: """True negative count. Returns: Optional[int]: Count as int if available. Otherwise, None. """ tn = self.conf_mat[0, 0] if tn < 0: return None return tn @property def false_pos(self) -> int: """False positive count.""" return self.conf_mat[0, 1] @property def false_neg(self) -> int: """False negative count.""" return self.conf_mat[1, 0] @property def recall(self) -> float: """``TP / (TP + FN)``""" tp = self.true_pos fn = self.false_neg return float(tp) / (tp + fn) @property def sensitivity(self) -> float: """Equivalent to ``recall``.""" return self.recall @property def specificity(self) -> Optional[float]: """``TN / (TN + FP)``""" if self.true_neg is None: return None tn = self.true_neg fp = self.false_pos return float(tn) / (tn + fp) @property def precision(self) -> float: """``TP / (TP + FP)``""" tp = self.true_pos fp = self.false_pos return float(tp) / (tp + fp) @property def f1(self) -> float: """F1 score = ``2 * (precision * recall) / (precision + recall)``""" precision = self.precision recall = self.recall return 2 * (precision * recall) / (precision + recall)
[docs] def to_json(self) -> dict: """Serialize to a dict.""" out = { 'class_id': self.class_id, 'class_name': self.class_name, 'gt_count': self.gt_count, 'pred_count': self.pred_count, 'count_error': abs(self.gt_count - self.pred_count), 'relative_frequency': self.gt_count / self.conf_mat.sum(), 'metrics': { 'recall': self.recall, 'precision': self.precision, 'f1': self.f1, 'sensitivity': self.sensitivity, 'specificity': self.specificity, } } if self.true_neg is None: del out['relative_frequency'] out['true_pos'] = self.true_pos out['false_pos'] = self.false_pos out['false_neg'] = self.false_neg else: cm = self.conf_mat cm_frac = cm / cm.sum() out['conf_mat'] = cm.tolist() out['conf_mat_frac'] = cm_frac.tolist() [[TN, FP], [FN, TP]] = cm out['conf_mat_dict'] = dict(TN=TN, FP=FP, FN=FN, TP=TP) [[TN, FP], [FN, TP]] = cm_frac out['conf_mat_frac_dict'] = dict(TN=TN, FP=FP, FN=FN, TP=TP) out.update(self.extra_info) return out