from typing import TYPE_CHECKING, Iterator
import logging
from rastervision.core.data import (ChipClassificationLabels,
ObjectDetectionLabels,
SemanticSegmentationLabels)
from rastervision.core.utils import calculate_required_padding
if TYPE_CHECKING:
import numpy as np
from rastervision.core.data import Scene, SemanticSegmentationLabelStore
from rastervision.core.rv_pipeline import (
PredictOptions, ObjectDetectionPredictOptions,
SemanticSegmentationPredictOptions)
from rastervision.pytorch_learner import (ClassificationLearner,
ObjectDetectionLearner,
SemanticSegmentationLearner)
log = logging.getLogger(__name__)
[docs]def predict_scene_cc(
learner: 'ClassificationLearner', scene: 'Scene',
predict_options: 'PredictOptions') -> 'ChipClassificationLabels':
"""Generate chip classification predictions for a :class:`.Scene`."""
from rastervision.pytorch_learner.dataset import (
ClassificationSlidingWindowGeoDataset)
chip_sz = predict_options.chip_sz
stride = predict_options.stride
batch_sz = predict_options.batch_sz
base_tf, _ = learner.cfg.data.get_data_transforms()
ds = ClassificationSlidingWindowGeoDataset(
scene, size=chip_sz, stride=stride, transform=base_tf)
predictions: Iterator['np.ndarray'] = learner.predict_dataset(
ds,
raw_out=True,
numpy_out=True,
dataloader_kw=dict(batch_size=batch_sz),
progress_bar=True,
progress_bar_kw=dict(desc=f'Making predictions on {scene.id}'))
labels = ChipClassificationLabels.from_predictions(ds.windows, predictions)
return labels
[docs]def predict_scene_od(learner: 'ObjectDetectionLearner', scene: 'Scene',
predict_options: 'ObjectDetectionPredictOptions'
) -> ObjectDetectionLabels:
"""Generate object detection predictions for a :class:`.Scene`."""
from rastervision.pytorch_learner.dataset import (
ObjectDetectionSlidingWindowGeoDataset)
chip_sz = predict_options.chip_sz
stride = predict_options.stride
batch_sz = predict_options.batch_sz
base_tf, _ = learner.cfg.data.get_data_transforms()
ds = ObjectDetectionSlidingWindowGeoDataset(
scene, size=chip_sz, stride=stride, transform=base_tf)
predictions: Iterator[dict[str, 'np.ndarray']] = learner.predict_dataset(
ds,
raw_out=True,
numpy_out=True,
predict_kw=dict(out_shape=(chip_sz, chip_sz)),
dataloader_kw=dict(batch_size=batch_sz),
progress_bar=True,
progress_bar_kw=dict(desc=f'Making predictions on {scene.id}'))
labels = ObjectDetectionLabels.from_predictions(ds.windows, predictions)
labels = ObjectDetectionLabels.prune_duplicates(
labels,
score_thresh=predict_options.score_thresh,
merge_thresh=predict_options.merge_thresh)
return labels
[docs]def predict_scene_ss(learner: 'SemanticSegmentationLearner', scene: 'Scene',
predict_options: 'SemanticSegmentationPredictOptions'
) -> 'SemanticSegmentationLabels':
"""Generate semantic segmentation predictions for a :class:`.Scene`."""
from rastervision.pytorch_learner.dataset import (
SemanticSegmentationSlidingWindowGeoDataset)
if scene.label_store is None:
raise ValueError(f'Scene.label_store is not set for scene {scene.id}')
chip_sz = predict_options.chip_sz
stride = predict_options.stride
crop_sz = predict_options.crop_sz
batch_sz = predict_options.batch_sz
label_store: 'SemanticSegmentationLabelStore' = scene.label_store
raw_out = label_store.smooth_output
base_tf, _ = learner.cfg.data.get_data_transforms()
if crop_sz is None:
ds = SemanticSegmentationSlidingWindowGeoDataset(
scene, size=chip_sz, stride=stride, transform=base_tf)
else:
padding = calculate_required_padding(
extent_sz=scene.extent.size,
chip_sz=(chip_sz, chip_sz),
stride=(stride, stride),
pad_direction='both',
crop_sz=crop_sz,
)
ds = SemanticSegmentationSlidingWindowGeoDataset(
scene,
size=chip_sz,
stride=stride,
padding=padding,
pad_direction='both',
transform=base_tf,
)
predictions: Iterator['np.ndarray'] = learner.predict_dataset(
ds,
raw_out=raw_out,
numpy_out=True,
predict_kw=dict(out_shape=(chip_sz, chip_sz)),
dataloader_kw=dict(batch_size=batch_sz),
progress_bar=True,
progress_bar_kw=dict(desc=f'Making predictions on {scene.id}'))
labels = SemanticSegmentationLabels.from_predictions(
ds.windows,
predictions,
smooth=raw_out,
extent=scene.extent,
num_classes=len(label_store.class_config),
crop_sz=crop_sz)
return labels