from typing import TYPE_CHECKING, Any
import numpy as np
from rastervision.core.box import Box
from rastervision.core.data.label import ObjectDetectionLabels
from rastervision.core.data.label_source import LabelSource
from rastervision.core.data.vector_source import VectorSource
from rastervision.core.data.utils import parse_array_slices_2d
if TYPE_CHECKING:
from rastervision.core.data import CRSTransformer
[docs]class ObjectDetectionLabelSource(LabelSource):
"""A read-only label source for object detection."""
[docs] def __init__(self,
vector_source: VectorSource,
bbox: Box | None = None,
ioa_thresh: float | None = None,
clip: bool = False):
"""Constructor.
Args:
vector_source: A ``VectorSource``.
bbox: User-specified crop of the extent. If ``None``, the full
extent available in the source file is used.
ioa_thresh: IOA threshold to apply when retrieving labels for a
window. Defaults to ``None``.
clip: Clip bounding boxes to window limits when retrieving labels
for a window. Defaults to ``False``.
"""
self.vector_source = vector_source
geojson = self.vector_source.get_geojson()
self.validate_geojson(geojson)
self.labels = ObjectDetectionLabels.from_geojson(geojson, bbox=bbox)
if bbox is None:
bbox = vector_source.extent
self._bbox = bbox
self.ioa_thresh = ioa_thresh if ioa_thresh is not None else 1e-6
self.clip = clip
[docs] def get_labels(self,
window: Box | None = None,
ioa_thresh: float = 1e-6,
clip: bool = False) -> ObjectDetectionLabels:
"""Get labels (in global coords) for a window.
Args:
window: Window coords.
Returns:
Labels with sufficient overlap with the window. The returned labels
are in global coods (i.e. coords within the full extent of the
source).
"""
if window is None:
return self.labels
window = window.to_global_coords(self.bbox)
return ObjectDetectionLabels.get_overlapping(
self.labels, window, ioa_thresh=ioa_thresh, clip=clip)
[docs] def __getitem__(self, key: Any) -> tuple[np.ndarray, np.ndarray, str]:
"""Get labels (in window coords) for a window.
Returns a 3-tuple: (npboxes, class_ids, box_format).
- npboxes is a float np.ndarray of shape (num_boxes, 4) representing
pixel coords of bounding boxes in the form [ymin, xmin, ymax, xmax].
- class_ids is a np.ndarray of shape (num_boxes,) representing the
class labels for each of the boxes.
- box_format is the format of npboxes which, in this case, is always
'yxyx'.
Args:
window (Box): Window coords.
Returns:
tuple[np.ndarray, np.ndarray, str]: 3-tuple of
(npboxes, class_ids, box_format). The returned npboxes are in
window coords (i.e. coords within the window).
"""
if isinstance(key, Box):
window = key
labels = self.get_labels(
window, ioa_thresh=self.ioa_thresh, clip=self.clip)
class_ids = labels.get_class_ids()
npboxes = labels.get_npboxes()
window_global = window.to_global_coords(self.bbox)
npboxes = ObjectDetectionLabels.global_to_local(
npboxes, window_global)
return npboxes, class_ids, 'yxyx'
window, (h, w) = parse_array_slices_2d(key, extent=self.extent)
npboxes, class_ids, fmt = self[window]
# rescale if steps specified
if h.step is not None:
# assume fmt='yxyx'
npboxes[:, [0, 2]] /= h.step
if w.step is not None:
# assume fmt='yxyx'
npboxes[:, [1, 3]] /= w.step
return npboxes, class_ids, fmt
[docs] def validate_geojson(self, geojson: dict) -> None:
for f in geojson['features']:
geom_type = f.get('geometry', {}).get('type', '')
if 'Point' in geom_type or 'LineString' in geom_type:
raise ValueError(
'LineStrings and Points are not supported '
'in ChipClassificationLabelSource. Use BufferTransformer '
'to buffer them into Polygons.')
for f in geojson['features']:
if f.get('properties', {}).get('class_id') is None:
raise ValueError('All GeoJSON features must have a class_id '
'field in their properties.')
@property
def bbox(self) -> Box:
return self._bbox
@property
def crs_transformer(self) -> 'CRSTransformer':
return self.vector_source.crs_transformer
[docs] def set_bbox(self, bbox: 'Box') -> None:
self._bbox = bbox