Source code for rastervision.core.data.raster_source.xarray_source

from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, Union
import logging

import numpy as np
from xarray import DataArray

from rastervision.core.box import Box
from rastervision.core.data.crs_transformer import RasterioCRSTransformer
from rastervision.core.data.raster_source import RasterSource
from rastervision.core.data.utils import parse_array_slices_Nd, fill_overflow

if TYPE_CHECKING:
    from pystac import Item, ItemCollection
    from rastervision.core.data import RasterTransformer, CRSTransformer

log = logging.getLogger(__name__)


[docs]class XarraySource(RasterSource): """A RasterSource for reading an Xarry DataArray. .. warning:: ``XarraySource`` API is in beta. """
[docs] def __init__(self, data_array: DataArray, crs_transformer: 'CRSTransformer', raster_transformers: List['RasterTransformer'] = [], channel_order: Optional[Sequence[int]] = None, bbox: Optional[Box] = None, temporal: bool = False): """Constructor. Args: uris (Union[str, List[str]]): One or more URIs of images. If more than one, the images will be mosaiced together using GDAL. crs_transformer (CRSTransformer): A CRSTransformer defining the mapping between pixel and map coords. raster_transformers (List['RasterTransformer']): RasterTransformers to use to transform chips after they are read. channel_order (Optional[Sequence[int]]): List of indices of channels to extract from raw imagery. Can be a subset of the available channels. If None, all channels available in the image will be read. Defaults to None. bbox (Optional[Box], optional): User-specified crop of the extent. If None, the full extent available in the source file is used. temporal (bool): If True, data_array is expected to have a "time" dimension and the chips returned will be of shape (T, H, W, C). """ self.temporal = temporal if self.temporal: if set(data_array.dims) != {'x', 'y', 'band', 'time'}: raise ValueError( 'If temporal=True, data_array must have 4 dimensions: ' '"x", "y", "band", and "time" (in any order).') else: if set(data_array.dims) != {'x', 'y', 'band'}: raise ValueError( 'If temporal=False, data_array must have 3 dimensions: ' '"x", "y", and "band" (in any order).') self.data_array = data_array.transpose(..., 'y', 'x', 'band') self.ndim = data_array.ndim self._crs_transformer = crs_transformer num_channels_raw = len(data_array.band) if channel_order is None: channel_order = np.arange(num_channels_raw, dtype=int) else: channel_order = np.array(channel_order, dtype=int) self._num_channels = None self._dtype = None if len(raster_transformers) == 0: self._num_channels = len(channel_order) self._dtype = data_array.dtype height, width = len(data_array.y), len(data_array.x) self.full_extent = Box(0, 0, height, width) if bbox is None: bbox = self.full_extent else: if bbox not in self.full_extent: new_bbox = bbox.intersection(self.full_extent) log.warning(f'Clipping ({bbox}) to the DataArray\'s ' f'full extent ({self.full_extent}). ' f'New bbox={new_bbox}') bbox = new_bbox super().__init__( channel_order, num_channels_raw, raster_transformers=raster_transformers, bbox=bbox)
[docs] @classmethod def from_stac( cls, item_or_item_collection: Union['Item', 'ItemCollection'], raster_transformers: List['RasterTransformer'] = [], channel_order: Optional[Sequence[int]] = None, bbox: Optional[Box] = None, bbox_map_coords: Optional[Box] = None, temporal: bool = False, allow_streaming: bool = False, stackstac_args: dict = dict(rescale=False)) -> 'XarraySource': """Construct an ``XarraySource`` from a STAC Item or ItemCollection. Args: item_or_item_collection: STAC Item or ItemCollection. raster_transformers: RasterTransformers to use to transform chips after they are read. channel_order: List of indices of channels to extract from raw imagery. Can be a subset of the available channels. If None, all channels available in the image will be read. Defaults to None. bbox: User-specified crop of the extent. If None, the full extent available in the source file is used. Mutually exclusive with ``bbox_map_coords``. Defaults to ``None``. bbox_map_coords: User-specified bbox in EPSG:4326 coords of the form (ymin, xmin, ymax, xmax). Useful for cropping the raster source so that only part of the raster is read from. Mutually exclusive with ``bbox``. Defaults to ``None``. temporal: If True, data_array is expected to have a "time" dimension and the chips returned will be of shape (T, H, W, C). allow_streaming: If False, load the entire DataArray into memory. Defaults to True. stackstac_args: Optional arguments to pass to stackstac.stack(). """ import stackstac data_array = stackstac.stack(item_or_item_collection, **stackstac_args) if not temporal and 'time' in data_array.dims: if len(data_array.time) > 1: raise ValueError('temporal=False but len(data_array.time) > 1') data_array = data_array.isel(time=0) if not allow_streaming: from humanize import naturalsize log.info('Loading the full DataArray into memory ' f'({naturalsize(data_array.nbytes)}).') data_array.load() crs_transformer = RasterioCRSTransformer( transform=data_array.transform, image_crs=data_array.crs) if bbox is not None: if bbox_map_coords is not None: raise ValueError('Specify either bbox or bbox_map_coords, ' 'but not both.') bbox = Box(*bbox) elif bbox_map_coords is not None: bbox_map_coords = Box(*bbox_map_coords) bbox = crs_transformer.map_to_pixel(bbox_map_coords).normalize() else: bbox = None raster_source = XarraySource( data_array, crs_transformer=crs_transformer, raster_transformers=raster_transformers, channel_order=channel_order, bbox=bbox, temporal=temporal) return raster_source
@property def shape(self) -> Tuple[int, int, int]: """Shape of the raster as a (height, width, num_channels) tuple.""" H, W = self.bbox.size if self.temporal: T = len(self.data_array.time) return T, H, W, self.num_channels return H, W, self.num_channels @property def num_channels(self) -> int: """Number of channels in the chips read from this source. .. note:: Unlike the parent class, ``XarraySource`` applies ``channel_order`` before ``raster_transformers``. So the number of output channels is not guaranteed to be equal to ``len(channel_order)``. """ if self._num_channels is None: self._set_info_from_chip() return self._num_channels @property def dtype(self) -> np.dtype: if self._dtype is None: self._set_info_from_chip() return self._dtype @property def crs_transformer(self) -> RasterioCRSTransformer: return self._crs_transformer def _set_info_from_chip(self): """Read 1x1 chip to get info not statically inferable.""" test_chip = self.get_chip(Box(0, 0, 1, 1)) self._dtype = test_chip.dtype self._num_channels = test_chip.shape[-1] def _get_chip(self, window: Box, bands: Union[int, Sequence[int], slice] = slice(None), time: Union[int, Sequence[int], slice] = slice(None), out_shape: Optional[Tuple[int, ...]] = None) -> np.ndarray: window = window.to_global_coords(self.bbox) window_within_bbox = window.intersection(self.bbox) yslice, xslice = window_within_bbox.to_slices() if self.temporal: chip = self.data_array.isel( x=xslice, y=yslice, band=bands, time=time).to_numpy() else: chip = self.data_array.isel( x=xslice, y=yslice, band=bands).to_numpy() if window != window_within_bbox: *batch_dims, h, w, c = chip.shape # coords of window_within_bbox within window yslice, xslice = window_within_bbox.to_local_coords( window).to_slices() tmp = np.zeros((*batch_dims, *window.size, c)) tmp[..., yslice, xslice, :] = chip chip = tmp chip = fill_overflow(self.bbox, window, chip) if out_shape is not None: chip = self.resize(chip, out_shape) return chip
[docs] def get_chip(self, window: Box, bands: Optional[Union[int, Sequence[int], slice]] = None, time: Union[int, Sequence[int], slice] = slice(None), out_shape: Optional[Tuple[int, ...]] = None) -> np.ndarray: """Read a chip specified by a window from the file. Args: window (Box): Bounding box of chip in pixel coordinates. bands (Optional[Union[Sequence[int], slice]], optional): Subset of bands to read. Note that this will be applied on top of the channel_order (if specified). So if this is an RGB image and channel_order=[2, 1, 0], then using bands=[0] will return the B-channel. Defaults to None. out_shape (Optional[Tuple[int, ...]], optional): (height, width) of the output chip. If None, no resizing is done. Defaults to None. Returns: np.ndarray: A chip of shape (height, width, channels). """ if bands is None or bands == slice(None): bands = self.channel_order else: bands = self.channel_order[bands] chip = self._get_chip( window, bands=bands, time=time, out_shape=out_shape) for transformer in self.raster_transformers: chip = transformer.transform(chip, bands) return chip
[docs] def __getitem__(self, key: Any) -> 'np.ndarray': if isinstance(key, Box): return self.get_chip(key) window, dim_slices = parse_array_slices_Nd( key, extent=self.extent, dims=self.ndim) if self.temporal: t, h, w, c = dim_slices else: h, w, c = dim_slices t = None out_shape = None if h.step is not None or w.step is not None: out_h, out_w = window.size if h.step is not None: out_h //= h.step if w.step is not None: out_w //= w.step out_shape = (int(out_h), int(out_w)) chip = self.get_chip(window, bands=c, time=t, out_shape=out_shape) return chip