from typing import (TYPE_CHECKING, Literal, Sequence, overload)
from collections.abc import Callable
from pydantic import NonNegativeInt as NonNegInt, PositiveInt as PosInt
import math
import random
import numpy as np
from shapely.geometry import Polygon
from shapely.ops import unary_union
from rasterio.windows import Window as RioWindow
from rastervision.pipeline.utils import repr_with_args
from rastervision.core.utils.misc import (calculate_required_padding,
ensure_tuple)
if TYPE_CHECKING:
from typing import Self
from shapely.geometry import MultiPolygon
from shapely.geometry.base import BaseGeometry
[docs]class BoxSizeError(ValueError):
pass
[docs]class Box:
"""A multi-purpose box (ie. rectangle) representation."""
[docs] def __init__(self, ymin: int, xmin: int, ymax: int, xmax: int):
"""Constructor.
Although primarily intended for representing integer pixel coordinates
in a scene, this class can also be used to represent floating point
map coordinates though not all methods might be compatible with that
interpretation.
Args:
ymin: minimum y value (y is row)
xmin: minimum x value (x is column)
ymax: maximum y value
xmax: maximum x value
"""
if not all(math.isfinite(v) for v in (ymin, xmin, ymax, xmax)):
raise ValueError(
f'Invalid Box coordinates: {(ymin, xmin, ymax, xmax)}.')
self.ymin = ymin
self.xmin = xmin
self.ymax = ymax
self.xmax = xmax
def __eq__(self, other: 'Self') -> bool:
"""Return true if other has same coordinates."""
return self.tuple_format() == other.tuple_format()
def __ne__(self, other: 'Self'):
"""Return true if other has different coordinates."""
return self.tuple_format() != other.tuple_format()
@property
def height(self) -> int:
"""Height of the Box."""
return self.ymax - self.ymin
@property
def width(self) -> int:
"""Width of the Box."""
return self.xmax - self.xmin
@property
def extent(self) -> 'Self':
"""Return a Box(0, 0, h, w) representing the size of this Box."""
return Box(0, 0, self.height, self.width)
@property
def size(self) -> tuple[int, int]:
"""(height, width) tuple."""
return self.height, self.width
@property
def area(self) -> int:
"""Return area of Box."""
return self.height * self.width
[docs] def normalize(self) -> 'Self':
"""Ensure ymin <= ymax and xmin <= xmax."""
ymin, ymax = sorted((self.ymin, self.ymax))
xmin, xmax = sorted((self.xmin, self.xmax))
return Box(ymin, xmin, ymax, xmax)
[docs] def to_int(self):
"""Return a new Box with all coordinates cast to ints."""
ymin, xmin, ymax, xmax = self
ymin, xmin, ymax, xmax = int(ymin), int(xmin), int(ymax), int(xmax)
out = Box(ymin, xmin, ymax, xmax)
return out
[docs] @staticmethod
def to_npboxes(boxes: list['Self']) -> np.ndarray:
"""Return nx4 numpy array from list of Box."""
nb_boxes = len(boxes)
npboxes = np.empty((nb_boxes, 4))
for boxind, box in enumerate(boxes):
npboxes[boxind, :] = box.npbox_format()
return npboxes
def __iter__(self):
return iter(self.tuple_format())
[docs] def __getitem__(self, i: NonNegInt):
return self.tuple_format()[i]
def __repr__(self) -> str:
return repr_with_args(self, **self.to_dict())
def __hash__(self) -> int:
return hash(self.tuple_format())
[docs] def geojson_coordinates(self) -> list[tuple[int, int]]:
"""Return Box as GeoJSON coordinates."""
# Compass directions:
nw = [self.xmin, self.ymin]
ne = [self.xmin, self.ymax]
se = [self.xmax, self.ymax]
sw = [self.xmax, self.ymin]
return [nw, ne, se, sw, nw]
[docs] def make_random_square_container(self, size: int) -> 'Self':
"""Return a new square Box that contains this Box.
Args:
size: the width and height of the new Box
"""
return self.make_random_box_container(size, size)
[docs] def make_random_box_container(self, out_h: int, out_w: int) -> 'Self':
"""Return a new rectangular Box that contains this Box.
Args:
out_h (int): the height of the new Box
out_w (int): the width of the new Box
"""
self_h, self_w = self.size
if out_h < self_h:
raise BoxSizeError('size of random container cannot be < height')
if out_w < self_w:
raise BoxSizeError('size of random container cannot be < width')
ymin, xmin, _, _ = self.normalize()
lb = ymin - (out_h - self_h)
ub = ymin
out_ymin = random.randint(int(lb), int(ub))
lb = xmin - (out_w - self_w)
ub = xmin
out_xmin = random.randint(int(lb), int(ub))
return Box(out_ymin, out_xmin, out_ymin + out_h, out_xmin + out_w)
[docs] def make_random_square(self, size: int) -> 'Self':
"""Return new randomly positioned square Box that lies inside this Box.
Args:
size: the height and width of the new Box
"""
if size >= self.width:
raise BoxSizeError('size of random square cannot be >= width')
if size >= self.height:
raise BoxSizeError('size of random square cannot be >= height')
ymin, xmin, ymax, xmax = self.normalize()
lb = ymin
ub = ymax - size
rand_y = random.randint(int(lb), int(ub))
lb = xmin
ub = xmax - size
rand_x = random.randint(int(lb), int(ub))
return Box.make_square(rand_y, rand_x, size)
[docs] def intersection(self, other: 'Self') -> 'Self':
"""Return the intersection of this Box and the other.
Args:
other: The box to intersect with this one.
Returns:
The intersection of this box and the other one.
"""
if not self.intersects(other):
return Box(0, 0, 0, 0)
box1 = self.normalize()
box2 = other.normalize()
xmin = max(box1.xmin, box2.xmin)
ymin = max(box1.ymin, box2.ymin)
xmax = min(box1.xmax, box2.xmax)
ymax = min(box1.ymax, box2.ymax)
return Box(xmin=xmin, ymin=ymin, xmax=xmax, ymax=ymax)
[docs] def intersects(self, other: 'Self') -> bool:
box1 = self.normalize()
box2 = other.normalize()
if box1.ymax <= box2.ymin or box1.ymin >= box2.ymax:
return False
if box1.xmax <= box2.xmin or box1.xmin >= box2.xmax:
return False
return True
[docs] @classmethod
def from_npbox(cls, npbox: np.ndarray) -> 'Self':
"""Return new Box based on npbox format.
Args:
npbox: Numpy array of form [ymin, xmin, ymax, xmax] with float type
"""
return Box(*npbox)
[docs] @classmethod
def from_shapely(cls, shape: 'BaseGeometry') -> 'Self':
"""Instantiate from the bounds of a shapely geometry."""
xmin, ymin, xmax, ymax = shape.bounds
return Box(ymin, xmin, ymax, xmax)
[docs] @classmethod
def from_rasterio(cls, rio_window: RioWindow) -> 'Self':
"""Instantiate from a rasterio window."""
yslice, xslice = rio_window.toslices()
return Box(yslice.start, xslice.start, yslice.stop, xslice.stop)
[docs] def to_xywh(self) -> tuple[int, int, int, int]:
"""Convert to (xmin, ymin, width, height) tuple"""
return (self.xmin, self.ymin, self.width, self.height)
[docs] def to_xyxy(self) -> tuple[int, int, int, int]:
"""Convert to (xmin, ymin, xmax, ymax) tuple"""
return (self.xmin, self.ymin, self.xmax, self.ymax)
[docs] def to_points(self) -> np.ndarray:
"""Get (x, y) coords of each vertex as a 4x2 numpy array."""
return np.array(self.geojson_coordinates()[:4])
[docs] def to_shapely(self) -> Polygon:
"""Convert to shapely Polygon."""
return Polygon.from_bounds(*self.shapely_format())
[docs] def to_rasterio(self) -> RioWindow:
"""Convert to a Rasterio Window."""
return RioWindow.from_slices(*self.to_slices())
[docs] def to_slices(self, h_step: int | None = None,
w_step: int | None = None) -> tuple[slice, slice]:
"""Convert to slices: ymin:ymax[:h_step], xmin:xmax[:w_step]"""
return slice(self.ymin, self.ymax, h_step), slice(
self.xmin, self.xmax, w_step)
[docs] def translate(self, dy: int, dx: int) -> 'Self':
"""Translate window along y and x axes by the given distances."""
ymin, xmin, ymax, xmax = self
return Box(ymin + dy, xmin + dx, ymax + dy, xmax + dx)
[docs] def to_global_coords(self, bbox: 'Self') -> 'Self':
"""Go from bbox coords to global coords.
E.g., Given a box Box(20, 20, 40, 40) and bbox Box(20, 20, 100, 100),
the box becomes Box(40, 40, 60, 60).
Inverse of Box.to_local_coords().
"""
return self.translate(dy=bbox.ymin, dx=bbox.xmin)
[docs] def to_local_coords(self, bbox: 'Self') -> 'Self':
"""Go from to global coords bbox coords.
E.g., Given a box Box(40, 40, 60, 60) and bbox Box(20, 20, 100, 100),
the box becomes Box(20, 20, 40, 40).
Inverse of Box.to_global_coords().
"""
return self.translate(dy=-bbox.ymin, dx=-bbox.xmin)
[docs] def reproject(self, transform_fn: Callable[[tuple], tuple]) -> 'Self':
"""Reprojects this box based on a transform function.
Args:
transform_fn: A function that takes in a tuple (x, y) and
reprojects that point to the target coordinate reference
system.
"""
(xmin, ymin) = transform_fn((self.xmin, self.ymin))
(xmax, ymax) = transform_fn((self.xmax, self.ymax))
return Box(ymin, xmin, ymax, xmax)
[docs] @staticmethod
def make_square(ymin, xmin, size) -> 'Self':
"""Return new square Box."""
return Box(ymin, xmin, ymin + size, xmin + size)
[docs] def center_crop(self, edge_offset_y: int, edge_offset_x: int) -> 'Self':
"""Return Box whose sides are eroded by the given offsets.
Box(0, 0, 10, 10).center_crop(2, 4) == Box(2, 4, 8, 6)
"""
return Box(self.ymin + edge_offset_y, self.xmin + edge_offset_x,
self.ymax - edge_offset_y, self.xmax - edge_offset_x)
[docs] def erode(self, erosion_sz) -> 'Self':
"""Return new Box whose sides are eroded by erosion_sz."""
return self.center_crop(erosion_sz, erosion_sz)
[docs] def buffer(self, buffer_sz: float, max_extent: 'Self') -> 'Self':
"""Return new Box whose sides are buffered by buffer_sz.
The resulting box is clipped so that the values of the corners are
always greater than zero and less than the height and width of
max_extent.
"""
buffer_sz = max(0., buffer_sz)
if buffer_sz < 1.:
delta_width = int(round(buffer_sz * self.width))
delta_height = int(round(buffer_sz * self.height))
else:
delta_height = delta_width = int(round(buffer_sz))
return Box(
max(0, math.floor(self.ymin - delta_height)),
max(0, math.floor(self.xmin - delta_width)),
min(max_extent.height,
int(self.ymax) + delta_height),
min(max_extent.width,
int(self.xmax) + delta_width))
[docs] def pad(self, ymin: int, xmin: int, ymax: int, xmax: int) -> 'Self':
"""Pad sides by the given amount."""
return Box(
ymin=self.ymin - ymin,
xmin=self.xmin - xmin,
ymax=self.ymax + ymax,
xmax=self.xmax + xmax)
[docs] def pad_directional(
self,
padding: tuple[NonNegInt, NonNegInt] | NonNegInt,
pad_direction: Literal['both', 'start', 'end'] = 'end') -> 'Self':
"""Pad sides based on given padding and direction."""
padding: tuple[NonNegInt, NonNegInt] = ensure_tuple(padding)
if padding == (0, 0):
return self
if padding[0] < 0 or padding[1] < 0:
raise ValueError('padding must be non-negative.')
h_pad, w_pad = padding
if pad_direction == 'both':
return self.pad(ymin=h_pad, xmin=w_pad, ymax=h_pad, xmax=w_pad)
elif pad_direction == 'end':
return self.pad(ymin=0, xmin=0, ymax=h_pad, xmax=w_pad)
elif pad_direction == 'start':
return self.pad(ymin=h_pad, xmin=w_pad, ymax=0, xmax=0)
raise ValueError('pad_directions must be one of: '
'"both", "start", "end".')
[docs] def copy(self) -> 'Self':
return Box(*self)
[docs] def get_windows(
self,
size: PosInt | tuple[PosInt, PosInt],
stride: PosInt | tuple[PosInt, PosInt],
padding: NonNegInt | tuple[NonNegInt, NonNegInt] | None = None,
pad_direction: Literal['both', 'start', 'end'] = 'end'
) -> 'SlidingWindows':
"""Return sliding windows for given size, stride, and padding.
Each of ``size``, ``stride``, and ``padding`` can be either a positive
int or a tuple ``(vertical-component, horizontal-component)`` of
positive ints.
If ``padding`` is not specified and ``stride <= size``, it will be
automatically calculated such that the windows cover the entire extent.
Args:
box: Outer box within which to generate sliding windows.
size: Size ``(h, w)`` of the windows.
stride: Step size between windows. Can be a ``(h_step, w_step)``
tuple or positive int.
padding: Optional padding to accommodate windows that overflow the
extent. Can be a ``(h_pad, w_pad)`` tuple or a non-negative
int. If ``None``, will be automatically calculated such that
the windows cover the entire extent. Defaults to ``None``.
pad_direction: Directions to add padding to.
If ``'end'``, only add padding to bottom and right.
If ``'start'``, only add padding to top and left.
If ``'both'``, add padding to all sides.
Has no effect if padding is zero. Defaults to ``'end'``.
Returns:
Lazy list of windows.
"""
windows = SlidingWindows(
self,
size=size,
stride=stride,
padding=padding,
pad_direction=pad_direction)
return windows
[docs] def to_dict(self) -> dict[str, int]:
"""Convert to a dict with keys: ymin, xmin, ymax, xmax."""
return {
'ymin': self.ymin,
'xmin': self.xmin,
'ymax': self.ymax,
'xmax': self.xmax,
}
[docs] @classmethod
def from_dict(cls, d: dict) -> 'Self':
return cls(d['ymin'], d['xmin'], d['ymax'], d['xmax'])
[docs] @staticmethod
def filter_by_aoi(windows: Sequence['Box'],
aoi_polygons: list[Polygon],
within: bool = True) -> tuple[list['Box'], list[int]]:
"""Filters windows by a list of AOI polygons
Args:
within: if True, windows are only kept if they lie fully within an
AOI polygon. Otherwise, windows are kept if they intersect an
AOI polygon.
"""
aoi: Polygon | MultiPolygon = unary_union(aoi_polygons)
keep_window = aoi.contains if within else aoi.intersects
inds = [
i for i, w in enumerate(windows) if keep_window(w.to_shapely())
]
filtered_windows = [windows[i] for i in inds]
return filtered_windows, inds
[docs] @staticmethod
def within_aoi(window: 'Box',
aoi_polygons: Polygon | list[Polygon]) -> bool:
"""Check if window is within the union of given AOI polygons."""
aoi_polygons: Polygon | MultiPolygon = unary_union(aoi_polygons)
w = window.to_shapely()
out = aoi_polygons.contains(w)
return out
[docs] @staticmethod
def intersects_aoi(window: 'Box',
aoi_polygons: Polygon | list[Polygon]) -> bool:
"""Check if window intersects with the union of given AOI polygons."""
aoi_polygons: Polygon | MultiPolygon = unary_union(aoi_polygons)
w = window.to_shapely()
out = aoi_polygons.intersects(w)
return out
[docs] def __contains__(self, query: 'Self | tuple[int, int]') -> bool:
"""Check if box or point is contained within this box.
Args:
query: Box or single point (x, y).
Raises:
NotImplementedError: if query is not a Box or tuple/list.
"""
if isinstance(query, Box):
ymin, xmin, ymax, xmax = query
return (ymin >= self.ymin and xmin >= self.xmin
and ymax <= self.ymax and xmax <= self.xmax)
elif isinstance(query, (tuple, list)):
x, y = query
return self.xmin <= x <= self.xmax and self.ymin <= y <= self.ymax
else:
raise NotImplementedError()
[docs]class SlidingWindows(Sequence[Box]):
"""Lazy representation of a list of sliding windows.
Instead of storing a list of all windows in memory, this class dynamically
computes the coordinates of windows as they are retrieved. Supports
iteration and basic slicing.
"""
[docs] def __init__(
self,
box: Box,
*,
size: PosInt | tuple[PosInt, PosInt],
stride: PosInt | tuple[PosInt, PosInt],
padding: NonNegInt | tuple[NonNegInt, NonNegInt] | None = None,
pad_direction: Literal['both', 'start', 'end'] = 'end',
):
"""Constructor.
Each of ``size``, ``stride``, and ``padding`` can be either a positive
int or a tuple ``(vertical-component, horizontal-component)`` of
positive ints.
If ``padding`` is not specified and ``stride <= size``, it will be
automatically calculated such that the windows cover the entire extent.
Args:
box: Outer box within which to generate sliding windows.
size: Size ``(h, w)`` of the windows.
stride: Step size between windows. Can be a ``(h_step, w_step)``
tuple or positive int.
padding: Optional padding to accommodate windows that overflow the
extent. Can be a ``(h_pad, w_pad)`` tuple or a non-negative
int. If ``None``, will be automatically calculated such that
the windows cover the entire extent. Defaults to ``None``.
pad_direction: Directions to add padding to.
If ``'end'``, only add padding to bottom and right.
If ``'start'``, only add padding to top and left.
If ``'both'``, add padding to all sides.
Has no effect if padding is zero. Defaults to ``'end'``.
"""
size: tuple[PosInt, PosInt] = ensure_tuple(size)
stride: tuple[PosInt, PosInt] = ensure_tuple(stride)
if size[0] <= 0 or size[1] <= 0 or stride[0] <= 0 or stride[1] <= 0:
raise ValueError('size and stride must be positive.')
if padding is None:
if size[0] < stride[0] or size[1] < stride[1]:
padding = (0, 0)
else:
padding = calculate_required_padding(box.size, size, stride,
pad_direction)
self.box = box
self.size = size
self.stride = stride
self.padding: tuple[NonNegInt, NonNegInt] = ensure_tuple(padding)
self.pad_direction = pad_direction
self.padded_box = box.pad_directional(self.padding, self.pad_direction)
self.h, self.w = size
self.y_step, self.x_step = stride
self.y_start = self.padded_box.ymin
self.x_start = self.padded_box.xmin
self.y_end = self.padded_box.ymax - self.h
self.x_end = self.padded_box.xmax - self.w
self.nrows = int((self.y_end - self.y_start) // self.y_step + 1)
self.ncols = int((self.x_end - self.x_start) // self.x_step + 1)
self.total = self.nrows * self.ncols
@overload
def __getitem__(self, i: int | np.integer) -> Box:
...
@overload
def __getitem__(self, s: slice) -> list[Box]:
...
@overload
def __getitem__(self, inds: Sequence[int]) -> list[Box]:
...
[docs] def __getitem__(self, key: int | slice | Sequence[int]) -> Box | list[Box]:
if isinstance(key, int | np.integer):
row, col = self.index_to_rowcol(key)
return self.get_by_rowcol(row, col)
if isinstance(key, slice):
start = 0 if key.start is None else key.start
stop = len(self) if key.stop is None else key.stop
step = 1 if key.step is None else key.step
if not all(isinstance(v, int) for v in (start, stop, step)):
raise TypeError('Slice indices must be integers.')
windows = [self[i] for i in range(start, stop, step)]
return windows
windows = [self[i] for i in key]
return windows
[docs] def get_by_rowcol(self, row: int, col: int) -> Box:
"""Get window at given row and column indices."""
if row >= self.nrows or col >= self.ncols:
raise IndexError()
ymin = self.y_start + self.y_step * row
xmin = self.x_start + self.x_step * col
window = Box(ymin, xmin, ymin + self.h, xmin + self.w)
return window
[docs] def index_to_rowcol(self, i: int) -> tuple[int, int]:
"""Get row and column indices of the i-th window."""
if i >= len(self):
raise IndexError()
if i < 0:
i += len(self)
row = i // self.ncols
col = i % self.ncols
return row, col
def __len__(self) -> int:
return self.total