Source code for rastervision.core.data.raster_transformer.stats_transformer_config
from typing import TYPE_CHECKING, Optional
from os.path import join
from rastervision.pipeline.config import register_config, Field
from rastervision.core.data.raster_transformer import (RasterTransformerConfig,
StatsTransformer)
from rastervision.core.raster_stats import RasterStats
if TYPE_CHECKING:
from rastervision.core.rv_pipeline import RVPipelineConfig
from rastervision.core.data import SceneConfig
def stats_transformer_config_upgrader(cfg_dict: dict, version: int) -> dict:
if version == 2:
# field added in version 3
# since `scene_group` cannot be None, set it to a special value so that
# `update_root()`, which is called by the predictor, knows to set
# `stats_uri` to the old location of `stats.json`.
cfg_dict['scene_group'] = '__N/A__'
return cfg_dict
[docs]@register_config(
'stats_transformer', upgrader=stats_transformer_config_upgrader)
class StatsTransformerConfig(RasterTransformerConfig):
"""Configure a :class:`.StatsTransformer`."""
stats_uri: Optional[str] = Field(
None,
description='The URI of the output of the StatsAnalyzer. '
'If None, and this Config is inside an RVPipeline, '
'this field will be auto-generated.')
scene_group: str = Field(
'train_scenes',
description='Name of the group of scenes whose stats to use. Defaults'
'to "train_scenes".')
[docs] def update(self,
pipeline: Optional['RVPipelineConfig'] = None,
scene: Optional['SceneConfig'] = None) -> None:
if pipeline is not None and self.stats_uri is None:
self.stats_uri = join(pipeline.analyze_uri, 'stats',
self.scene_group, 'stats.json')
[docs] def build(self):
stats = RasterStats.load(self.stats_uri)
return StatsTransformer(means=stats.means, stds=stats.stds)
[docs] def update_root(self, root_dir: str) -> None:
if self.scene_group == '__N/A__':
# backward compatibility: use old location of stats.json
self.stats_uri = join(root_dir, 'stats.json')
else:
self.stats_uri = join(root_dir, 'analyze', 'stats',
self.scene_group, 'stats.json')