Source code for rastervision.pytorch_backend.pytorch_semantic_segmentation

from typing import TYPE_CHECKING
from os.path import join
import uuid

import numpy as np

from rastervision.pipeline.file_system.utils import make_dir
from rastervision.core.data import SemanticSegmentationLabels
from rastervision.core.data_sample import DataSample
from rastervision.pytorch_backend.pytorch_learner_backend import (
    PyTorchLearnerSampleWriter, PyTorchLearnerBackend)
from rastervision.pytorch_backend.utils import chip_collate_fn_ss
from rastervision.pytorch_learner.utils import predict_scene_ss

if TYPE_CHECKING:
    from rastervision.core.data import DatasetConfig, Scene
    from rastervision.core.rv_pipeline import (
        ChipOptions, SemanticSegmentationPredictOptions)
    from rastervision.pytorch_learner import SemanticSegmentationGeoDataConfig


[docs]class PyTorchSemanticSegmentationSampleWriter(PyTorchLearnerSampleWriter):
[docs] def write_sample(self, sample: 'DataSample'): """Write sample. This writes a training or validation sample to ``(train|valid)/img/{scene_id}-{ind}.png`` and ``(train|valid)/labels/{scene_id}-{ind}.png`` """ img = sample.chip img_path = self.get_image_path(sample) self.write_chip(img, img_path) if sample.label is not None: label_arr: np.ndarray = sample.label label_path = self.get_label_path(sample, label_arr) self.write_chip(label_arr, label_path) self.sample_ind += 1
[docs] def get_label_path(self, sample: 'DataSample', label_arr: np.ndarray) -> str: split = '' if sample.split is None else sample.split img_dir = join(self.sample_dir, split, 'labels') make_dir(img_dir) if sample.scene_id is not None: sample_name = f'{sample.scene_id}-{self.sample_ind}' else: sample_name = f'{self.sample_ind}' ext = self.get_image_ext(label_arr) label_path = join(img_dir, f'{sample_name}.{ext}') return label_path
[docs]class PyTorchSemanticSegmentation(PyTorchLearnerBackend):
[docs] def get_sample_writer(self): output_uri = join(self.pipeline_cfg.chip_uri, f'{uuid.uuid4()}.zip') return PyTorchSemanticSegmentationSampleWriter( output_uri, self.pipeline_cfg.dataset.class_config, self.tmp_dir)
[docs] def chip_dataset(self, dataset: 'DatasetConfig', chip_options: 'ChipOptions', dataloader_kw: dict = {}) -> None: dataloader_kw = dict(**dataloader_kw, collate_fn=chip_collate_fn_ss) return super().chip_dataset(dataset, chip_options, dataloader_kw)
[docs] def predict_scene(self, scene: 'Scene', predict_options: 'SemanticSegmentationPredictOptions' ) -> 'SemanticSegmentationLabels': if self.learner is None: self.load_model() labels = predict_scene_ss(self.learner, scene, predict_options) return labels
def _make_chip_data_config(self, dataset: 'DatasetConfig', chip_options: 'ChipOptions' ) -> 'SemanticSegmentationGeoDataConfig': from rastervision.pytorch_learner import ( SemanticSegmentationGeoDataConfig) data_config = SemanticSegmentationGeoDataConfig( scene_dataset=dataset, sampling=chip_options.sampling) return data_config