Source code for rastervision.pytorch_backend.pytorch_object_detection

from typing import TYPE_CHECKING, Dict, Iterator, Optional
from os.path import join, basename
import uuid

from rastervision.pipeline.file_system import json_to_file
from rastervision.core.data.label import ObjectDetectionLabels
from rastervision.pytorch_backend.pytorch_learner_backend import (
    PyTorchLearnerSampleWriter, PyTorchLearnerBackend)
from rastervision.pytorch_learner.dataset import (
    ObjectDetectionSlidingWindowGeoDataset)

if TYPE_CHECKING:
    import numpy as np
    from rastervision.core.data import Scene
    from rastervision.core.data_sample import DataSample


[docs]class PyTorchObjectDetectionSampleWriter(PyTorchLearnerSampleWriter): """Writes data in COCO format.""" def __enter__(self): super().__enter__() self.splits = { 'train': { 'images': [], 'annotations': [] }, 'valid': { 'images': [], 'annotations': [] } } self.categories = [{ 'id': class_id, 'name': class_name } for class_id, class_name in enumerate(self.class_config.names)] return self def __exit__(self, type, value, traceback): """This writes label files in COCO format to (train|valid)/labels.json""" for split in ['train', 'valid']: if len(self.splits[split]['images']) > 0: split_dir = join(self.sample_dir, split) labels_path = join(split_dir, 'labels.json') images = self.splits[split]['images'] annotations = self.splits[split]['annotations'] coco_dict = { 'images': images, 'annotations': annotations, 'categories': self.categories } json_to_file(coco_dict, labels_path) super().__exit__(type, value, traceback)
[docs] def write_sample(self, sample: 'DataSample'): """ This writes a training or validation sample to (train|valid)/img/{scene_id}-{ind}.png and updates some COCO data structures. """ split_name = 'train' if sample.is_train else 'valid' img_path = self.get_image_path(split_name, sample) self.write_chip(sample.chip, img_path) self.update_coco_data(split_name, sample, img_path) self.sample_ind += 1
[docs] def update_coco_data(self, split_name: str, sample: 'DataSample', img_path: str): images = self.splits[split_name]['images'] annotations = self.splits[split_name]['annotations'] images.append({ 'file_name': basename(img_path), 'id': self.sample_ind, 'height': sample.chip.shape[0], 'width': sample.chip.shape[1] }) npboxes = sample.labels.get_npboxes() npboxes = ObjectDetectionLabels.global_to_local(npboxes, sample.window) for box_ind, (box, class_id) in enumerate( zip(npboxes, sample.labels.get_class_ids())): bbox = [box[1], box[0], box[3] - box[1], box[2] - box[0]] bbox = [int(i) for i in bbox] annotations.append({ 'id': '{}-{}'.format(self.sample_ind, box_ind), 'image_id': self.sample_ind, 'bbox': bbox, 'category_id': int(class_id) })
[docs]class PyTorchObjectDetection(PyTorchLearnerBackend):
[docs] def get_sample_writer(self): output_uri = join(self.pipeline_cfg.chip_uri, f'{uuid.uuid4()}.zip') return PyTorchObjectDetectionSampleWriter( output_uri, self.pipeline_cfg.dataset.class_config, self.tmp_dir)
[docs] def predict_scene(self, scene: 'Scene', chip_sz: int, stride: Optional[int] = None) -> ObjectDetectionLabels: if stride is None: stride = chip_sz if self.learner is None: self.load_model() # Important to use self.learner.cfg.data instead of # self.learner_cfg.data because of the updates # Learner.from_model_bundle() makes to the custom transforms. base_tf, _ = self.learner.cfg.data.get_data_transforms() ds = ObjectDetectionSlidingWindowGeoDataset( scene, size=chip_sz, stride=stride, transform=base_tf) predictions: Iterator[Dict[str, 'np.ndarray']] = ( self.learner.predict_dataset( ds, raw_out=True, numpy_out=True, progress_bar=True, progress_bar_kw=dict(desc=f'Making predictions on {scene.id}')) ) labels = ObjectDetectionLabels.from_predictions( ds.windows, predictions) return labels