Source code for

from typing import Optional, Sequence, List, Tuple
from pydantic import conint

import numpy as np

from import Box
from import RasterSource
from import CRSTransformer
from import RasterioSource
from import all_equal

[docs]class MultiRasterSource(RasterSource): """Merge multiple ``RasterSources`` by concatenating along channel dim."""
[docs] def __init__(self, raster_sources: Sequence[RasterSource], primary_source_idx: conint(ge=0) = 0, force_same_dtype: bool = False, channel_order: Optional[Sequence[conint(ge=0)]] = None, raster_transformers: Sequence = [], extent: Optional[Box] = None): """Constructor. Args: raster_sources (Sequence[RasterSource]): Sequence of RasterSources. primary_source_idx (0 <= int < len(raster_sources)): Index of the raster source whose CRS, dtype, and other attributes will override those of the other raster sources. force_same_dtype (bool): If true, force all sub-chips to have the same dtype as the primary_source_idx-th sub-chip. No careful conversion is done, just a quick cast. Use with caution. channel_order (Sequence[conint(ge=0)], optional): Channel ordering that will be used by .get_chip(). Defaults to None. raster_transformers (Sequence, optional): Sequence of transformers. Defaults to []. """ num_channels_raw = sum(rs.num_channels_raw for rs in raster_sources) if not channel_order: num_channels = sum(rs.num_channels for rs in raster_sources) channel_order = list(range(num_channels)) # validate primary_source_idx if not (0 <= primary_source_idx < len(raster_sources)): raise IndexError('primary_source_idx must be in range ' '[0, len(raster_sources)].') if extent is None: extent = raster_sources[primary_source_idx].extent super().__init__( channel_order, num_channels_raw, raster_transformers=raster_transformers, extent=extent) self.force_same_dtype = force_same_dtype self.raster_sources = raster_sources self.primary_source_idx = primary_source_idx self.extents = [rs.extent for rs in self.raster_sources] self.all_extents_equal = all_equal(self.extents) self.validate_raster_sources()
[docs] def validate_raster_sources(self) -> None: """Validate sub-``RasterSources``. Checks if: - dtypes are same or ``force_same_dtype`` is True. - each sub-``RasterSource`` is a :class:`.RasterioSource` if extents not identical. """ dtypes = [rs.dtype for rs in self.raster_sources] if not self.force_same_dtype and not all_equal(dtypes): raise ValueError( 'dtypes of all sub raster sources must be the same. ' f'Got: {dtypes} ' '(Use force_same_dtype to cast all to the dtype of the ' 'primary source)') if not self.all_extents_equal: all_rasterio_sources = all( isinstance(rs, RasterioSource) for rs in self.raster_sources) if not all_rasterio_sources: raise NotImplementedError( 'Non-identical extents are only ' 'supported for RasterioSource raster sources.')
@property def primary_source(self) -> RasterSource: """Primary sub-``RasterSource``""" return self.raster_sources[self.primary_source_idx] @property def dtype(self) -> np.dtype: return self.primary_source.dtype @property def crs_transformer(self) -> CRSTransformer: return self.primary_source.crs_transformer def _get_sub_chips(self, window: Box, raw: bool = False) -> List[np.ndarray]: """If all extents are identical, simply retrieves chips from each sub raster source. Otherwise, follows the following algorithm - using pixel-coords window, get chip from the primary sub raster source - convert window to world coords using the CRS of the primary sub raster source - for each remaining sub raster source - convert world-coords window to pixel coords using the sub raster source's CRS - get chip from the sub raster source using this window; specify `out_shape` when reading to ensure shape matches reference chip from the primary sub raster source Args: window (Box): window to read, in pixel coordinates. raw (bool, optional): If True, uses RasterSource._get_chip. Otherwise, RasterSource.get_chip. Defaults to False. Returns: List[np.ndarray]: List of chips from each sub raster source. """ def get_chip( rs: RasterSource, window: Box, out_shape: Optional[Tuple[int, int]] = None) -> np.ndarray: if raw: return rs._get_chip(window, out_shape=out_shape) return rs.get_chip(window, out_shape=out_shape) if self.all_extents_equal: sub_chips = [get_chip(rs, window) for rs in self.raster_sources] else: primary_rs = self.primary_source other_rses = [rs for rs in self.raster_sources if rs != primary_rs] primary_sub_chip = get_chip(primary_rs, window) out_shape = primary_sub_chip.shape[:2] world_window = primary_rs.crs_transformer.pixel_to_map(window) pixel_windows = [ rs.crs_transformer.map_to_pixel(world_window) for rs in other_rses ] sub_chips = [ get_chip(rs, w, out_shape=out_shape) for rs, w in zip(other_rses, pixel_windows) ] sub_chips.insert(self.primary_source_idx, primary_sub_chip) if self.force_same_dtype: dtype = sub_chips[self.primary_source_idx].dtype sub_chips = [chip.astype(dtype) for chip in sub_chips] return sub_chips def _get_chip(self, window: Box) -> np.ndarray: """Return the raw chip located in the window. Get raw chips from sub raster sources and concatenate them. Args: window: Box Returns: [height, width, channels] numpy array """ sub_chips = self._get_sub_chips(window, raw=True) chip = np.concatenate(sub_chips, axis=-1) return chip
[docs] def get_chip(self, window: Box) -> np.ndarray: """Return the transformed chip in the window. Get processed chips from sub raster sources (with their respective channel orders and transformations applied), concatenate them along the channel dimension, apply channel_order, followed by transformations. Args: window: Box Returns: np.ndarray with shape [height, width, channels] """ sub_chips = self._get_sub_chips(window, raw=False) chip = np.concatenate(sub_chips, axis=-1) chip = chip[..., self.channel_order] for transformer in self.raster_transformers: chip = transformer.transform(chip, self.channel_order) return chip