Source code for rastervision.core.data.raster_source.temporal_multi_raster_source

from typing import Any, Optional, Sequence, Tuple
from pydantic import conint

import numpy as np

from rastervision.core.box import Box
from rastervision.core.data.raster_source import (RasterSource,
                                                  MultiRasterSource)
from rastervision.core.data.utils import all_equal, parse_array_slices_Nd


[docs]class TemporalMultiRasterSource(MultiRasterSource): """Merge multiple ``RasterSources`` by stacking them along a new dim."""
[docs] def __init__(self, raster_sources: Sequence[RasterSource], primary_source_idx: conint(ge=0) = 0, force_same_dtype: bool = False, raster_transformers: Sequence = [], bbox: Optional[Box] = None): """Constructor. Args: raster_sources (Sequence[RasterSource]): Sequence of RasterSources. primary_source_idx (0 <= int < len(raster_sources)): Index of the raster source whose CRS, dtype, and other attributes will override those of the other raster sources. force_same_dtype (bool): If true, force all sub-chips to have the same dtype as the primary_source_idx-th sub-chip. No careful conversion is done, just a quick cast. Use with caution. raster_transformers (Sequence): Sequence of transformers. Defaults to []. bbox (Optional[Box]): User-specified crop of the extent. If given, the primary raster source's bbox is set to this. If None, the full extent available in the source file of the primary raster source is used. """ if not all_equal([rs.num_channels for rs in raster_sources]): raise ValueError( 'All sub raster sources must have the same num_channels.') # validate primary_source_idx if not (0 <= primary_source_idx < len(raster_sources)): raise IndexError('primary_source_idx must be in range ' '[0, len(raster_sources)].') primary_rs = raster_sources[primary_source_idx] num_channels_raw = primary_rs.num_channels_raw channel_order = None if bbox is None: bbox = primary_rs.bbox else: primary_rs.set_bbox(bbox) RasterSource.__init__( self, channel_order, num_channels_raw, bbox=bbox, raster_transformers=raster_transformers) self.force_same_dtype = force_same_dtype self.raster_sources = raster_sources self.primary_source_idx = primary_source_idx self.non_primary_sources = [ rs for rs in self.raster_sources if rs != self.primary_source ] self.validate_raster_sources()
def _get_chip(self, window: Box, out_shape: Optional[Tuple[int, int]] = None) -> np.ndarray: """Get chip w/o applying channel_order and transformers. Args: window (Box): The window for which to get the chip, in pixel coordinates. out_shape (Optional[Tuple[int, int]]): (height, width) to resize the chip to. Returns: np.ndarray: 4D array of shape (T, H, W, C). """ sub_chips = self._get_sub_chips(window, out_shape=out_shape) chip = np.stack(sub_chips) return chip
[docs] def get_chip(self, window: Box, out_shape: Optional[Tuple[int, int]] = None) -> np.ndarray: """Return the transformed chip in the window. Get processed chips from sub raster sources (with their respective channel orders and transformations applied), stack them along a new temporal dimension, apply channel_order, followed by transformations. Args: window (Box): The window for which to get the chip, in pixel coordinates. out_shape (Optional[Tuple[int, int]]): (height, width) to resize the chip to. Returns: np.ndarray: 4D array of shape (T, H, W, C). """ sub_chips = self._get_sub_chips(window, out_shape=out_shape) chip = np.stack(sub_chips) for transformer in self.raster_transformers: chip = transformer.transform(chip, self.channel_order) return chip
[docs] def __getitem__(self, key: Any) -> 'np.ndarray': if isinstance(key, Box): return self.get_chip(key) window, (t, h, w, c) = parse_array_slices_Nd( key, extent=self.extent, dims=4) chip = self.get_chip(window) if h.step is not None or w.step is not None: chip = chip[:, ::h.step, ::w.step, :] chip = chip[t, ...] chip = chip[..., c] return chip
@property def shape(self) -> Tuple[int, int, int, int]: return (len(self.raster_sources), *self.primary_source.shape)