Source code for rastervision.core.data.label.object_detection_labels

from typing import TYPE_CHECKING, Dict, Iterable, List, Optional
import numpy as np
from shapely.geometry import shape

from rastervision.core.box import Box
from rastervision.core.data.label.labels import Labels
from rastervision.core.data.label.tfod_utils.np_box_list import NpBoxList
from rastervision.core.data.label.tfod_utils.np_box_list_ops import (
    prune_non_overlapping_boxes, clip_to_window, concatenate,
    non_max_suppression)

if TYPE_CHECKING:
    from rastervision.core.data import (ClassConfig, CRSTransformer)
    from shapely.geometry import Polygon


[docs]class ObjectDetectionLabels(Labels): """A set of boxes and associated class_ids and scores. Implemented using the Tensorflow Object Detection API's BoxList class. """
[docs] def __init__(self, npboxes: np.array, class_ids: np.array, scores: np.array = None): """Construct a set of object detection labels. Args: npboxes: float numpy array of size nx4 with cols ymin, xmin, ymax, xmax. Should be in pixel coordinates within the global frame of reference. class_ids: int numpy array of size n with class ids scores: float numpy array of size n """ self.boxlist = NpBoxList(npboxes) # This field name actually needs to be 'classes' to be able to use # certain utility functions in the TF Object Detection API. self.boxlist.add_field('classes', class_ids) # We need to ensure that there is always a scores field so that the # concatenate method will work with empty labels objects. if scores is None: scores = np.ones(class_ids.shape) self.boxlist.add_field('scores', scores)
[docs] def __add__(self, other: 'ObjectDetectionLabels') -> 'ObjectDetectionLabels': return ObjectDetectionLabels.concatenate(self, other)
def __eq__(self, other: 'ObjectDetectionLabels') -> bool: return (isinstance(other, ObjectDetectionLabels) and self.to_dict() == other.to_dict())
[docs] def __setitem__(self, window: Box, item: Dict[str, np.ndarray]): boxes = item['boxes'] boxes = ObjectDetectionLabels.local_to_global(boxes, window) class_ids = item['class_ids'] scores = item.get('scores') new_labels = ObjectDetectionLabels(boxes, class_ids, scores=scores) concatenated_labels = self + new_labels self.boxlist = concatenated_labels.boxlist
[docs] def __getitem__(self, window: Box) -> 'ObjectDetectionLabels': return ObjectDetectionLabels.get_overlapping(self, window)
[docs] def assert_equal(self, expected_labels: 'ObjectDetectionLabels'): np.testing.assert_array_equal(self.get_npboxes(), expected_labels.get_npboxes()) np.testing.assert_array_equal(self.get_class_ids(), expected_labels.get_class_ids()) np.testing.assert_array_equal(self.get_scores(), expected_labels.get_scores())
[docs] def filter_by_aoi(self, aoi_polygons: Iterable['Polygon']): boxes = self.get_boxes() class_ids = self.get_class_ids() scores = self.get_scores() new_boxes = [] new_class_ids = [] new_scores = [] for box, class_id, score in zip(boxes, class_ids, scores): box_poly = box.to_shapely() for aoi in aoi_polygons: if box_poly.within(aoi): new_boxes.append(box.npbox_format()) new_class_ids.append(class_id) new_scores.append(score) break if len(new_boxes) == 0: return ObjectDetectionLabels.make_empty() return ObjectDetectionLabels( np.array(new_boxes), np.array(new_class_ids), np.array(new_scores))
[docs] @classmethod def make_empty(cls) -> 'ObjectDetectionLabels': npboxes = np.empty((0, 4)) class_ids = np.empty((0, )) scores = np.empty((0, )) return cls(npboxes, class_ids, scores)
[docs] @staticmethod def from_boxlist(boxlist: NpBoxList): """Make ObjectDetectionLabels from BoxList object.""" scores = (boxlist.get_field('scores') if boxlist.has_field('scores') else None) return ObjectDetectionLabels( boxlist.get(), boxlist.get_field('classes'), scores=scores)
[docs] @staticmethod def from_geojson(geojson: dict, bbox: Optional[Box] = None, ioa_thresh: float = 0.8, clip: bool = True) -> 'ObjectDetectionLabels': """Convert GeoJSON to ObjectDetectionLabels object. If bbox is provided, filter out the boxes that lie "more than a little bit" outside the bbox. Args: geojson: (dict) normalized GeoJSON (see VectorSource) bbox: (Box) in pixel coords Returns: ObjectDetectionLabels """ features = geojson['features'] if len(features) == 0: labels = ObjectDetectionLabels.make_empty() else: boxes = [Box.from_shapely(shape(f['geometry'])) for f in features] class_ids = [f['properties']['class_id'] for f in features] scores = [f['properties'].get('score', 1.0) for f in features] boxes = np.array([b.npbox_format() for b in boxes], dtype=float) class_ids = np.array(class_ids) scores = np.array(scores) labels = ObjectDetectionLabels(boxes, class_ids, scores=scores) if bbox is not None: labels = ObjectDetectionLabels.get_overlapping( labels, bbox, ioa_thresh=ioa_thresh, clip=clip) return labels
[docs] def get_boxes(self) -> List[Box]: """Return list of Boxes.""" return [Box.from_npbox(npbox) for npbox in self.boxlist.get()]
[docs] def get_npboxes(self) -> np.ndarray: return self.boxlist.get()
[docs] def get_scores(self) -> np.ndarray: if self.boxlist.has_field('scores'): return self.boxlist.get_field('scores') return None
[docs] def get_class_ids(self) -> np.ndarray: return self.boxlist.get_field('classes')
def __len__(self) -> int: return self.boxlist.get().shape[0] def __str__(self) -> str: # prama: no cover return str(self.boxlist.get())
[docs] def to_boxlist(self) -> NpBoxList: return self.boxlist
[docs] def to_dict(self, round_boxes: bool = True) -> dict: """Returns a dict version of these labels. The Dict has a Box as a key, and a tuple of (class_id, score) as the values. """ npboxes = self.get_npboxes() if round_boxes and np.issubdtype(npboxes.dtype, np.floating): npboxes = npboxes.round(2) classes = self.get_class_ids() scores = self.get_scores().round(6) d = { Box.from_npbox(box): (class_id, score) for box, class_id, score in zip(npboxes, classes, scores) } return d
[docs] @staticmethod def local_to_global(npboxes: np.ndarray, window: Box): """Convert from local to global coordinates. The local coordinates are row/col within the window frame of reference. The global coordinates are row/col within the extent of a RasterSource. """ xmin = window.xmin ymin = window.ymin return npboxes + np.array([[ymin, xmin, ymin, xmin]])
[docs] @staticmethod def global_to_local(npboxes: np.ndarray, window: Box): """Convert from global to local coordinates. The global coordinates are row/col within the extent of a RasterSource. The local coordinates are row/col within the window frame of reference. """ xmin = window.xmin ymin = window.ymin return npboxes - np.array([[ymin, xmin, ymin, xmin]])
[docs] @staticmethod def local_to_normalized(npboxes: np.ndarray, window: Box): """Convert from local to normalized coordinates. The local coordinates are row/col within the window frame of reference. Normalized coordinates range from 0 to 1 on each (height/width) axis. """ height, width = window.size return npboxes / np.array([[height, width, height, width]])
[docs] @staticmethod def normalized_to_local(npboxes: np.ndarray, window: Box): """Convert from normalized to local coordinates. Normalized coordinates range from 0 to 1 on each (height/width) axis. The local coordinates are row/col within the window frame of reference. """ height, width = window.size return npboxes * np.array([[height, width, height, width]])
[docs] @staticmethod def get_overlapping(labels: 'ObjectDetectionLabels', window: Box, ioa_thresh: float = 0.5, clip: bool = False) -> 'ObjectDetectionLabels': """Return subset of labels that overlap with window. Args: labels: ObjectDetectionLabels window: Box ioa_thresh: The minimum intersection-over-area (IOA) for a box to be considered as overlapping. For each box, IOA is defined as the area of the intersection of the box with the window over the area of the box. clip: If True, clip label boxes to the window. """ window_npbox = window.npbox_format() window_boxlist = NpBoxList(np.expand_dims(window_npbox, axis=0)) boxlist = prune_non_overlapping_boxes( labels.boxlist, window_boxlist, minoverlap=ioa_thresh) if clip: boxlist = clip_to_window(boxlist, window_npbox) return ObjectDetectionLabels.from_boxlist(boxlist)
[docs] @staticmethod def concatenate( labels1: 'ObjectDetectionLabels', labels2: 'ObjectDetectionLabels') -> 'ObjectDetectionLabels': """Return concatenation of labels. Args: labels1: ObjectDetectionLabels labels2: ObjectDetectionLabels """ new_boxlist = concatenate([labels1.to_boxlist(), labels2.to_boxlist()]) return ObjectDetectionLabels.from_boxlist(new_boxlist)
[docs] @staticmethod def prune_duplicates( labels: 'ObjectDetectionLabels', score_thresh: float, merge_thresh: float, max_output_size: Optional[int] = None) -> 'ObjectDetectionLabels': """Remove duplicate boxes via non-maximum suppression. Args: labels: Labels whose boxes are to be pruned. score_thresh: Prune boxes with score less than this threshold. merge_thresh: Prune boxes with intersection-over-union (IOU) greater than this threshold. max_output_size (int): Maximum number of retained boxes. If None, this is set to ``len(abels)``. Defaults to None. Returns: ObjectDetectionLabels: Pruned labels. """ if max_output_size is None: max_output_size = len(labels) pruned_boxlist = non_max_suppression( labels.boxlist, max_output_size=max_output_size, iou_threshold=merge_thresh, score_threshold=score_thresh) return ObjectDetectionLabels.from_boxlist(pruned_boxlist)
[docs] def save(self, uri: str, class_config: 'ClassConfig', crs_transformer: 'CRSTransformer', bbox: Optional[Box] = None) -> None: """Save labels as a GeoJSON file. Args: uri (str): URI of the output file. class_config (ClassConfig): ClassConfig to map class IDs to names. crs_transformer (CRSTransformer): CRSTransformer to convert from pixel-coords to map-coords before saving. bbox (Optional[Box]): User-specified crop of the extent. Must be provided if the corresponding RasterSource has bbox != extent. """ from rastervision.core.data import ObjectDetectionGeoJSONStore label_store = ObjectDetectionGeoJSONStore( uri=uri, class_config=class_config, crs_transformer=crs_transformer, bbox=bbox) label_store.save(self)