Source code for rastervision.core.data.label_source.object_detection_label_source_config

from rastervision.core.data.label_source import (LabelSourceConfig,
                                                 ObjectDetectionLabelSource)
from rastervision.core.data.vector_source import VectorSourceConfig
from rastervision.core.data.vector_transformer import (
    ClassInferenceTransformerConfig, BufferTransformerConfig)
from rastervision.pipeline.config import register_config, validator


[docs]@register_config('object_detection_label_source') class ObjectDetectionLabelSourceConfig(LabelSourceConfig): """Configure an :class:`.ObjectDetectionLabelSource`.""" vector_source: VectorSourceConfig
[docs] @validator('vector_source') def ensure_required_transformers( cls, v: VectorSourceConfig) -> VectorSourceConfig: """Add class-inference and buffer transformers if absent.""" tfs = v.transformers # add class inference transformer has_inf_tf = any( isinstance(tf, ClassInferenceTransformerConfig) for tf in tfs) if not has_inf_tf: tfs += [ClassInferenceTransformerConfig(default_class_id=None)] # add buffer transformers has_buf_tf = any(isinstance(tf, BufferTransformerConfig) for tf in tfs) if not has_buf_tf: tfs += [ BufferTransformerConfig(geom_type='Point', default_buf=1), BufferTransformerConfig(geom_type='LineString', default_buf=1) ] return v
[docs] def update(self, pipeline=None, scene=None): super().update(pipeline, scene) self.vector_source.update(pipeline, scene)
[docs] def build(self, class_config, crs_transformer, extent, tmp_dir=None) -> ObjectDetectionLabelSource: vs = self.vector_source.build(class_config, crs_transformer) return ObjectDetectionLabelSource(vs, extent)