Source code for rastervision.core.evaluation.chip_classification_evaluation
from typing import TYPE_CHECKING
import numpy as np
from sklearn.metrics import confusion_matrix
from rastervision.core.evaluation import (ClassificationEvaluation,
ClassEvaluationItem)
if TYPE_CHECKING:
from rastervision.core.data import (ChipClassificationLabels, ClassConfig)
[docs]class ChipClassificationEvaluation(ClassificationEvaluation):
[docs] def __init__(self, class_config: 'ClassConfig'):
super().__init__()
self.class_config = class_config
[docs] def compute(self, gt_labels: 'ChipClassificationLabels',
pred_labels: 'ChipClassificationLabels') -> None:
self.reset()
self.class_to_eval_item = {}
gt_class_ids = []
pred_class_ids = []
for gt_cell in gt_labels.get_cells():
gt_class_id = gt_labels.get_cell_class_id(gt_cell)
pred_class_id = pred_labels.get_cell_class_id(gt_cell)
if gt_class_id is not None and pred_class_id is not None:
gt_class_ids.append(gt_class_id)
pred_class_ids.append(pred_class_id)
labels = np.arange(len(self.class_config))
self.conf_mat = confusion_matrix(
gt_class_ids, pred_class_ids, labels=labels)
for class_id, class_name in enumerate(self.class_config.names):
eval_item = ClassEvaluationItem.from_multiclass_conf_mat(
conf_mat=self.conf_mat,
class_id=class_id,
class_name=class_name)
self.class_to_eval_item[class_id] = eval_item
self.compute_avg()