Source code for rastervision.core.evaluation.semantic_segmentation_evaluation

from typing import TYPE_CHECKING
import logging

import numpy as np
from sklearn.metrics import confusion_matrix
from tqdm.auto import tqdm

from rastervision.core.evaluation import ClassEvaluationItem
from rastervision.core.evaluation import ClassificationEvaluation

if TYPE_CHECKING:
    from rastervision.core.data import (ClassConfig,
                                        SemanticSegmentationLabels)

log = logging.getLogger(__name__)


[docs]class SemanticSegmentationEvaluation(ClassificationEvaluation): """Evaluation for semantic segmentation."""
[docs] def __init__(self, class_config: 'ClassConfig'): super().__init__() self.class_config = class_config
[docs] def compute(self, gt_labels: 'SemanticSegmentationLabels', pred_labels: 'SemanticSegmentationLabels') -> None: self.reset() # compute confusion matrix null_class_id = self.class_config.null_class_id num_classes = len(self.class_config) labels = np.arange(num_classes) self.conf_mat = np.zeros((num_classes, num_classes)) windows = pred_labels.get_windows() with tqdm(windows, delay=5, desc='Computing metrics') as bar: for window in bar: gt_arr = gt_labels.get_label_arr(window, null_class_id) pred_arr = pred_labels.get_label_arr(window, null_class_id) self.conf_mat += confusion_matrix( gt_arr.ravel(), pred_arr.ravel(), labels=labels) for class_id, class_name in enumerate(self.class_config.names): eval_item = ClassEvaluationItem.from_multiclass_conf_mat( conf_mat=self.conf_mat, class_id=class_id, class_name=class_name) self.class_to_eval_item[class_id] = eval_item self.compute_avg()