Source code for rastervision.core.data.label_source.semantic_segmentation_label_source_config

from typing import Union

from rastervision.core.data.raster_source import (RasterSourceConfig,
                                                  RasterizedSourceConfig)
from rastervision.core.data.label_source import (
    LabelSourceConfig, SemanticSegmentationLabelSource)
from rastervision.pipeline.config import (register_config, Field)


def ss_label_source_config_upgrader(cfg_dict: dict, version: int) -> dict:
    if version < 4:
        try:
            # removed in version 4
            del cfg_dict['rgb_class_config']
        except KeyError:
            pass
    return cfg_dict


[docs]@register_config( 'semantic_segmentation_label_source', upgrader=ss_label_source_config_upgrader) class SemanticSegmentationLabelSourceConfig(LabelSourceConfig): """Configure a :class:`.SemanticSegmentationLabelSource`.""" raster_source: Union[RasterSourceConfig, RasterizedSourceConfig] = Field( ..., description='The labels in the form of rasters.')
[docs] def build(self, class_config, crs_transformer, extent, tmp_dir): if isinstance(self.raster_source, RasterizedSourceConfig): rs = self.raster_source.build(class_config, crs_transformer, extent) else: rs = self.raster_source.build(tmp_dir) return SemanticSegmentationLabelSource(rs, class_config)