Source code for rastervision.pytorch_backend.utils

from typing import TYPE_CHECKING, List, Tuple

import numpy as np

from rastervision.pytorch_learner.object_detection_utils import BoxList

if TYPE_CHECKING:
    from torch import Tensor
    from rastervision.core.box import Box


[docs]def chip_collate_fn_ss(batch: List[Tuple[Tuple[np.ndarray, np.ndarray], 'Box']] ) -> Tuple[Tuple[np.ndarray, np.ndarray], List['Box']]: xs = np.stack([x for (x, _), _ in batch]) ys = [None if np.isnan(y).all() else y for (_, y), _ in batch] ws = [w for (_, _), w in batch] return (xs, ys), ws
[docs]def chip_collate_fn_cc(batch: List[Tuple[Tuple[np.ndarray, np.ndarray], 'Box']] ) -> Tuple[Tuple[np.ndarray, np.ndarray], List['Box']]: xs = np.stack([x for (x, _), _ in batch]) ys = [None if np.isnan(y).all() else y for (_, y), _ in batch] ws = [w for (_, _), w in batch] return (xs, ys), ws
[docs]def chip_collate_fn_od(batch: List[Tuple[Tuple['Tensor', 'BoxList'], 'Box']] ) -> Tuple[Tuple[np.ndarray, 'BoxList'], List['Box']]: xs = np.stack([x.numpy() for (x, _), _ in batch]) # (..., c, h, w) --> (..., h, w, c) xs = xs.swapaxes(-3, -2).swapaxes(-2, -1) xs = (xs * 255).astype(np.uint8) ys = [y if isinstance(y, BoxList) else None for (_, y), _ in batch] ws = [w for (_, _), w in batch] return (xs, ys), ws