Source code for rastervision.core.data.raster_source.multi_raster_source

from typing import TYPE_CHECKING, Sequence

from pydantic import NonNegativeInt as NonNegInt
import numpy as np
from pystac import Item

from rastervision.core.box import Box
from rastervision.core.data.raster_source import RasterSource, RasterioSource
from rastervision.core.data.raster_source.stac_config import subset_assets
from rastervision.core.data.utils import all_equal

if TYPE_CHECKING:
    from typing import Self
    from rastervision.core.data import RasterTransformer, CRSTransformer


[docs]class MultiRasterSource(RasterSource): """ Merge multiple ``RasterSources`` by concatenating along the channel dim. """
[docs] def __init__(self, raster_sources: Sequence[RasterSource], primary_source_idx: NonNegInt = 0, channel_order: Sequence[NonNegInt] | None = None, raster_transformers: Sequence['RasterTransformer'] = [], bbox: Box | None = None): """Constructor. Args: raster_sources: Sequence of RasterSources. primary_source_idx: Index of the raster source whose CRS, dtype, and other attributes will override those of the other raster sources. channel_order: Channel ordering that will be used by :meth:`.MultiRasterSource.get_chip`. Defaults to ``None``. raster_transformers: List of transformers. Defaults to ``[]``. bbox: User-specified crop of the extent. If specified, the primary raster source's bbox is set to this. If ``None``, the full extent available in the source file of the primary raster source is used. """ dtype_raw = raster_sources[primary_source_idx].dtype num_channels_raw = sum(rs.num_channels for rs in raster_sources) if not channel_order: num_channels = sum(rs.num_channels for rs in raster_sources) channel_order = list(range(num_channels)) # validate primary_source_idx if not (0 <= primary_source_idx < len(raster_sources)): raise IndexError('primary_source_idx must be in range ' '[0, len(raster_sources)].') if bbox is None: bbox = raster_sources[primary_source_idx].bbox else: raster_sources[primary_source_idx].set_bbox(bbox) super().__init__( channel_order=channel_order, num_channels_raw=num_channels_raw, dtype_raw=dtype_raw, bbox=bbox, raster_transformers=raster_transformers) self.raster_sources = raster_sources self.primary_source_idx = primary_source_idx self.non_primary_sources = [ rs for i, rs in enumerate(raster_sources) if i != primary_source_idx ] self.validate_raster_sources()
[docs] @classmethod def from_stac( cls, item: Item, assets: list[str] | None, primary_source_idx: NonNegInt = 0, raster_transformers: list['RasterTransformer'] = [], channel_order: Sequence[int] | None = None, bbox: Box | tuple[int, int, int, int] | None = None, bbox_map_coords: Box | tuple[int, int, int, int] | None = None, allow_streaming: bool = False) -> 'Self': """Construct a ``MultiRasterSource`` from a STAC Item. This creates a :class:`.RasterioSource` for each asset and puts all the raster sources together into a ``MultiRasterSource``. If ``assets`` is not specified, all the assets in the STAC item are used. Only assets that are readable by rasterio are supported. Args: item: STAC Item. assets: List of names of assets to use. If ``None``, all assets present in the item will be used. Defaults to ``None``. primary_source_idx: Index of the raster source whose CRS, dtype, and other attributes will override those of the other raster sources. raster_transformers: RasterTransformers to use to transform chips after they are read. Defaults to ``[]``. channel_order: List of indices of channels to extract from raw imagery. Can be a subset of the available channels. If None, all channels available in the image will be read. Defaults to None. bbox: User-specified crop of the extent. Can be :class:`.Box` or (ymin, xmin, ymax, xmax) tuple. If None, the full extent available in the source file is used. Mutually exclusive with ``bbox_map_coords``. Defaults to ``None``. bbox_map_coords: User-specified bbox in EPSG:4326 coords. Can be :class:`.Box` or (ymin, xmin, ymax, xmax) tuple. Useful for cropping the raster source so that only part of the raster is read from. Mutually exclusive with ``bbox``. Defaults to ``None``. allow_streaming: Passed to :class:`.RasterioSource`. If ``False``, assets will be downloaded. Defaults to ``True``. """ if bbox is not None and bbox_map_coords is not None: raise ValueError('Specify either bbox or bbox_map_coords, ' 'but not both.') if assets is not None: item = subset_assets(item, assets) uris = [asset.href for asset in item.assets.values()] raster_sources = [ RasterioSource(uri, allow_streaming=allow_streaming) for uri in uris ] crs_transformer = raster_sources[primary_source_idx].crs_transformer if bbox_map_coords is not None: bbox_map_coords = Box(*bbox_map_coords) bbox = crs_transformer.map_to_pixel(bbox_map_coords).normalize() elif bbox is not None: bbox = Box(*bbox) raster_source = MultiRasterSource( raster_sources, primary_source_idx=primary_source_idx, raster_transformers=raster_transformers, channel_order=channel_order, bbox=bbox) return raster_source
[docs] def validate_raster_sources(self) -> None: """Validate sub-``RasterSources``. Checks if all raster sources have the same dtype. """ dtypes = [rs.dtype for rs in self.raster_sources] if not all_equal(dtypes): raise ValueError( 'dtypes of all sub raster sources must be the same. ' f'Got: {dtypes}.')
@property def primary_source(self) -> RasterSource: """Primary sub-``RasterSource``""" return self.raster_sources[self.primary_source_idx] @property def shape(self) -> tuple[int, ...]: """Shape of the raster as a (..., H, W, C) tuple.""" *shape, _ = self.primary_source.shape return (*shape, self.num_channels) @property def crs_transformer(self) -> 'CRSTransformer': return self.primary_source.crs_transformer def _get_sub_chips(self, window: Box, out_shape: tuple[int, int] | None = None ) -> list[np.ndarray]: """Return chips from sub raster sources as a list. If all extents are identical, simply retrieves chips from each sub raster source. Otherwise, follows the following algorithm - using pixel-coords window, get chip from the primary sub raster source - convert window to world coords using the CRS of the primary sub raster source - for each remaining sub raster source - convert world-coords window to pixel coords using the sub raster source's CRS - get chip from the sub raster source using this window; specify `out_shape` when reading to ensure shape matches reference chip from the primary sub raster source Args: window: The window for which to get the chip, in pixel coordinates. out_shape: (height, width) to resize the chip to. Returns: List of chips from each sub raster source. """ def get_chip(rs: RasterSource, window: Box, map: bool = False, out_shape: tuple[int, int] | None = None) -> np.ndarray: if map: func = rs.get_chip_by_map_window else: func = rs.get_chip return func(window, out_shape=out_shape) primary_rs = self.primary_source other_rses = self.non_primary_sources primary_sub_chip = get_chip(primary_rs, window, out_shape=out_shape) if out_shape is None: out_shape = primary_sub_chip.shape[:2] window_map_coords = primary_rs.crs_transformer.pixel_to_map( window, bbox=primary_rs.bbox) sub_chips = [ get_chip(rs, window_map_coords, map=True, out_shape=out_shape) for rs in other_rses ] sub_chips.insert(self.primary_source_idx, primary_sub_chip) return sub_chips def _get_chip(self, window: Box, out_shape: tuple[int, int] | None = None) -> np.ndarray: """Get chip w/o applying channel_order and transformers. Args: window (Box): The window for which to get the chip, in pixel coordinates. out_shape (tuple[int, int] | None): (height, width) to resize the chip to. Returns: Array of shape (height, width, channels). """ sub_chips = self._get_sub_chips(window, out_shape=out_shape) chip = np.concatenate(sub_chips, axis=-1) return chip
[docs] def get_chip(self, window: Box, out_shape: tuple[int, int] | None = None) -> np.ndarray: """Return the transformed chip in the window. Get processed chips from sub raster sources (with their respective channel orders and transformations applied), concatenate them along the channel dimension, apply channel_order, followed by transformations. Args: window: The window for which to get the chip, in pixel coordinates. out_shape: (height, width) to resize the chip to. Returns: Array of shape (height, width, channels). """ sub_chips = self._get_sub_chips(window, out_shape=out_shape) chip = np.concatenate(sub_chips, axis=-1) chip = chip[..., self.channel_order] for transformer in self.raster_transformers: chip = transformer.transform(chip) return chip