Source code for rastervision.core.data.raster_source.raster_source

from typing import TYPE_CHECKING, Any, List, Optional, Tuple
from abc import ABC, abstractmethod, abstractproperty

from rastervision.core.box import Box

if TYPE_CHECKING:
    from rastervision.core.data import (CRSTransformer, RasterTransformer)
    import numpy as np


[docs]class ChannelOrderError(Exception):
[docs] def __init__(self, channel_order: List[int], num_channels_raw: int): self.channel_order = channel_order self.num_channels_raw = num_channels_raw msg = (f'The channel_order ({channel_order}) contains an' f'index >= num_channels_raw ({num_channels_raw}).') super().__init__(msg)
[docs]class RasterSource(ABC): """A source of raster data. This should be subclassed when adding a new source of raster data such as a set of files, an API, a TMS URI schema, etc. """
[docs] def __init__(self, channel_order: Optional[List[int]], num_channels_raw: int, raster_transformers: List['RasterTransformer'] = [], extent: Optional[Box] = None): """Constructor. Args: channel_order: list of channel indices to use when extracting chip from raw imagery. num_channels_raw: Number of channels in the raw imagery before applying channel_order. raster_transformers: ``RasterTransformers`` for transforming chips whenever they are retrieved. Defaults to ``[]``. extent: Use-specified extent. If None, the full extent of the raster source is used. """ if channel_order is None: channel_order = list(range(num_channels_raw)) if any(c >= num_channels_raw for c in channel_order): raise ChannelOrderError(channel_order, num_channels_raw) self.channel_order = channel_order self.num_channels_raw = num_channels_raw self.raster_transformers = raster_transformers self._extent = extent
@property def num_channels(self) -> int: """Number of channels in the chips read from this source.""" return len(self.channel_order) @property def shape(self) -> Tuple[int, int, int]: """Shape of the raster as a (height, width, num_channels) tuple.""" ymin, xmin, ymax, xmax = self.extent return ymax - ymin, xmax - xmin, self.num_channels @abstractproperty def dtype(self) -> 'np.dtype': """``numpy.dtype`` of the chips read from this source.""" pass @property def extent(self) -> 'Box': """Extent of the RasterSource.""" return self._extent @abstractproperty def crs_transformer(self) -> 'CRSTransformer': """Associated :class:`.CRSTransformer`.""" pass @abstractmethod def _get_chip(self, window: 'Box') -> 'np.ndarray': """Return raw chip without applying channel_order or transforms. Args: window: Box Returns: [height, width, channels] numpy array """ pass def __getitem__(self, key: Any) -> 'np.ndarray': if isinstance(key, Box): return self.get_chip(key) elif isinstance(key, slice): key = [key] elif isinstance(key, tuple): pass else: raise TypeError('Unsupported key type.') slices = list(key) assert 1 <= len(slices) <= 2 assert all(s is not None for s in slices) assert isinstance(slices[0], slice) if len(slices) == 1: h, = slices w = slice(None, None) else: assert isinstance(slices[1], slice) h, w = slices if any(x is not None and x < 0 for x in [h.start, h.stop, w.start, w.stop]): raise NotImplementedError() ymin, xmin, ymax, xmax = self.extent _ymin = 0 if h.start is None else h.start _xmin = 0 if w.start is None else w.start _ymax = ymax if h.stop is None else h.stop _xmax = xmax if w.stop is None else w.stop window = Box(_ymin, _xmin, _ymax, _xmax) chip = self.get_chip(window) if h.step is not None or w.step is not None: chip = chip[::h.step, ::w.step] return chip
[docs] def get_chip(self, window: 'Box') -> 'np.ndarray': """Return the transformed chip in the window. Get a raw chip, extract subset of channels using channel_order, and then apply transformations. Args: window (Box): The window for which to get the chip. Returns: np.ndarray: Array of shape (height, width, channels). """ chip = self._get_chip(window) chip = chip[:, :, self.channel_order] for transformer in self.raster_transformers: chip = transformer.transform(chip, self.channel_order) return chip
[docs] def get_raw_chip(self, window: 'Box') -> 'np.ndarray': """Return raw chip without applying channel_order or transforms. Args: window (Box): The window for which to get the chip. Returns: np.ndarray: Array of shape (height, width, channels). """ return self._get_chip(window)
[docs] def get_image_array(self) -> 'np.ndarray': """Return entire transformed image array. .. warning:: Not safe to call on very large RasterSources. Returns: np.ndarray: Array of shape (height, width, channels). """ return self.get_chip(self.extent)
[docs] def get_raw_image_array(self) -> 'np.ndarray': """Return raw image for the full extent. .. warning:: Not safe to call on very large RasterSources. Returns: np.ndarray: Array of shape (height, width, channels). """ return self.get_raw_chip(self.extent)