Source code for rastervision.core.data.raster_source.multi_raster_source

from typing import Optional, Sequence, List, Tuple
from pydantic import conint

import numpy as np

from rastervision.core.box import Box
from rastervision.core.data.raster_source import RasterSource
from rastervision.core.data.crs_transformer import CRSTransformer
from rastervision.core.data.utils import all_equal


[docs]class MultiRasterSource(RasterSource): """Merge multiple ``RasterSources`` by concatenating along channel dim."""
[docs] def __init__(self, raster_sources: Sequence[RasterSource], primary_source_idx: conint(ge=0) = 0, force_same_dtype: bool = False, channel_order: Optional[Sequence[conint(ge=0)]] = None, raster_transformers: Sequence = [], bbox: Optional[Box] = None): """Constructor. Args: raster_sources (Sequence[RasterSource]): Sequence of RasterSources. primary_source_idx (0 <= int < len(raster_sources)): Index of the raster source whose CRS, dtype, and other attributes will override those of the other raster sources. force_same_dtype (bool): If true, force all sub-chips to have the same dtype as the primary_source_idx-th sub-chip. No careful conversion is done, just a quick cast. Use with caution. channel_order (Sequence[conint(ge=0)], optional): Channel ordering that will be used by .get_chip(). Defaults to None. raster_transformers (Sequence, optional): Sequence of transformers. Defaults to []. extent (Optional[Box], optional): User-specified extent. If given, the primary raster source's extent is set to this. If None, the full extent of the primary raster source is used. bbox (Optional[Box], optional): User-specified crop of the extent. If given, the primary raster source's bbox is set to this. If None, the full extent available in the source file of the primary raster source is used. """ num_channels_raw = sum(rs.num_channels_raw for rs in raster_sources) if not channel_order: num_channels = sum(rs.num_channels for rs in raster_sources) channel_order = list(range(num_channels)) # validate primary_source_idx if not (0 <= primary_source_idx < len(raster_sources)): raise IndexError('primary_source_idx must be in range ' '[0, len(raster_sources)].') if bbox is None: bbox = raster_sources[primary_source_idx].bbox else: raster_sources[primary_source_idx].set_bbox(bbox) super().__init__( channel_order, num_channels_raw, bbox=bbox, raster_transformers=raster_transformers) self.force_same_dtype = force_same_dtype self.raster_sources = raster_sources self.primary_source_idx = primary_source_idx self.validate_raster_sources()
[docs] def validate_raster_sources(self) -> None: """Validate sub-``RasterSources``. Checks if: - dtypes are same or ``force_same_dtype`` is True. """ dtypes = [rs.dtype for rs in self.raster_sources] if not self.force_same_dtype and not all_equal(dtypes): raise ValueError( 'dtypes of all sub raster sources must be the same. ' f'Got: {dtypes} ' '(Use force_same_dtype to cast all to the dtype of the ' 'primary source)')
@property def primary_source(self) -> RasterSource: """Primary sub-``RasterSource``""" return self.raster_sources[self.primary_source_idx] @property def dtype(self) -> np.dtype: return self.primary_source.dtype @property def crs_transformer(self) -> CRSTransformer: return self.primary_source.crs_transformer def _get_sub_chips( self, window: Box, raw: bool = False, out_shape: Optional[Tuple[int, int]] = None) -> List[np.ndarray]: """If all extents are identical, simply retrieves chips from each sub raster source. Otherwise, follows the following algorithm - using pixel-coords window, get chip from the primary sub raster source - convert window to world coords using the CRS of the primary sub raster source - for each remaining sub raster source - convert world-coords window to pixel coords using the sub raster source's CRS - get chip from the sub raster source using this window; specify `out_shape` when reading to ensure shape matches reference chip from the primary sub raster source Args: window (Box): window to read, in pixel coordinates. raw (bool, optional): If True, uses RasterSource._get_chip. Otherwise, RasterSource.get_chip. Defaults to False. Returns: List[np.ndarray]: List of chips from each sub raster source. """ def get_chip( rs: RasterSource, window: Box, out_shape: Optional[Tuple[int, int]] = None) -> np.ndarray: if raw: return rs._get_chip(window, out_shape=out_shape) return rs.get_chip(window, out_shape=out_shape) primary_rs = self.primary_source other_rses = [rs for rs in self.raster_sources if rs != primary_rs] primary_sub_chip = get_chip(primary_rs, window, out_shape=out_shape) out_shape = primary_sub_chip.shape[:2] world_window = primary_rs.crs_transformer.pixel_to_map( window, bbox=primary_rs.bbox) pixel_windows = [ rs.crs_transformer.map_to_pixel(world_window, bbox=rs.bbox) for rs in other_rses ] sub_chips = [ get_chip(rs, w, out_shape=out_shape) for rs, w in zip(other_rses, pixel_windows) ] sub_chips.insert(self.primary_source_idx, primary_sub_chip) if self.force_same_dtype: dtype = sub_chips[self.primary_source_idx].dtype sub_chips = [chip.astype(dtype) for chip in sub_chips] return sub_chips def _get_chip(self, window: Box, out_shape: Optional[Tuple[int, int]] = None) -> np.ndarray: """Return the raw chip located in the window. Get raw chips from sub raster sources and concatenate them. Args: window: Box Returns: [height, width, channels] numpy array """ sub_chips = self._get_sub_chips(window, raw=True, out_shape=out_shape) chip = np.concatenate(sub_chips, axis=-1) return chip
[docs] def get_chip(self, window: Box, out_shape: Optional[Tuple[int, int]] = None) -> np.ndarray: """Return the transformed chip in the window. Get processed chips from sub raster sources (with their respective channel orders and transformations applied), concatenate them along the channel dimension, apply channel_order, followed by transformations. Args: window: Box Returns: np.ndarray with shape [height, width, channels] """ sub_chips = self._get_sub_chips(window, raw=False, out_shape=out_shape) chip = np.concatenate(sub_chips, axis=-1) chip = chip[..., self.channel_order] for transformer in self.raster_transformers: chip = transformer.transform(chip, self.channel_order) return chip