Source code for rastervision.pytorch_learner.object_detection_learner

from typing import (TYPE_CHECKING, Dict, Iterable, List, Optional, Union)
import warnings

import logging

import numpy as np

from rastervision.pytorch_learner.learner import Learner
from rastervision.pytorch_learner.object_detection_utils import (
    BoxList, TorchVisionODAdapter, compute_coco_eval, collate_fn)
from rastervision.pytorch_learner.dataset.visualizer import (
    ObjectDetectionVisualizer)

if TYPE_CHECKING:
    from torch import nn

warnings.filterwarnings('ignore')

log = logging.getLogger(__name__)


[docs]class ObjectDetectionLearner(Learner):
[docs] def get_visualizer_class(self): return ObjectDetectionVisualizer
[docs] def build_model(self, model_def_path: Optional[str] = None) -> 'nn.Module': """Override to pass img_sz.""" cfg = self.cfg model = cfg.model.build( num_classes=cfg.data.num_classes, in_channels=cfg.data.img_channels, save_dir=self.modules_dir, hubconf_dir=model_def_path, img_sz=cfg.data.img_sz) return model
[docs] def setup_model(self, model_weights_path: Optional[str] = None, model_def_path: Optional[str] = None) -> None: """Override to apply the TorchVisionODAdapter wrapper.""" if self.model is not None: self.model.to(self.device) return model = self.build_model(model_def_path) if self.cfg.model.external_def is not None: # this model will have 1 extra output classes that we will ignore self.model = TorchVisionODAdapter(model, ignored_output_inds=[0]) else: # this model will have 2 extra output classes that we will ignore num_classes = self.cfg.data.num_classes self.model = TorchVisionODAdapter( model, ignored_output_inds=[0, num_classes + 1]) self.model.to(self.device) self.load_init_weights(model_weights_path)
[docs] def build_metric_names(self): metric_names = [ 'epoch', 'train_time', 'valid_time', 'train_loss', 'map', 'map50' ] return metric_names
[docs] def get_collate_fn(self): return collate_fn
[docs] def train_step(self, batch, batch_ind): x, y = batch loss_dict = self.model(x, y) return {'train_loss': loss_dict['total_loss']}
[docs] def validate_step(self, batch, batch_ind): x, y = batch outs = self.model(x) ys = self.to_device(y, 'cpu') outs = self.to_device(outs, 'cpu') return {'ys': ys, 'outs': outs}
[docs] def validate_end(self, outputs, num_samples): outs = [] ys = [] for o in outputs: outs.extend(o['outs']) ys.extend(o['ys']) num_class_ids = len(self.cfg.data.class_names) coco_eval = compute_coco_eval(outs, ys, num_class_ids) metrics = {'map': 0.0, 'map50': 0.0} if coco_eval is not None: coco_metrics = coco_eval.stats metrics = {'map': coco_metrics[0], 'map50': coco_metrics[1]} return metrics
[docs] def output_to_numpy( self, out: Iterable[BoxList] ) -> Union[Dict[str, np.ndarray], List[Dict[str, np.ndarray]]]: def boxlist_to_numpy(boxlist: BoxList) -> Dict[str, np.ndarray]: return { 'boxes': boxlist.convert_boxes('yxyx').numpy(), 'class_ids': boxlist.get_field('class_ids').numpy(), 'scores': boxlist.get_field('scores').numpy() } if isinstance(out, BoxList): return boxlist_to_numpy(out) else: return [boxlist_to_numpy(boxlist) for boxlist in out]
[docs] def prob_to_pred(self, x): return x