Source code for rastervision.core.data.label.object_detection_labels

from typing import TYPE_CHECKING, Dict, Iterable, List, Optional
import numpy as np
from shapely.geometry import shape

from rastervision.core.box import Box
from rastervision.core.data.label.labels import Labels
from rastervision.core.data.label.tfod_utils.np_box_list import BoxList
from rastervision.core.data.label.tfod_utils.np_box_list_ops import (
    prune_non_overlapping_boxes, clip_to_window, concatenate,
    non_max_suppression)

if TYPE_CHECKING:
    from rastervision.core.data import (ClassConfig, CRSTransformer)
    from shapely.geometry import Polygon


[docs]class ObjectDetectionLabels(Labels): """A set of boxes and associated class_ids and scores. Implemented using the Tensorflow Object Detection API's BoxList class. """
[docs] def __init__(self, npboxes: np.array, class_ids: np.array, scores: np.array = None): """Construct a set of object detection labels. Args: npboxes: float numpy array of size nx4 with cols ymin, xmin, ymax, xmax. Should be in pixel coordinates within the global frame of reference. class_ids: int numpy array of size n with class ids scores: float numpy array of size n """ self.boxlist = BoxList(npboxes) # This field name actually needs to be 'classes' to be able to use # certain utility functions in the TF Object Detection API. self.boxlist.add_field('classes', class_ids) # We need to ensure that there is always a scores field so that the # concatenate method will work with empty labels objects. if scores is None: scores = np.ones(class_ids.shape) self.boxlist.add_field('scores', scores)
def __add__(self, other: 'ObjectDetectionLabels') -> 'ObjectDetectionLabels': return ObjectDetectionLabels.concatenate(self, other) def __eq__(self, other: 'ObjectDetectionLabels') -> bool: return (isinstance(other, ObjectDetectionLabels) and self.to_dict() == other.to_dict()) def __setitem__(self, window: Box, item: Dict[str, np.ndarray]): boxes = item['boxes'] boxes = ObjectDetectionLabels.local_to_global(boxes, window) class_ids = item['class_ids'] scores = item.get('scores') new_labels = ObjectDetectionLabels(boxes, class_ids, scores=scores) concatenated_labels = self + new_labels self.boxlist = concatenated_labels.boxlist
[docs] def assert_equal(self, expected_labels: 'ObjectDetectionLabels'): np.testing.assert_array_equal(self.get_npboxes(), expected_labels.get_npboxes()) np.testing.assert_array_equal(self.get_class_ids(), expected_labels.get_class_ids()) np.testing.assert_array_equal(self.get_scores(), expected_labels.get_scores())
[docs] def filter_by_aoi(self, aoi_polygons: Iterable['Polygon']): boxes = self.get_boxes() class_ids = self.get_class_ids() scores = self.get_scores() new_boxes = [] new_class_ids = [] new_scores = [] for box, class_id, score in zip(boxes, class_ids, scores): box_poly = box.to_shapely() for aoi in aoi_polygons: if box_poly.within(aoi): new_boxes.append(box.npbox_format()) new_class_ids.append(class_id) new_scores.append(score) break if len(new_boxes) == 0: return ObjectDetectionLabels.make_empty() return ObjectDetectionLabels( np.array(new_boxes), np.array(new_class_ids), np.array(new_scores))
[docs] @classmethod def make_empty(cls) -> 'ObjectDetectionLabels': npboxes = np.empty((0, 4)) class_ids = np.empty((0, )) scores = np.empty((0, )) return cls(npboxes, class_ids, scores)
[docs] @staticmethod def from_boxlist(boxlist: BoxList): """Make ObjectDetectionLabels from BoxList object.""" scores = (boxlist.get_field('scores') if boxlist.has_field('scores') else None) return ObjectDetectionLabels( boxlist.get(), boxlist.get_field('classes'), scores=scores)
[docs] @staticmethod def from_geojson(geojson: dict, extent: Optional[Box] = None) -> 'ObjectDetectionLabels': """Convert GeoJSON to ObjectDetectionLabels object. If extent is provided, filter out the boxes that lie "more than a little bit" outside the extent. Args: geojson: (dict) normalized GeoJSON (see VectorSource) extent: (Box) in pixel coords Returns: ObjectDetectionLabels """ boxes = [] class_ids = [] scores = [] for f in geojson['features']: geom = shape(f['geometry']) (xmin, ymin, xmax, ymax) = geom.bounds boxes.append(Box(ymin, xmin, ymax, xmax)) props = f['properties'] class_ids.append(props['class_id']) scores.append(props.get('score', 1.0)) if len(boxes): boxes = np.array( [box.npbox_format() for box in boxes], dtype=float) class_ids = np.array(class_ids) scores = np.array(scores) labels = ObjectDetectionLabels(boxes, class_ids, scores=scores) else: labels = ObjectDetectionLabels.make_empty() if extent is not None: labels = ObjectDetectionLabels.get_overlapping( labels, extent, ioa_thresh=0.8, clip=True) return labels
[docs] def get_boxes(self) -> List[Box]: """Return list of Boxes.""" return [Box.from_npbox(npbox) for npbox in self.boxlist.get()]
[docs] def get_npboxes(self) -> np.ndarray: return self.boxlist.get()
[docs] def get_scores(self) -> np.ndarray: if self.boxlist.has_field('scores'): return self.boxlist.get_field('scores') return None
[docs] def get_class_ids(self) -> np.ndarray: return self.boxlist.get_field('classes')
def __len__(self) -> int: return self.boxlist.get().shape[0] def __str__(self) -> str: return str(self.boxlist.get())
[docs] def to_boxlist(self) -> BoxList: return self.boxlist
[docs] def to_dict(self) -> dict: """Returns a dict version of these labels. The Dict has a Box as a key, and a tuple of (class_id, score) as the values. """ d = {} boxes = list(map(Box.from_npbox, self.get_npboxes())) classes = list(self.get_class_ids()) scores = list(self.get_scores()) for box, class_id, score in zip(boxes, classes, scores): d[box.tuple_format()] = (class_id, score) return d
[docs] @staticmethod def local_to_global(npboxes: np.ndarray, window: Box): """Convert from local to global coordinates. The local coordinates are row/col within the window frame of reference. The global coordinates are row/col within the extent of a RasterSource. """ xmin = window.xmin ymin = window.ymin return npboxes + np.array([[ymin, xmin, ymin, xmin]])
[docs] @staticmethod def global_to_local(npboxes: np.ndarray, window: Box): """Convert from global to local coordinates. The global coordinates are row/col within the extent of a RasterSource. The local coordinates are row/col within the window frame of reference. """ xmin = window.xmin ymin = window.ymin return npboxes - np.array([[ymin, xmin, ymin, xmin]])
[docs] @staticmethod def local_to_normalized(npboxes: np.ndarray, window: Box): """Convert from local to normalized coordinates. The local coordinates are row/col within the window frame of reference. Normalized coordinates range from 0 to 1 on each (height/width) axis. """ height, width = window.size return npboxes / np.array([[height, width, height, width]])
[docs] @staticmethod def normalized_to_local(npboxes: np.ndarray, window: Box): """Convert from normalized to local coordinates. Normalized coordinates range from 0 to 1 on each (height/width) axis. The local coordinates are row/col within the window frame of reference. """ height, width = window.size return npboxes * np.array([[height, width, height, width]])
[docs] @staticmethod def get_overlapping(labels: 'ObjectDetectionLabels', window: Box, ioa_thresh: float = 0.000001, clip: bool = False) -> 'ObjectDetectionLabels': """Return subset of labels that overlap with window. Args: labels: ObjectDetectionLabels window: Box ioa_thresh: the minimum IOA for a box to be considered as overlapping clip: if True, clip label boxes to the window """ window_npbox = window.npbox_format() window_boxlist = BoxList(np.expand_dims(window_npbox, axis=0)) boxlist = prune_non_overlapping_boxes( labels.boxlist, window_boxlist, minoverlap=ioa_thresh) if clip: boxlist = clip_to_window(boxlist, window_npbox) return ObjectDetectionLabels.from_boxlist(boxlist)
[docs] @staticmethod def concatenate( labels1: 'ObjectDetectionLabels', labels2: 'ObjectDetectionLabels') -> 'ObjectDetectionLabels': """Return concatenation of labels. Args: labels1: ObjectDetectionLabels labels2: ObjectDetectionLabels """ new_boxlist = concatenate([labels1.to_boxlist(), labels2.to_boxlist()]) return ObjectDetectionLabels.from_boxlist(new_boxlist)
[docs] @staticmethod def prune_duplicates(labels: 'ObjectDetectionLabels', score_thresh: float, merge_thresh: float) -> 'ObjectDetectionLabels': """Remove duplicate boxes. Runs non-maximum suppression to remove duplicate boxes that result from sliding window prediction algorithm. Args: labels: ObjectDetectionLabels score_thresh: the minimum allowed score of boxes merge_thresh: the minimum IOA allowed when merging two boxes together Returns: ObjectDetectionLabels """ max_output_size = 1000000 pruned_boxlist = non_max_suppression( labels.boxlist, max_output_size=max_output_size, iou_threshold=merge_thresh, score_threshold=score_thresh) return ObjectDetectionLabels.from_boxlist(pruned_boxlist)
[docs] def save(self, uri: str, class_config: 'ClassConfig', crs_transformer: 'CRSTransformer') -> None: """Save labels as a GeoJSON file. Args: uri (str): URI of the output file. class_config (ClassConfig): ClassConfig to map class IDs to names. crs_transformer (CRSTransformer): CRSTransformer to convert from pixel-coords to map-coords before saving. """ from rastervision.core.data import ObjectDetectionGeoJSONStore label_store = ObjectDetectionGeoJSONStore( uri=uri, class_config=class_config, crs_transformer=crs_transformer) label_store.save(self)