Source code for rastervision.pytorch_learner.dataset.transform

from typing import Any, overload
from collections.abc import Callable
from enum import Enum

from pydantic import PositiveInt as PosInt
import numpy as np
import albumentations as A
import torch

from rastervision.pytorch_learner.object_detection_utils import BoxList

TransformFunc = Callable[[tuple[np.ndarray, Any], A.BasicTransform], tuple[
    np.ndarray, Any]]


[docs]class TransformType(Enum): noop = 'noop' classification = 'classification' regression = 'regression' object_detection = 'object_detection' semantic_segmentation = 'semantic_segmentation'
[docs]def apply_transform(transform: A.BasicTransform, **kwargs) -> dict: """Apply Albumentations transform to possibly batched images. In case of batched images, the same transform is applied to all of them. This is useful for when the images represent a time-series. Args: transform: An albumentations transform. **kwargs: Extra args for ``transform``. Returns: Output of ``transform``. If ``ndim == 4``, the transformed image in the dict is also 4-dimensional. """ img = kwargs['image'] if img.ndim == 3: return transform(**kwargs) if img.ndim != 4: raise NotImplementedError( f'Image should have 3 or 4 dims. Found {img.ndim}.') batch_size = len(img) if len(transform._additional_targets) != (batch_size - 1): additional_targets = {f'img{i}': 'image' for i in range(1, batch_size)} transform.add_targets(additional_targets) img = kwargs.pop('image') img_keys = transform._additional_targets.keys() img_args = dict(zip(img_keys, img[1:])) out = transform(image=img[0], **kwargs, **img_args) out['image'] = np.stack([out.pop('image')] + [out[k] for k in img_keys]) return out
@overload def classification_transformer( inp: tuple[np.ndarray, int], transform: A.BasicTransform | None) -> tuple[np.ndarray, np.ndarray]: ... @overload def classification_transformer( inp: tuple[np.ndarray, None], transform: A.BasicTransform | None) -> tuple[np.ndarray, None]: ...
[docs]def classification_transformer(inp, transform): """Apply transform to image only.""" x, y = inp x = np.array(x) if transform is not None: out = apply_transform(transform, image=x) x = out['image'] if y is not None: y = np.array(y, dtype=int) return x, y
@overload def regression_transformer( inp: tuple[np.ndarray, Any], transform: A.BasicTransform | None) -> tuple[np.ndarray, np.ndarray]: ... @overload def regression_transformer( inp: tuple[np.ndarray, None], transform: A.BasicTransform | None) -> tuple[np.ndarray, None]: ...
[docs]def regression_transformer(inp, transform): """Apply transform to image only.""" x, y = inp x = np.array(x) if transform is not None: out = apply_transform(transform, image=x) x = out['image'] if y is not None: y = np.array(y, dtype=float) return x, y
[docs]def yxyx_to_albu(yxyx: np.ndarray, img_size: tuple[PosInt, PosInt]) -> np.ndarray: """Unnormalized [ymin, xmin, ymax, xmax] to Albumentations format i.e. normalized [ymin, xmin, ymax, xmax]. """ h, w = img_size ymin, xmin, ymax, xmax = yxyx.T ymin, ymax = ymin / h, ymax / h xmin, xmax = xmin / w, xmax / w xmin = np.clip(xmin, 0., 1., out=xmin) xmax = np.clip(xmax, 0., 1., out=xmax) ymin = np.clip(ymin, 0., 1., out=ymin) ymax = np.clip(ymax, 0., 1., out=ymax) xyxy = np.stack([xmin, ymin, xmax, ymax], axis=1).reshape((-1, 4)) return xyxy
[docs]def xywh_to_albu(xywh: np.ndarray, img_size: tuple[PosInt, PosInt]) -> np.ndarray: """Unnormalized [xmin, ymin, w, h] to Albumentations format i.e. normalized [ymin, xmin, ymax, xmax]. """ h, w = img_size xmin, ymin, box_w, box_h = xywh.T ymin, box_h = ymin / h, box_h / h xmin, box_w = xmin / w, box_w / w xmin, ymin, xmax, ymax = xmin, ymin, xmin + box_w, ymin + box_h xmin = np.clip(xmin, 0., 1., out=xmin) xmax = np.clip(xmax, 0., 1., out=xmax) ymin = np.clip(ymin, 0., 1., out=ymin) ymax = np.clip(ymax, 0., 1., out=ymax) xyxy = np.stack([xmin, ymin, xmax, ymax], axis=1).reshape((-1, 4)) return xyxy
[docs]def albu_to_yxyx(xyxy: np.ndarray, img_size: tuple[PosInt, PosInt]) -> np.ndarray: """Albumentations format (i.e. normalized [ymin, xmin, ymax, xmax]) to unnormalized [ymin, xmin, ymax, xmax]. """ h, w = img_size xmin, ymin, xmax, ymax = xyxy.T xmin, ymin, xmax, ymax = xmin * w, ymin * h, xmax * w, ymax * h xmin = np.clip(xmin, 0., w, out=xmin) xmax = np.clip(xmax, 0., w, out=xmax) ymin = np.clip(ymin, 0., h, out=ymin) ymax = np.clip(ymax, 0., h, out=ymax) yxyx = np.stack([ymin, xmin, ymax, xmax], axis=1).reshape((-1, 4)) return yxyx
@overload def object_detection_transformer( inp: tuple[np.ndarray, tuple[np.ndarray, np.ndarray, str]], transform: A.BasicTransform | None ) -> tuple[torch.Tensor, BoxList | None]: ... @overload def object_detection_transformer( inp: tuple[np.ndarray, None], transform: A.BasicTransform | None) -> tuple[torch.Tensor, None]: ...
[docs]def object_detection_transformer(inp, transform): """Apply transform to image, bounding boxes, and labels. Also perform normalization and conversion to pytorch tensors. The transform's BBoxParams are expected to have the format 'albumentations' (i.e. normalized [ymin, xmin, ymax, xmax]). Args: inp: Tuple of the form: ``(image, (boxes, class_ids, box_format))``. box_format must be ``'yxyx'`` or ``'xywh'``. transform: A transform. Defaults to ``None``. Raises: NotImplementedError: If box_format is not ``'yxyx'`` or ``'xywh'``. Returns: Transformed image and boxes. """ x, y = inp img_size = x.shape[:2] if y is not None: boxes, class_ids, box_format = y if transform is not None: if y is None: x = apply_transform( transform, image=x, bboxes=[], category_id=[])['image'] else: # The albumentations transform expects the bboxes to be in the # Albumentations format i.e. [ymin, xmin, ymax, xmax], so we convert to # that format before applying the transform. if box_format == 'yxyx': # used by ObjectDetectionGeoDataset boxes = yxyx_to_albu(boxes, img_size) elif box_format == 'xywh': # used by ObjectDetectionImageDataset boxes = xywh_to_albu(boxes, img_size) else: raise NotImplementedError(f'Unknown box_format: {box_format}.') out = apply_transform( transform, image=x, bboxes=boxes, category_id=class_ids) x = out['image'] boxes = np.array(out['bboxes']).reshape((-1, 4)) class_ids = np.array(out['category_id']) if len(boxes) > 0: boxes = albu_to_yxyx(boxes, x.shape[:2]) new_box_format = 'yxyx' elif y is not None: new_box_format = box_format if y is not None: boxes = torch.from_numpy(boxes).float() class_ids = torch.from_numpy(class_ids).long() if len(boxes) == 0: boxes = torch.empty((0, 4)).float() y = BoxList(boxes, format=new_box_format, class_ids=class_ids) # normalize x if np.issubdtype(x.dtype, np.unsignedinteger): max_val = np.iinfo(x.dtype).max x = x.astype(float) / max_val # convert to pytorch x = torch.from_numpy(x).permute(2, 0, 1).float() return x, y
@overload def semantic_segmentation_transformer( inp: tuple[np.ndarray, np.ndarray], transform: A.BasicTransform | None ) -> tuple[np.ndarray, np.ndarray | None]: ... @overload def semantic_segmentation_transformer( inp: tuple[np.ndarray, None], transform: A.BasicTransform | None) -> tuple[np.ndarray, None]: ...
[docs]def semantic_segmentation_transformer(inp, transform): """Apply transform to image and mask.""" x, y = inp x = np.array(x) if transform is not None: if y is None: x = apply_transform(transform, image=x)['image'] else: y = np.array(y) out = apply_transform(transform, image=x, mask=y) x, y = out['image'], out['mask'] if y is not None: y = y.astype(int) return x, y
TF_TYPE_TO_TF_FUNC: dict[TransformType, TransformFunc] = { TransformType.noop: lambda x, tf: x, TransformType.classification: classification_transformer, TransformType.regression: regression_transformer, TransformType.object_detection: object_detection_transformer, TransformType.semantic_segmentation: semantic_segmentation_transformer }