Source code for rastervision.core.rv_pipeline.semantic_segmentation

from typing import TYPE_CHECKING, List, Sequence
import logging

import numpy as np

from rastervision.core.box import Box
from rastervision.core.rv_pipeline.rv_pipeline import RVPipeline
from rastervision.core.rv_pipeline.utils import nodata_below_threshold
from rastervision.core.rv_pipeline.semantic_segmentation_config import (
    SemanticSegmentationWindowMethod)

if TYPE_CHECKING:
    from rastervision.core.backend.backend import Backend
    from rastervision.core.data import (
        ClassConfig,
        Labels,
        Scene,
        SemanticSegmentationLabelSource,
    )
    from rastervision.core.rv_pipeline.semantic_segmentation_config import (
        SemanticSegmentationConfig, SemanticSegmentationChipOptions)

log = logging.getLogger(__name__)


[docs]def get_train_windows(scene: 'Scene', class_config: 'ClassConfig', chip_size: int, chip_options: 'SemanticSegmentationChipOptions', chip_nodata_threshold: float = 1.) -> List[Box]: """Get training windows covering a scene. Args: scene: The scene over-which windows are to be generated. Returns: A list of windows, list(Box) """ co = chip_options raster_source = scene.raster_source extent = raster_source.extent label_source: 'SemanticSegmentationLabelSource' = scene.label_source def filter_windows(windows: Sequence[Box]) -> List[Box]: """Filter out chips that (1) are outside the AOI (2) only consist of null labels (3) have NODATA proportion >= chip_nodata_threshold """ total_windows = len(windows) if scene.aoi_polygons: windows = Box.filter_by_aoi(windows, scene.aoi_polygons) log.info(f'AOI filtering: {len(windows)}/{total_windows} ' 'chips accepted') filt_windows = [] for w in windows: chip = raster_source.get_chip(w) nodata_below_thresh = nodata_below_threshold( chip, chip_nodata_threshold, nodata_val=0) label_arr = label_source.get_labels(w).get_label_arr(w) null_labels = label_arr == class_config.null_class_id if not np.all(null_labels) and nodata_below_thresh: filt_windows.append(w) log.info('Label and NODATA filtering: ' f'{len(filt_windows)}/{len(windows)} chips accepted') windows = filt_windows return windows def should_use_window(window: Box) -> bool: if co.negative_survival_prob >= 1.0: return True else: is_positive = False if co.target_class_ids is not None: is_positive = label_source.enough_target_pixels( window, co.target_count_threshold, co.target_class_ids) if is_positive: return True keep_negative = np.random.sample() < co.negative_survival_prob return keep_negative if co.window_method == SemanticSegmentationWindowMethod.sliding: stride = co.stride or int(round(chip_size / 2)) unfiltered_windows = extent.get_windows(chip_size, stride) windows = list(filter_windows(unfiltered_windows)) if len(windows) > 0: a_window = windows[0] windows = list(filter(should_use_window, windows)) if len(windows) == 0: windows = [a_window] elif len(windows) == 0: return [unfiltered_windows[0]] elif co.window_method == SemanticSegmentationWindowMethod.random_sample: windows = [] attempts = 0 while attempts < co.chips_per_scene: window = extent.make_random_square(chip_size) if not filter_windows([window]): continue attempts += 1 if co.negative_survival_prob >= 1.0: windows.append(window) elif attempts == co.chips_per_scene and len(windows) == 0: # Ensure there is at least one window per scene. windows.append(window) elif should_use_window(window): windows.append(window) return windows
[docs]class SemanticSegmentation(RVPipeline):
[docs] def chip(self, *args, **kwargs): log.info(f'Chip options: {self.config.chip_options}') return super().chip(*args, **kwargs)
[docs] def get_train_windows(self, scene): return get_train_windows( scene, self.config.dataset.class_config, self.config.train_chip_sz, self.config.chip_options, chip_nodata_threshold=self.config.chip_nodata_threshold)
[docs] def get_train_labels(self, window, scene): return scene.label_source.get_labels(window=window)
[docs] def post_process_batch(self, windows, chips, labels): # Fill in null class for any NODATA pixels. null_class_id = self.config.dataset.class_config.null_class_id for window, chip in zip(windows, chips): nodata_mask = np.sum(chip, axis=2) == 0 labels.mask_fill(window, nodata_mask, fill_value=null_class_id) return labels
[docs] def predict_scene(self, scene: 'Scene', backend: 'Backend') -> 'Labels': cfg: 'SemanticSegmentationConfig' = self.config chip_sz = cfg.predict_chip_sz stride = cfg.predict_options.stride crop_sz = cfg.predict_options.crop_sz if stride is None: stride = chip_sz if crop_sz == 'auto': overlap_sz = chip_sz - stride if overlap_sz % 2 == 1: log.warn( 'Using crop_sz="auto" but overlap size (chip_sz minus ' 'stride) is odd. This means that one pixel row/col will ' 'still overlap after cropping.') crop_sz = overlap_sz // 2 return backend.predict_scene( scene, chip_sz=chip_sz, stride=stride, crop_sz=crop_sz)