Source code for rastervision.core.data.label_source.semantic_segmentation_label_source_config

from typing import Union

from rastervision.core.data.raster_source import (RasterSourceConfig,
                                                  RasterizedSourceConfig)
from rastervision.core.data.label_source import (
    LabelSourceConfig, SemanticSegmentationLabelSource)
from rastervision.pipeline.config import (register_config, Field)


def ss_label_source_config_upgrader(cfg_dict: dict,
                                    version: int) -> dict:  # pragma: no cover
    if version == 3:
        # removed in version 4
        cfg_dict.pop('rgb_class_config', None)
    return cfg_dict


[docs]@register_config( 'semantic_segmentation_label_source', upgrader=ss_label_source_config_upgrader) class SemanticSegmentationLabelSourceConfig(LabelSourceConfig): """Configure a :class:`.SemanticSegmentationLabelSource`.""" raster_source: Union[RasterSourceConfig, RasterizedSourceConfig] = Field( ..., description='The labels in the form of rasters.')
[docs] def build(self, class_config, crs_transformer, bbox=None, tmp_dir=None) -> SemanticSegmentationLabelSource: if isinstance(self.raster_source, RasterizedSourceConfig): rs = self.raster_source.build(class_config, crs_transformer, bbox) else: rs = self.raster_source.build(tmp_dir) return SemanticSegmentationLabelSource(rs, class_config)