Source code for rastervision.core.data.raster_transformer.stats_transformer_config

from typing import TYPE_CHECKING
from os.path import join

from rastervision.pipeline.config import register_config, Field
from rastervision.core.data.raster_transformer import (RasterTransformerConfig,
                                                       StatsTransformer)

if TYPE_CHECKING:
    from rastervision.core.rv_pipeline import RVPipelineConfig
    from rastervision.core.data import SceneConfig


def stats_transformer_config_upgrader(cfg_dict: dict, version: int) -> dict:
    if version == 2:
        # field added in version 3
        # since `scene_group` cannot be None, set it to a special value so that
        # `update_root()`, which is called by the predictor, knows to set
        # `stats_uri` to the old location of `stats.json`.
        cfg_dict['scene_group'] = '__N/A__'
    elif version == 13:
        cfg_dict['needs_channel_order'] = True
    return cfg_dict


[docs]@register_config( 'stats_transformer', upgrader=stats_transformer_config_upgrader) class StatsTransformerConfig(RasterTransformerConfig): """Configure a :class:`.StatsTransformer`.""" stats_uri: str | None = Field( None, description='The URI of the output of the StatsAnalyzer. ' 'If None, and this Config is inside an RVPipeline, ' 'this field will be auto-generated.') scene_group: str = Field( 'train_scenes', description='Name of the group of scenes whose stats to use. Defaults' 'to "train_scenes".') needs_channel_order: bool = Field( False, description='Whether the means and stds in the stats_uri file need to ' 'be re-ordered/subsetted using ``channel_order`` to be compatible ' 'with the chips that will be passed to the :class:`.StatsTransformer` ' 'by the :class:`.RasterSource`. This field exists for backward ' 'compatibility with Raster Vision versions <= 0.30. It will be set ' 'automatically when loading stats from older model-bundles.')
[docs] def update(self, pipeline: 'RVPipelineConfig | None' = None, scene: 'SceneConfig | None' = None) -> None: if pipeline is not None and self.stats_uri is None: self.stats_uri = join(pipeline.analyze_uri, 'stats', self.scene_group, 'stats.json')
[docs] def build(self, channel_order: list[int] | None = None) -> StatsTransformer: if self.needs_channel_order: tf = StatsTransformer.from_stats_json( self.stats_uri, channel_order=channel_order) else: tf = StatsTransformer.from_stats_json(self.stats_uri) return tf
[docs] def update_root(self, root_dir: str) -> None: if self.scene_group == '__N/A__': # backward compatibility: use old location of stats.json self.stats_uri = join(root_dir, 'stats.json') else: self.stats_uri = join(root_dir, 'analyze', 'stats', self.scene_group, 'stats.json')