Source code for rastervision.core.rv_pipeline.rv_pipeline_config

from typing import (TYPE_CHECKING, List, Optional)
from os.path import join

from rastervision.pipeline.pipeline_config import PipelineConfig
from rastervision.core.data import (DatasetConfig, StatsTransformerConfig,
                                    LabelStoreConfig, SceneConfig)
from rastervision.core.analyzer import StatsAnalyzerConfig
from rastervision.core.backend import BackendConfig
from rastervision.core.evaluation import EvaluatorConfig
from rastervision.core.analyzer import AnalyzerConfig
from rastervision.core.rv_pipeline.chip_options import ChipOptions
from rastervision.pipeline.config import (Config, Field, register_config,
                                          validator)

if TYPE_CHECKING:
    from rastervision.core.backend.backend import Backend  # noqa


[docs]@register_config('predict_options') class PredictOptions(Config): chip_sz: int = Field( 300, description='Size of predictions chips in pixels.') stride: Optional[int] = Field( None, description='Stride of the sliding window for generating chips.' 'Defaults to ``chip_sz``.') batch_sz: int = Field( 8, description='Batch size to use during prediction.')
[docs] @validator('stride', always=True) def validate_stride(cls, v: Optional[int], values: dict) -> dict: if v is None: chip_sz: int = values['chip_sz'] return chip_sz return v
def rv_pipeline_config_upgrader(cfg_dict: dict, version: int) -> dict: if version == 10: train_chip_sz = cfg_dict.pop('train_chip_sz', 300) nodata_threshold = cfg_dict.pop('chip_nodata_threshold', 1.) chip_options: dict = cfg_dict.get('chip_options', {}) method = chip_options.pop('method', 'sliding') if method != 'sliding': method = 'random' chip_options['sampling'] = dict(size=train_chip_sz, method=method) chip_options['nodata_threshold'] = nodata_threshold cfg_dict['chip_options'] = chip_options elif version == 11: predict_chip_sz = cfg_dict.pop('predict_chip_sz', 300) predict_batch_sz = cfg_dict.pop('predict_batch_sz', 8) predict_options = cfg_dict.get('predict_options', {}) predict_options['chip_sz'] = predict_chip_sz predict_options['batch_sz'] = predict_batch_sz cfg_dict['predict_options'] = predict_options return cfg_dict
[docs]@register_config('rv_pipeline', upgrader=rv_pipeline_config_upgrader) class RVPipelineConfig(PipelineConfig): """Configure an :class:`.RVPipeline`.""" dataset: DatasetConfig = Field( ..., description= 'Dataset containing train, validation, and optional test scenes.') backend: BackendConfig = Field( ..., description='Backend to use for interfacing with ML library.') evaluators: List[EvaluatorConfig] = Field( [], description=( 'Evaluators to run during analyzer command. If list is empty ' 'the default evaluator is added.')) analyzers: List[AnalyzerConfig] = Field( [], description= ('Analyzers to run during analyzer command. A StatsAnalyzer will be added ' 'automatically if any scenes have a RasterTransformer.')) analyze_uri: Optional[str] = Field( None, description= 'URI for output of analyze. If None, will be auto-generated.') chip_uri: Optional[str] = Field( None, description='URI for output of chip. If None, will be auto-generated.') train_uri: Optional[str] = Field( None, description='URI for output of train. If None, will be auto-generated.' ) predict_uri: Optional[str] = Field( None, description= 'URI for output of predict. If None, will be auto-generated.') eval_uri: Optional[str] = Field( None, description='URI for output of eval. If None, will be auto-generated.') bundle_uri: Optional[str] = Field( None, description='URI for output of bundle. If None, will be auto-generated.' ) source_bundle_uri: Optional[str] = Field( None, description='If provided, the model will be loaded from this bundle ' 'for the train stage. Useful for fine-tuning.') chip_options: Optional[ChipOptions] = Field( None, description='Config for chip stage.') predict_options: Optional[PredictOptions] = Field( None, description='Config for predict stage.')
[docs] def update(self): super().update() if self.analyze_uri is None: self.analyze_uri = join(self.root_uri, 'analyze') if self.chip_uri is None: self.chip_uri = join(self.root_uri, 'chip') if self.train_uri is None: self.train_uri = join(self.root_uri, 'train') if self.predict_uri is None: self.predict_uri = join(self.root_uri, 'predict') if self.eval_uri is None: self.eval_uri = join(self.root_uri, 'eval') if self.bundle_uri is None: self.bundle_uri = join(self.root_uri, 'bundle') self.dataset.update(pipeline=self) self.backend.update(pipeline=self) if not self.evaluators: self.evaluators.append(self.get_default_evaluator()) for evaluator in self.evaluators: evaluator.update(pipeline=self) self._insert_analyzers() for analyzer in self.analyzers: analyzer.update(pipeline=self)
[docs] def get_model_bundle_uri(self): return join(self.bundle_uri, 'model-bundle.zip')
def _insert_analyzers(self): # Inserts StatsAnalyzer if it's needed because a RasterSource has a # StatsTransformer, but there isn't a StatsAnalyzer in the list of Analyzers. has_stats_transformer = False for s in self.dataset.all_scenes: for t in s.raster_source.transformers: if isinstance(t, StatsTransformerConfig): has_stats_transformer = True has_stats_analyzer = False for a in self.analyzers: if isinstance(a, StatsAnalyzerConfig): has_stats_analyzer = True break if has_stats_transformer and not has_stats_analyzer: self.analyzers.append(StatsAnalyzerConfig())
[docs] def get_default_label_store(self, scene: SceneConfig) -> LabelStoreConfig: """Returns a default LabelStoreConfig to fill in any missing ones.""" raise NotImplementedError()
[docs] def get_default_evaluator(self) -> EvaluatorConfig: """Returns a default EvaluatorConfig to use if one isn't set.""" raise NotImplementedError()