from typing import Optional, Sequence, List, Tuple
from pydantic import conint
import numpy as np
from rastervision.core.box import Box
from rastervision.core.data.raster_source import RasterSource
from rastervision.core.data.crs_transformer import CRSTransformer
from rastervision.core.data.raster_source.rasterio_source import RasterioSource
from rastervision.core.data.utils import all_equal
[docs]class MultiRasterSource(RasterSource):
"""Merge multiple ``RasterSources`` by concatenating along channel dim."""
[docs] def __init__(self,
raster_sources: Sequence[RasterSource],
primary_source_idx: conint(ge=0) = 0,
force_same_dtype: bool = False,
channel_order: Optional[Sequence[conint(ge=0)]] = None,
raster_transformers: Sequence = [],
extent: Optional[Box] = None):
"""Constructor.
Args:
raster_sources (Sequence[RasterSource]): Sequence of RasterSources.
primary_source_idx (0 <= int < len(raster_sources)): Index of the
raster source whose CRS, dtype, and other attributes will
override those of the other raster sources.
force_same_dtype (bool): If true, force all sub-chips to have the
same dtype as the primary_source_idx-th sub-chip. No careful
conversion is done, just a quick cast. Use with caution.
channel_order (Sequence[conint(ge=0)], optional): Channel ordering
that will be used by .get_chip(). Defaults to None.
raster_transformers (Sequence, optional): Sequence of transformers.
Defaults to [].
"""
num_channels_raw = sum(rs.num_channels_raw for rs in raster_sources)
if not channel_order:
num_channels = sum(rs.num_channels for rs in raster_sources)
channel_order = list(range(num_channels))
# validate primary_source_idx
if not (0 <= primary_source_idx < len(raster_sources)):
raise IndexError('primary_source_idx must be in range '
'[0, len(raster_sources)].')
if extent is None:
extent = raster_sources[primary_source_idx].extent
super().__init__(
channel_order,
num_channels_raw,
raster_transformers=raster_transformers,
extent=extent)
self.force_same_dtype = force_same_dtype
self.raster_sources = raster_sources
self.primary_source_idx = primary_source_idx
self.extents = [rs.extent for rs in self.raster_sources]
self.all_extents_equal = all_equal(self.extents)
self.validate_raster_sources()
[docs] def validate_raster_sources(self) -> None:
"""Validate sub-``RasterSources``.
Checks if:
- dtypes are same or ``force_same_dtype`` is True.
- each sub-``RasterSource`` is a :class:`.RasterioSource` if extents
not identical.
"""
dtypes = [rs.dtype for rs in self.raster_sources]
if not self.force_same_dtype and not all_equal(dtypes):
raise ValueError(
'dtypes of all sub raster sources must be the same. '
f'Got: {dtypes} '
'(Use force_same_dtype to cast all to the dtype of the '
'primary source)')
if not self.all_extents_equal:
all_rasterio_sources = all(
isinstance(rs, RasterioSource) for rs in self.raster_sources)
if not all_rasterio_sources:
raise NotImplementedError(
'Non-identical extents are only '
'supported for RasterioSource raster sources.')
@property
def primary_source(self) -> RasterSource:
"""Primary sub-``RasterSource``"""
return self.raster_sources[self.primary_source_idx]
@property
def dtype(self) -> np.dtype:
return self.primary_source.dtype
@property
def crs_transformer(self) -> CRSTransformer:
return self.primary_source.crs_transformer
def _get_sub_chips(self, window: Box,
raw: bool = False) -> List[np.ndarray]:
"""If all extents are identical, simply retrieves chips from each sub
raster source. Otherwise, follows the following algorithm
- using pixel-coords window, get chip from the primary sub raster
source
- convert window to world coords using the CRS of the primary sub
raster source
- for each remaining sub raster source
- convert world-coords window to pixel coords using the sub
raster source's CRS
- get chip from the sub raster source using this window;
specify `out_shape` when reading to ensure shape matches
reference chip from the primary sub raster source
Args:
window (Box): window to read, in pixel coordinates.
raw (bool, optional): If True, uses RasterSource._get_chip.
Otherwise, RasterSource.get_chip. Defaults to False.
Returns:
List[np.ndarray]: List of chips from each sub raster source.
"""
def get_chip(
rs: RasterSource,
window: Box,
out_shape: Optional[Tuple[int, int]] = None) -> np.ndarray:
if raw:
return rs._get_chip(window, out_shape=out_shape)
return rs.get_chip(window, out_shape=out_shape)
if self.all_extents_equal:
sub_chips = [get_chip(rs, window) for rs in self.raster_sources]
else:
primary_rs = self.primary_source
other_rses = [rs for rs in self.raster_sources if rs != primary_rs]
primary_sub_chip = get_chip(primary_rs, window)
out_shape = primary_sub_chip.shape[:2]
world_window = primary_rs.crs_transformer.pixel_to_map(window)
pixel_windows = [
rs.crs_transformer.map_to_pixel(world_window)
for rs in other_rses
]
sub_chips = [
get_chip(rs, w, out_shape=out_shape)
for rs, w in zip(other_rses, pixel_windows)
]
sub_chips.insert(self.primary_source_idx, primary_sub_chip)
if self.force_same_dtype:
dtype = sub_chips[self.primary_source_idx].dtype
sub_chips = [chip.astype(dtype) for chip in sub_chips]
return sub_chips
def _get_chip(self, window: Box) -> np.ndarray:
"""Return the raw chip located in the window.
Get raw chips from sub raster sources and concatenate them.
Args:
window: Box
Returns:
[height, width, channels] numpy array
"""
sub_chips = self._get_sub_chips(window, raw=True)
chip = np.concatenate(sub_chips, axis=-1)
return chip
[docs] def get_chip(self, window: Box) -> np.ndarray:
"""Return the transformed chip in the window.
Get processed chips from sub raster sources (with their respective
channel orders and transformations applied), concatenate them along the
channel dimension, apply channel_order, followed by transformations.
Args:
window: Box
Returns:
np.ndarray with shape [height, width, channels]
"""
sub_chips = self._get_sub_chips(window, raw=False)
chip = np.concatenate(sub_chips, axis=-1)
chip = chip[..., self.channel_order]
for transformer in self.raster_transformers:
chip = transformer.transform(chip, self.channel_order)
return chip