Source code for rastervision.pytorch_learner.dataset.object_detection_dataset

from typing import TYPE_CHECKING
from os.path import join
from collections import defaultdict
import logging

import albumentations as A
import numpy as np
from torch.utils.data import Dataset

from rastervision.pipeline.file_system import file_to_json
from rastervision.core.box import Box
from rastervision.core.data import ObjectDetectionLabels
from rastervision.pytorch_learner.dataset import (
    TransformType, ImageDataset, SlidingWindowGeoDataset,
    RandomWindowGeoDataset, load_image)
from rastervision.core.data.utils import make_od_scene

if TYPE_CHECKING:
    from rastervision.core.data import ClassConfig, ObjectDetectionLabelSource
log = logging.getLogger(__name__)


[docs]class CocoDataset(Dataset): """Read Object Detection data in the COCO format."""
[docs] def __init__(self, img_dir: str, annotation_uri: str): """Constructor. Args: img_dir: Directory containing the images. Image filenames must match the image IDs in the annotations file. annotation_uri: URI to a JSON file containing annotations in the COCO format. """ self.annotation_uri = annotation_uri ann_json = file_to_json(annotation_uri) self.img_ids: list[str] = [img['id'] for img in ann_json['images']] self.img_paths = { img['id']: join(img_dir, img['file_name']) for img in ann_json['images'] } self.img_anns = {id: defaultdict(list) for id in self.img_ids} for ann in ann_json['annotations']: img_ann = self.img_anns[ann['image_id']] img_ann['bboxes'].append(ann['bbox']) img_ann['category_id'].append(ann['category_id'])
[docs] def __getitem__(self, ind: int ) -> tuple[np.ndarray, tuple[np.ndarray, np.ndarray, str]]: img_id = self.img_ids[ind] path = self.img_paths[img_id] ann: dict[str, list] = self.img_anns[img_id] x = load_image(path) bboxes = np.array(ann['bboxes']) class_ids = np.array(ann['category_id'], dtype=np.int64) if len(bboxes) == 0: bboxes = np.empty((0, 4)) class_ids = np.empty((0, ), dtype=np.int64) return x, (bboxes, class_ids, 'xywh')
def __len__(self): return len(self.img_anns)
[docs]class ObjectDetectionImageDataset(ImageDataset): """Read Object Detection data in the COCO format. Uses :class:`.CocoDataset` to read the data. """
[docs] def __init__(self, img_dir: str, annotation_uri: str, *args, **kwargs): """Constructor. Args: img_dir: Directory containing the images. Image filenames must match the image IDs in the annotations file. annotation_uri: URI to a JSON file containing annotations in the COCO format. *args: See :meth:`.ImageDataset.__init__`. **kwargs: See :meth:`.ImageDataset.__init__`. """ ds = CocoDataset(img_dir, annotation_uri) super().__init__( ds, *args, **kwargs, transform_type=TransformType.object_detection)
[docs]def make_od_geodataset(cls, image_uri: str | list[str], label_vector_uri: str | None = None, class_config: 'ClassConfig | None' = None, aoi_uri: str | list[str] = [], label_vector_default_class_id: int | None = None, image_raster_source_kw: dict = {}, label_vector_source_kw: dict = {}, label_source_kw: dict = {}, **kwargs): """Create an instance of this class from image and label URIs. This is a convenience method. For more fine-grained control, it is recommended to use the default constructor. Args: image_uri: URI or list of URIs of GeoTIFFs to use as the source of image data. label_vector_uri: URI of GeoJSON file to use as the source of label. Defaults to ``None``. class_config: The ClassConfig. Must be non-None if creating a scene without a ``LabelSource``. Defaults to ``None``. aoi_uri: URI or list of URIs of GeoJSONs that specify the area-of-interest. If provided, the dataset will only access data from this area. Defaults to ``[]``. label_vector_default_class_id: If using label_vector_uri and all polygons in that file belong to the same class and they do not contain a `class_id` property, then use this argument to map all of the polygons to the appropriate class ID. See docs for ClassInferenceTransformer for more details. Defaults to ``None``. image_raster_source_kw: Additional arguments to pass to the :class:`.RasterioSource` used for image data. See docs for :class:`.RasterioSource` for more details. Defaults to ``{}``. label_vector_source_kw: Additional arguments to pass to the :class:`.GeoJSONVectorSourceConfig` used for label data, if label_vector_uri is set. See docs for :class:`.GeoJSONVectorSourceConfig` for more details. Defaults to ``{}``. label_source_kw: Additional arguments to pass to the :class:`.ObjectDetectionLabelSourceConfig` used for label data, if label_vector_uri is set. See docs for :class:`.ObjectDetectionLabelSourceConfig` for more details. Defaults to ``{}``. **kwargs: All other keyword args are passed to the default constructor for this class. Returns: An instance of this GeoDataset subclass. """ scene = make_od_scene( image_uri=image_uri, label_vector_uri=label_vector_uri, class_config=class_config, aoi_uri=aoi_uri, label_vector_default_class_id=label_vector_default_class_id, image_raster_source_kw=image_raster_source_kw, label_vector_source_kw=label_vector_source_kw, label_source_kw=label_source_kw) ds = cls(scene, **kwargs) return ds
[docs]class ObjectDetectionSlidingWindowGeoDataset(SlidingWindowGeoDataset): from_uris = classmethod(make_od_geodataset)
[docs] def __init__(self, *args, **kwargs): super().__init__( *args, **kwargs, transform_type=TransformType.object_detection)
[docs]class ObjectDetectionRandomWindowGeoDataset(RandomWindowGeoDataset): from_uris = classmethod(make_od_geodataset)
[docs] def __init__(self, *args, **kwargs): """Constructor. Args: *args: See :meth:`.RandomWindowGeoDataset.__init__`. Keyword Args: bbox_params: Optional ``bbox_params`` to use when resizing windows. Defaults to ``None``. ioa_thresh: Minimum IoA of a bounding box with a given window for it to be included in the labels for that window. Defaults to ``0.9``. clip: Clip bounding boxes to window limits when retrieving labels for a window. Defaults to ``False``. neg_ratio: Ratio of sampling probabilities of negative windows (windows w/o bboxes) vs positive windows (windows w/ at least 1 bbox). E.g. ``neg_ratio=2`` means 2/3 probability of sampling a negative window. If ``None``, the default sampling behavior of ``RandomWindowGeoDataset`` is used, without taking bboxes into account. Defaults to ``None``. neg_ioa_thresh: A window will be considered negative if its max IoA with any bounding box is less than this threshold. Defaults to ``0.2``. **kwargs: See :meth:`.RandomWindowGeoDataset.__init__`. """ from rastervision.pytorch_learner import DEFAULT_BBOX_PARAMS self.bbox_params: A.BboxParams | None = kwargs.pop( 'bbox_params', DEFAULT_BBOX_PARAMS) ioa_thresh: float = kwargs.pop('ioa_thresh', 0.9) clip: bool = kwargs.pop('clip', False) neg_ratio: float | None = kwargs.pop('neg_ratio', None) neg_ioa_thresh: float = kwargs.pop('neg_ioa_thresh', 0.2) super().__init__( *args, **kwargs, transform_type=TransformType.object_detection) label_source: 'ObjectDetectionLabelSource | None' = self.scene.label_source if label_source is not None: label_source.ioa_thresh = ioa_thresh label_source.clip = clip if neg_ratio is not None: if label_source is None: raise ValueError( 'Scene must have a LabelSource if neg_ratio is set.') self.neg_probability = neg_ratio / (neg_ratio + 1) self.neg_ioa_thresh: float = neg_ioa_thresh # Get labels for the AOI. clip=True here to ensure that it is # possible to draw a window (that lies within the extent) around # each bbox. self.labels = label_source.get_labels( ioa_thresh=ioa_thresh, clip=True) num_bboxes_in_scene = len(self.labels) if num_bboxes_in_scene == 0: raise ValueError( 'neg_ratio specified, but no bboxes found in scene.') if self.has_aoi_polygons: self.labels = self.labels.filter_by_aoi( self.scene.aoi_polygons) num_bboxes_in_aoi = len(self.labels) if num_bboxes_in_aoi == 0: raise ValueError( 'neg_ratio specified, but no bboxes found in AOI. ' 'Total bboxes in scene (ignoring AOI):' f'{num_bboxes_in_scene}.') self.bboxes = self.labels.get_boxes() else: self.neg_probability = None
[docs] def append_resize_transform(self, transform: A.BasicTransform, out_size: tuple[int, int]) -> A.BasicTransform: resize_tf = A.Resize(*out_size, always_apply=True) if transform is None: transform = resize_tf else: transform = A.Compose( [transform, resize_tf], bbox_params=self.bbox_params) return transform
def _sample_pos_window(self) -> Box: """Sample a window containing at least one bounding box. This is done by randomly sampling one of the bounding boxes in the scene and drawing a random window around it. """ bbox: Box = np.random.choice(self.bboxes) box_h, box_w = bbox.size # check if it is possible to sample a containing window hmax, wmax = self.max_size if box_h > hmax or box_w > wmax: raise ValueError( f'Cannot sample containing window because bounding box {bbox}' f'is larger than self.max_size ({self.max_size}).') # try to sample a window size that is larger than the box's size for _ in range(self.max_sample_attempts): h, w = self.sample_window_size() if h >= box_h and w >= box_w: window = bbox.make_random_box_container(h, w) return window log.warning('ObjectDetectionRandomWindowGeoDataset: Failed to find ' 'suitable (h, w) for positive window. ' f'Using (hmax, wmax) = ({hmax}, {wmax}) instead.') window = bbox.make_random_box_container(hmax, wmax) return window def _sample_neg_window(self) -> Box: """Attempt to sample a window containing no bounding boxes. If not found within self.max_sample_attempts, just return the last sampled window. """ for _ in range(self.max_sample_attempts): window = super()._sample_window() labels = ObjectDetectionLabels.get_overlapping( self.labels, window, ioa_thresh=self.neg_ioa_thresh) if len(labels) == 0: return window log.warning('ObjectDetectionRandomWindowGeoDataset: Failed to find ' 'negative window. Returning last sampled window.') return window def _sample_window(self) -> Box: """Sample negative or positive window based on neg_probability, if set. If neg_probability is not set, use :meth:`.RandomWindowGeoDataset._sample_window`. """ if self.neg_probability is None: return super()._sample_window() if np.random.sample() < self.neg_probability: return self._sample_neg_window() return self._sample_pos_window()