from typing import TYPE_CHECKING, Any
import logging
import numpy as np
from PIL import ImageColor
from skimage.io import imsave
from rastervision.core.box import Box
if TYPE_CHECKING:
from rastervision.core.data import (RasterSource, LabelSource, LabelStore)
log = logging.getLogger(__name__)
[docs]def color_to_triple(color: str | None = None) -> tuple[int, int, int]:
"""Given a PIL ImageColor string, return a triple of integers
representing the red, green, and blue values.
If color is None, return a random color.
Args:
color: A PIL ImageColor string
Returns:
An triple of integers
"""
if color is None:
r, g, b = np.random.randint(0, 256, size=3).tolist()
return r, g, b
return ImageColor.getrgb(color)
[docs]def color_to_integer(color: str) -> int:
"""Given a PIL ImageColor string, return a packed integer.
Args:
color: A PIL ImageColor string
Returns:
An integer containing the packed RGB values.
"""
triple = color_to_triple(color) if isinstance(color, str) else color
r = triple[0] * (1 << 16)
g = triple[1] * (1 << 8)
b = triple[2] * (1 << 0)
integer = r + g + b
return integer
[docs]def normalize_color(
color: str | tuple | list | None) -> tuple[float, float, float]:
"""Convert color representation to a float 3-tuple with values in [0-1]."""
if isinstance(color, str):
color = color_to_triple(color)
if isinstance(color, (tuple, list)):
if all(isinstance(c, int) for c in color):
return tuple(c / 255. for c in color)
elif all(isinstance(c, float) for c in color):
return tuple(color)
else:
raise ValueError('RGB values must be either all ints (0-255) '
'or all floats (0.0-1.0)')
raise TypeError('Expected color to be a string or tuple or list, '
f'but found {type(color)}.')
[docs]def rgb_to_int_array(rgb_array: np.ndarray) -> np.ndarray:
r = np.array(rgb_array[..., 0], dtype=np.uint32) * (1 << 16)
g = np.array(rgb_array[..., 1], dtype=np.uint32) * (1 << 8)
b = np.array(rgb_array[..., 2], dtype=np.uint32) * (1 << 0)
return r + g + b
[docs]def all_equal(it: list):
''' Returns true if all elements are equal to each other '''
return it.count(it[0]) == len(it)
[docs]def listify_uris(uris: str | list[str]) -> list[str]:
"""Convert to URI to list if needed."""
if isinstance(uris, (list, tuple)):
pass
elif isinstance(uris, str):
uris = [uris]
else:
raise TypeError(f'Expected str or list[str], but got {type(uris)}.')
return uris
[docs]def match_bboxes(raster_source: 'RasterSource',
label_source: 'LabelSource | LabelStore') -> None:
"""Set ``label_souce`` bbox equal to ``raster_source`` bbox.
Logs a warning if ``raster_source`` and ``label_source`` extents don't
intersect when converted to map coordinates.
Args:
raster_source (RasterSource): Source of imagery for a scene.
label_source (LabelSource | LabelStore): Source of labels for a
scene. Can be a ``LabelStore``.
"""
crs_tf_img = raster_source.crs_transformer
crs_tf_label = label_source.crs_transformer
bbox_img_map = crs_tf_img.pixel_to_map(raster_source.bbox)
if label_source.bbox is not None:
bbox_label_map = crs_tf_label.pixel_to_map(label_source.bbox)
if not bbox_img_map.intersects(bbox_label_map):
rs_cls = type(raster_source).__name__
ls_cls = type(label_source).__name__
log.warning(f'{rs_cls} bbox ({bbox_img_map}) does '
f'not intersect with {ls_cls} bbox '
f'({bbox_label_map}).')
# set LabelStore bbox to RasterSource bbox
bbox_label_pixel = crs_tf_label.map_to_pixel(bbox_img_map)
label_source.set_bbox(bbox_label_pixel)
[docs]def parse_array_slices_2d(key: tuple | slice,
extent: Box) -> tuple[Box, list[Any | None]]:
"""Parse 2D array-indexing inputs into a Box and slices."""
return parse_array_slices_Nd(key, extent, dims=2, h_dim=0, w_dim=1)
[docs]def parse_array_slices_Nd(key: tuple | slice,
extent: Box,
dims: int = 3,
h_dim: int = -3,
w_dim: int = -2) -> tuple[Box, list[Any | None]]:
"""Parse multi-dim array-indexing inputs into a Box and slices.
Args:
key (tuple | slice): Input to __getitem__.
extent (Box): Extent of the raster/label source being indexed.
dims (int): Total available indexable dims. Defaults to 3.
h_dim (int): Index of height dim. Defaults to -3.
w_dim (int): Index of width dim. Defaults to -2.
Raises:
NotImplementedError: If not (1 <= dims <= 3).
TypeError: If key is not a slice or tuple.
IndexError: if not (1 <= len(key) <= dims).
TypeError: If the index for any of the dims is None.
ValueError: If more than one Ellipsis ("...") in the input.
ValueError: If h and w indices (first 2 dims) are not slices.
NotImplementedError: If input contains negative values.
Returns:
tuple[Box, list]: A Box representing the h and w slices and a list
containing slices/index-values for all the dims.
"""
if isinstance(key, slice):
key = [key]
elif isinstance(key, tuple):
pass
else:
raise TypeError('Unsupported key type.')
input_slices = list(key)
if not (1 <= len(input_slices) <= dims):
raise IndexError(f'Too many indices for {dims}-dimensional source.')
if any(s is None for s in input_slices):
raise TypeError('None is not a valid index.')
if Ellipsis in input_slices:
if input_slices.count(Ellipsis) > 1:
raise ValueError('Only one ellipsis is allowed.')
num_missing_dims = dims - (len(input_slices) - 1)
filler_slices = [slice(None)] * num_missing_dims
idx = input_slices.index(Ellipsis)
# at the start
if idx == 0:
dim_slices = filler_slices + input_slices[(idx + 1):]
# somewhere in the middle
elif idx < (len(input_slices) - 1):
dim_slices = (
input_slices[:idx] + filler_slices + input_slices[(idx + 1):])
# at the end
else:
dim_slices = input_slices[:idx] + filler_slices
else:
num_missing_dims = dims - len(input_slices)
filler_slices = [slice(None)] * num_missing_dims
dim_slices = input_slices + filler_slices
if dim_slices[h_dim] is None:
dim_slices[h_dim] = slice(None)
if dim_slices[w_dim] is None:
dim_slices[w_dim] = slice(None)
h, w = dim_slices[h_dim], dim_slices[w_dim]
if not (isinstance(h, slice) and isinstance(w, slice)):
raise ValueError('h and w indices must be slices.')
if any(x is not None and x < 0
for x in [h.start, h.stop, h.step, w.start, w.stop, w.step]):
raise NotImplementedError(
'Negative indices are currently not supported.')
# slices with missing endpoints get expanded to the extent limits
H, W = extent.size
_ymin = 0 if h.start is None else h.start
_xmin = 0 if w.start is None else w.start
_ymax = H if h.stop is None else h.stop
_xmax = W if w.stop is None else w.stop
window = Box(_ymin, _xmin, _ymax, _xmax)
h_slice, w_slice = window.to_slices(h.step, w.step)
dim_slices[h_dim] = h_slice
dim_slices[w_dim] = w_slice
return window, dim_slices
[docs]def ensure_json_serializable(obj: Any) -> dict:
"""Convert numpy types to JSON serializable equivalents."""
if obj is None or isinstance(obj, (str, int, bool)):
return obj
if isinstance(obj, dict):
return {k: ensure_json_serializable(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
return [ensure_json_serializable(o) for o in obj]
if isinstance(obj, np.ndarray):
return ensure_json_serializable(obj.tolist())
if isinstance(obj, np.integer):
return int(obj)
if isinstance(obj, (float, np.floating)):
if np.isnan(obj):
return None
return float(obj)
if isinstance(obj, Box):
return obj.to_dict()
return obj
[docs]def save_img(im_array: np.ndarray, output_path: str):
"""Save numpy array as image file."""
imsave(output_path, im_array)