Source code for rastervision.core.data.label_source.object_detection_label_source

from typing import Any, Optional, Tuple

import numpy as np

from rastervision.core.box import Box
from rastervision.core.data.label import ObjectDetectionLabels
from rastervision.core.data.label_source import LabelSource
from rastervision.core.data.vector_source import VectorSource


[docs]class ObjectDetectionLabelSource(LabelSource): """A read-only label source for object detection."""
[docs] def __init__(self, vector_source: VectorSource, extent: Box, ioa_thresh: Optional[float] = None, clip: bool = False): """Constructor. Args: vector_source (VectorSource): A VectorSource. extent (Box): Box used to filter the labels by extent. ioa_thresh (Optional[float], optional): IOA threshold to apply when retieving labels for a window. Defaults to None. clip (bool, optional): Clip bounding boxes to window limits when retrieving labels for a window. Defaults to False. """ geojson = vector_source.get_geojson() self.validate_geojson(geojson) self.labels = ObjectDetectionLabels.from_geojson( geojson, extent=extent) self._extent = extent self.ioa_thresh = ioa_thresh if ioa_thresh is not None else 1e-6 self.clip = clip
[docs] def get_labels(self, window: Box = None, ioa_thresh: float = 1e-6, clip: bool = False) -> ObjectDetectionLabels: """Get labels (in global coords) for a window. Args: window (Box): Window coords. Returns: ObjectDetectionLabels: Labels with sufficient overlap with the window. The returned labels are in global coods (i.e. coords wihtin the full extent). """ if window is None: return self.labels window = window.shift_origin(self.extent) return ObjectDetectionLabels.get_overlapping( self.labels, window, ioa_thresh=ioa_thresh, clip=clip)
def __getitem__(self, key: Any) -> Tuple[np.ndarray, np.ndarray, str]: """Get labels (in window coords) for a window. Returns a 3-tuple: (npboxes, class_ids, box_format). - npboxes is a float np.ndarray of shape (num_boxes, 4) representing pixel coords of bounding boxes in the form [ymin, xmin, ymax, xmax]. - class_ids is a np.ndarray of shape (num_boxes,) representing the class labels for each of the boxes. - box_format is the format of npboxes which, in this case, is always 'yxyx'. Args: window (Box): Window coords. Returns: Tuple[np.ndarray, np.ndarray, str]: 3-tuple of (npboxes, class_ids, box_format). The returned npboxes are in window coords (i.e. coords within the window). """ if isinstance(key, Box): window = key labels = self.get_labels( window, ioa_thresh=self.ioa_thresh, clip=self.clip) class_ids = labels.get_class_ids() npboxes = labels.get_npboxes() npboxes = ObjectDetectionLabels.global_to_local(npboxes, window) return npboxes, class_ids, 'yxyx' elif isinstance(key, slice): key = [key] elif isinstance(key, tuple): pass else: raise TypeError('Unsupported key type.') slices = list(key) assert 1 <= len(slices) <= 2 assert all(s is not None for s in slices) assert isinstance(slices[0], slice) if len(slices) == 1: h, = slices w = slice(None, None) else: assert isinstance(slices[1], slice) h, w = slices if any(x is not None and x < 0 for x in [h.start, h.stop, w.start, w.stop]): raise NotImplementedError() ymin, xmin, ymax, xmax = self.extent _ymin = 0 if h.start is None else h.start _xmin = 0 if w.start is None else w.start _ymax = ymax if h.stop is None else h.stop _xmax = xmax if w.stop is None else w.stop window = Box(_ymin, _xmin, _ymax, _xmax) npboxes, class_ids, fmt = self[window] # rescale if steps specified if h.step is not None: # assume fmt='yxyx' npboxes[:, [0, 2]] /= h.step if w.step is not None: # assume fmt='yxyx' npboxes[:, [1, 3]] /= w.step return npboxes, class_ids, fmt
[docs] def validate_geojson(self, geojson: dict) -> None: for f in geojson['features']: geom_type = f.get('geometry', {}).get('type', '') if 'Point' in geom_type or 'LineString' in geom_type: raise ValueError( 'LineStrings and Points are not supported ' 'in ChipClassificationLabelSource. Use BufferTransformer ' 'to buffer them into Polygons.') for f in geojson['features']: if f.get('properties', {}).get('class_id') is None: raise ValueError('All GeoJSON features must have a class_id ' 'field in their properties.')
@property def extent(self) -> Box: return self._extent