Source code for rastervision.core.data.raster_source.multi_raster_source_config

from typing import TYPE_CHECKING

from typing_extensions import Annotated
from pydantic import NonNegativeInt as NonNegInt

from rastervision.pipeline.config import (Field, register_config,
                                          model_validator)
from rastervision.core.box import Box
from rastervision.core.data.raster_source import (RasterSourceConfig,
                                                  MultiRasterSource)

if TYPE_CHECKING:
    from typing import Self


def multi_rs_config_upgrader(cfg_dict: dict, version: int) -> dict:
    if version == 1:
        # field renamed in version 2
        cfg_dict['primary_source_idx'] = cfg_dict.get('crs_source', 0)
        cfg_dict.pop('crs_source', None)
    elif version == 13:
        # field removed in version 14
        cfg_dict.pop('force_same_dtype', None)
    return cfg_dict


[docs]@register_config('multi_raster_source', upgrader=multi_rs_config_upgrader) class MultiRasterSourceConfig(RasterSourceConfig): """Configure a :class:`.MultiRasterSource`. Or :class:`.TemporalMultiRasterSource`, if ``temporal=True``. """ raster_sources: Annotated[list[ RasterSourceConfig], Field(min_length=1)] = Field( ..., description='List of RasterSourceConfig to combine.') primary_source_idx: NonNegInt = Field( 0, description= 'Index of the raster source whose CRS, dtype, and other attributes ' 'will override those of the other raster sources. Defaults to 0.') temporal: bool = Field( False, description='Stack images from sub raster sources into a time-series ' 'of shape (T, H, W, C) instead of concatenating bands.')
[docs] @model_validator(mode='after') def validate_primary_source_idx(self) -> 'Self': primary_source_idx = self.primary_source_idx raster_sources = self.raster_sources if not (0 <= primary_source_idx < len(raster_sources)): raise IndexError('primary_source_idx must be in range ' '[0, len(raster_sources)].') return self
[docs] @model_validator(mode='after') def validate_temporal(self) -> 'Self': if self.temporal and self.channel_order is not None: raise ValueError( 'Setting channel_order is not allowed if temporal=True.') return self
[docs] def build(self, tmp_dir: str | None = None, use_transformers: bool = True) -> MultiRasterSource: if use_transformers: raster_transformers = [ t.build(channel_order=self.channel_order) for t in self.transformers ] else: raster_transformers = [] built_raster_sources = [ rs.build(tmp_dir, use_transformers) for rs in self.raster_sources ] bbox = Box(*self.bbox) if self.bbox is not None else None if self.temporal: from rastervision.core.data.raster_source import ( TemporalMultiRasterSource) multi_raster_source = TemporalMultiRasterSource( raster_sources=built_raster_sources, primary_source_idx=self.primary_source_idx, raster_transformers=raster_transformers, bbox=bbox) else: multi_raster_source = MultiRasterSource( raster_sources=built_raster_sources, primary_source_idx=self.primary_source_idx, channel_order=self.channel_order, raster_transformers=raster_transformers, bbox=bbox) return multi_raster_source
[docs] def update(self, pipeline=None, scene=None): for t in self.transformers: t.update(pipeline, scene)