Source code for rastervision.pytorch_learner.object_detection_learner

from typing import (TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple,
                    Union)
import warnings

import logging

import numpy as np
import torch

from rastervision.pytorch_learner.learner import Learner
from rastervision.pytorch_learner.object_detection_utils import (
    BoxList, TorchVisionODAdapter, compute_coco_eval, collate_fn,
    ONNXRuntimeAdapterForFasterRCNN)
from rastervision.pytorch_learner.dataset.visualizer import (
    ObjectDetectionVisualizer)

if TYPE_CHECKING:
    from torch import nn, Tensor

warnings.filterwarnings('ignore')

log = logging.getLogger(__name__)


[docs]class ObjectDetectionLearner(Learner):
[docs] def get_visualizer_class(self): return ObjectDetectionVisualizer
[docs] def build_model(self, model_def_path: Optional[str] = None) -> 'nn.Module': """Override to pass img_sz.""" cfg = self.cfg model = cfg.model.build( num_classes=cfg.data.num_classes, in_channels=cfg.data.img_channels, save_dir=self.modules_dir, hubconf_dir=model_def_path, img_sz=cfg.data.img_sz) return model
[docs] def setup_model(self, model_weights_path: Optional[str] = None, model_def_path: Optional[str] = None) -> None: """Override to apply the TorchVisionODAdapter wrapper.""" if self.model is not None: self.model.to(self.device) return self._onnx_mode = (model_weights_path is not None and model_weights_path.lower().endswith('.onnx')) if self._onnx_mode: model = self.load_onnx_model(model_weights_path) else: model = self.build_model(model_def_path) if self.cfg.model.external_def is not None: # this model will have 1 extra output classes that we will ignore self.model = TorchVisionODAdapter(model, ignored_output_inds=[0]) else: # this model will have 2 extra output classes that we will ignore num_classes = self.cfg.data.num_classes self.model = TorchVisionODAdapter( model, ignored_output_inds=[0, num_classes + 1]) if not self._onnx_mode: self.model.to(self.device) self.load_init_weights(model_weights_path)
[docs] def get_collate_fn(self): return collate_fn
[docs] def train_step(self, batch, batch_ind): x, y = batch loss_dict = self.model(x, y) loss_dict['train_loss'] = sum(loss_dict.values()) return loss_dict
[docs] def validate_step(self, batch, batch_ind): x, y = batch outs = self.model(x) ys = self.to_device(y, 'cpu') outs = self.to_device(outs, 'cpu') return {'ys': ys, 'outs': outs}
[docs] def validate_end(self, outputs): outs = [] ys = [] for o in outputs: outs.extend(o['outs']) ys.extend(o['ys']) num_class_ids = len(self.cfg.data.class_names) coco_eval = compute_coco_eval(outs, ys, num_class_ids) metrics = {'mAP': 0.0, 'mAP50': 0.0} if coco_eval is not None: coco_metrics = coco_eval.stats metrics = {'mAP': coco_metrics[0], 'mAP50': coco_metrics[1]} return metrics
[docs] def predict(self, x: 'Tensor', raw_out: bool = False, out_shape: Optional[Tuple[int, int]] = None) -> BoxList: """Make prediction for an image or batch of images. Args: x (Tensor): Image or batch of images as a float Tensor with pixel values normalized to [0, 1]. raw_out (bool, optional): If True, return prediction probabilities. Defaults to False. out_shape (Optional[Tuple[int, int]], optional): If provided, boxes are resized such that they reference pixel coordinates in an image of this shape. Defaults to None. Returns: BoxList: Predicted boxes. """ out: List[BoxList] = super().predict(x, raw_out=raw_out) out = self.postprocess_model_output(x, out, out_shape=out_shape) return out
[docs] def predict_onnx(self, x: 'Tensor', raw_out: bool = False, out_shape: Optional[Tuple[int, int]] = None) -> BoxList: out: List[BoxList] = super().predict(x, raw_out=raw_out) out = self.postprocess_model_output(x, out, out_shape=out_shape) return out
[docs] def postprocess_model_output(self, x: 'Tensor', out_batch: torch.Tensor, out_shape: Tuple[int, int]): if out_shape is None: return out_batch h_in, w_in = x.shape[-2:] h_out, w_out = out_shape yscale, xscale = (h_out / h_in), (w_out / w_in) with torch.inference_mode(): for out in out_batch: out.scale(yscale, xscale) return out_batch
[docs] def output_to_numpy( self, out: Iterable[BoxList] ) -> Union[Dict[str, np.ndarray], List[Dict[str, np.ndarray]]]: def boxlist_to_numpy(boxlist: BoxList) -> Dict[str, np.ndarray]: return { 'boxes': boxlist.convert_boxes('yxyx').numpy(), 'class_ids': boxlist.get_field('class_ids').numpy(), 'scores': boxlist.get_field('scores').numpy() } if isinstance(out, BoxList): return boxlist_to_numpy(out) else: return [boxlist_to_numpy(boxlist) for boxlist in out]
[docs] def prob_to_pred(self, x): return x
[docs] def export_to_onnx(self, path: str, model: Optional['nn.Module'] = None, sample_input: Optional[torch.Tensor] = None, **kwargs) -> None: if model is None and isinstance(self.model, TorchVisionODAdapter): model = self.model.model return super().export_to_onnx(path, model, sample_input, **kwargs)
[docs] def load_onnx_model(self, model_path: str) -> ONNXRuntimeAdapterForFasterRCNN: log.info(f'Loading ONNX model from {model_path}') onnx_model = ONNXRuntimeAdapterForFasterRCNN.from_file(model_path) return onnx_model