Source code for rastervision.pytorch_backend.pytorch_object_detection

from typing import TYPE_CHECKING
from os.path import join, basename
import uuid

from rastervision.pipeline.file_system import json_to_file
from rastervision.core.data_sample import DataSample
from rastervision.core.data.label import ObjectDetectionLabels
from rastervision.pytorch_backend.pytorch_learner_backend import (
    PyTorchLearnerSampleWriter, PyTorchLearnerBackend)
from rastervision.pytorch_backend.utils import chip_collate_fn_od
from rastervision.pytorch_learner.utils import predict_scene_od

if TYPE_CHECKING:
    from rastervision.core.data import DatasetConfig, Scene
    from rastervision.core.rv_pipeline import (ChipOptions,
                                               ObjectDetectionPredictOptions)
    from rastervision.pytorch_learner.object_detection_utils import BoxList
    from rastervision.pytorch_learner.object_detection_learner_config import (
        ObjectDetectionGeoDataConfig)


[docs]class PyTorchObjectDetectionSampleWriter(PyTorchLearnerSampleWriter): """Writes data in COCO format.""" def __enter__(self): super().__enter__() self.splits = { 'train': { 'images': [], 'annotations': [] }, 'valid': { 'images': [], 'annotations': [] } } self.categories = [{ 'id': class_id, 'name': class_name } for class_id, class_name in enumerate(self.class_config.names)] return self def __exit__(self, type, value, traceback): """This writes label files in COCO format to (train|valid)/labels.json""" for split in ['train', 'valid']: if len(self.splits[split]['images']) == 0: continue split_dir = join(self.sample_dir, split) labels_path = join(split_dir, 'labels.json') images = self.splits[split]['images'] annotations = self.splits[split]['annotations'] coco_dict = { 'images': images, 'annotations': annotations, 'categories': self.categories } json_to_file(coco_dict, labels_path) super().__exit__(type, value, traceback)
[docs] def write_sample(self, sample: 'DataSample'): """ This writes a training or validation sample to (train|valid)/img/{scene_id}-{ind}.png and updates some COCO data structures. """ img_path = self.get_image_path(sample) self.write_chip(sample.chip, img_path) self.update_coco_data(sample, img_path) self.sample_ind += 1
[docs] def update_coco_data(self, sample: 'DataSample', img_path: str): split = 'default' if sample.split is None else sample.split images = self.splits[split]['images'] annotations = self.splits[split]['annotations'] images.append({ 'file_name': basename(img_path), 'id': self.sample_ind, 'height': sample.chip.shape[0], 'width': sample.chip.shape[1] }) boxlist: 'BoxList' = sample.label npboxes = boxlist.convert_boxes('xywh') class_ids = boxlist.get_field('class_ids') for i, (bbox, class_id) in enumerate(zip(npboxes, class_ids)): bbox = [int(v) for v in bbox] class_id = int(class_id) annotations.append({ 'id': f'{self.sample_ind}-{i}', 'image_id': self.sample_ind, 'bbox': bbox, 'category_id': class_id, })
[docs]class PyTorchObjectDetection(PyTorchLearnerBackend):
[docs] def get_sample_writer(self): output_uri = join(self.pipeline_cfg.chip_uri, f'{uuid.uuid4()}.zip') return PyTorchObjectDetectionSampleWriter( output_uri, self.pipeline_cfg.dataset.class_config, self.tmp_dir)
[docs] def chip_dataset(self, dataset: 'DatasetConfig', chip_options: 'ChipOptions', dataloader_kw: dict = {}) -> None: dataloader_kw = dict(**dataloader_kw, collate_fn=chip_collate_fn_od) return super().chip_dataset(dataset, chip_options, dataloader_kw)
[docs] def predict_scene(self, scene: 'Scene', predict_options: 'ObjectDetectionPredictOptions' ) -> ObjectDetectionLabels: if self.learner is None: self.load_model() labels = predict_scene_od(self.learner, scene, predict_options) return labels
def _make_chip_data_config( self, dataset: 'DatasetConfig', chip_options: 'ChipOptions') -> 'ObjectDetectionGeoDataConfig': from rastervision.pytorch_learner import (ObjectDetectionGeoDataConfig) data_config = ObjectDetectionGeoDataConfig( scene_dataset=dataset, sampling=chip_options.sampling) return data_config