Source code for rastervision.core.rv_pipeline.chip_classification

from typing import List
import logging

from rastervision.core.rv_pipeline.rv_pipeline import RVPipeline
from rastervision.core.rv_pipeline.utils import nodata_below_threshold
from rastervision.core.box import Box
from rastervision.core.data import Scene

log = logging.getLogger(__name__)


[docs]def get_train_windows(scene: Scene, chip_size: int, chip_nodata_threshold: float = 1.) -> List[Box]: train_windows = [] extent = scene.raster_source.extent stride = chip_size windows = extent.get_windows(chip_size, stride) total_windows = len(windows) if scene.aoi_polygons: windows = Box.filter_by_aoi(windows, scene.aoi_polygons) log.info(f'AOI filtering: {len(windows)}/{total_windows} ' 'chips accepted') for window in windows: chip = scene.raster_source.get_chip(window) if nodata_below_threshold(chip, chip_nodata_threshold, nodata_val=0): train_windows.append(window) log.info('NODATA filtering: ' f'{len(train_windows)}/{len(windows)} chips accepted') return train_windows
[docs]class ChipClassification(RVPipeline):
[docs] def get_train_windows(self, scene: Scene) -> List[Box]: return get_train_windows( scene, self.config.train_chip_sz, chip_nodata_threshold=self.config.chip_nodata_threshold)
[docs] def get_train_labels(self, window: Box, scene: Scene): return scene.label_source.get_labels(window=window)