Note

This page was generated from sampling_training_data.ipynb.

Note

If running outside of the Docker image, you may need to set some environment variables manually. You can do it like so:

import os
from subprocess import check_output

os.environ['GDAL_DATA'] = check_output('pip show rasterio | grep Location | awk \'{print $NF"/rasterio/gdal_data/"}\'', shell=True).decode().strip()

We will be accessing files on S3 in this notebook. Since those files are public, we set the AWS_NO_SIGN_REQUEST to tell rasterio to skip the sign-in.

[1]:
%env AWS_NO_SIGN_REQUEST=YES
env: AWS_NO_SIGN_REQUEST=YES

Sampling training data#

The GeoDataset class#

The GeoDataset is a PyTorch-compatible Dataset implementation that allows sampling images and labels from a Scene.

It comes in two flavors:

  1. SlidingWindowGeoDataset

  2. RandomWindowGeoDataset

Below we explore both in the context of semantic segmentation.


First, let’s define a handy plotting function:

[15]:
from matplotlib import pyplot as plt
import matplotlib.patches as patches

def show_windows(img, windows, title='', ax=None, show=True):
    if ax is None:
        fig, ax = plt.subplots(1, 1, squeeze=True, figsize=(8, 8))
    ax.imshow(img)
    ax.axis('off')
    # draw windows on top of the image
    for w in windows:
        p = patches.Polygon(w.to_points(), color='r', linewidth=1, fill=False)
        ax.add_patch(p)
    ax.autoscale()
    ax.set_title(title)
    if show:
        plt.show()

SlidingWindowGeoDataset#

The SlidingWindowGeoDataset allows reading the scene left-to-right, top-to-bottom, using a sliding window.

[3]:
image_uri = 's3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0331E-1257N_1327_3160_13/images/global_monthly_2018_01_mosaic_L15-0331E-1257N_1327_3160_13.tif'
label_uri = 's3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0331E-1257N_1327_3160_13/labels/global_monthly_2018_01_mosaic_L15-0331E-1257N_1327_3160_13_Buildings.geojson'

Here we make use of the convenience API, GeoDataset.from_uris() (specifically, SemanticSegmentationSlidingWindowGeoDataset.from_uris()), but we can also use the normal constructor if we want to manually define the RasterSource and LabelSource.

[20]:
from rastervision.core.data import ClassConfig
from rastervision.pytorch_learner import (
    SemanticSegmentationSlidingWindowGeoDataset, SemanticSegmentationVisualizer)

import albumentations as A

class_config = ClassConfig(
    names=['background', 'building'],
    colors=['lightgray', 'darkred'],
    null_class='background')

ds = SemanticSegmentationSlidingWindowGeoDataset.from_uris(
    class_config=class_config,
    image_uri=image_uri,
    label_vector_uri=label_uri,
    label_vector_default_class_id=class_config.get_class_id('building'),
    image_raster_source_kw=dict(allow_streaming=True),
    size=200,
    stride=200,
    out_size=256,
)
2024-07-03 21:12:01:rastervision.pipeline.file_system.utils: INFO - Using cached file /opt/data/tmp/cache/s3/spacenet-dataset/spacenet/SN7_buildings/train/L15-0331E-1257N_1327_3160_13/labels/global_monthly_2018_01_mosaic_L15-0331E-1257N_1327_3160_13_Buildings.geojson.

We can read a data sample and the corresponding ground truth from the Dataset like so:

[5]:
x, y = ds[0]
x.shape, y.shape
[5]:
(torch.Size([3, 256, 256]), torch.Size([256, 256]))

And then plot it using the SemanticSegmentationVisualizer:

[6]:
viz = SemanticSegmentationVisualizer(
    class_names=class_config.names, class_colors=class_config.colors)
viz.plot_batch(x.unsqueeze(0), y.unsqueeze(0), show=True)
../../_images/usage_tutorials_sampling_training_data_18_0.png

The above was the first sliding window in the dataset. We can visualize what the full set of windows looks like like so:

[7]:
img_full = ds.scene.raster_source[:, :]
show_windows(img_full, ds.windows, title='Sliding windows')
../../_images/usage_tutorials_sampling_training_data_20_0.png

Padding#

By default, Raster Vision adds sufficient padding around the scene so that each pixel in the scene gets included.

What does the data look like when we sample one of the edge windows? The part of the window overflowing the scene’s extent will simply be filled with zeros, as seen below:

[23]:
x, y = ds[-2]
viz.plot_batch(x.unsqueeze(0), y.unsqueeze(0), show=True)
../../_images/usage_tutorials_sampling_training_data_24_0.png

You can control the padding behavior via the padding and pad_direction arguments. For example:

[13]:
ds_pad_all_sides = SemanticSegmentationSlidingWindowGeoDataset.from_uris(
    class_config=class_config,
    image_uri=image_uri,
    label_vector_uri=label_uri,
    label_vector_default_class_id=class_config.get_class_id('building'),
    image_raster_source_kw=dict(allow_streaming=True),
    size=200,
    stride=200,
    out_size=256,
    padding=None, # will be automatically calculated if None
    pad_direction='both',
)
ds_no_padding = SemanticSegmentationSlidingWindowGeoDataset.from_uris(
    class_config=class_config,
    image_uri=image_uri,
    label_vector_uri=label_uri,
    label_vector_default_class_id=class_config.get_class_id('building'),
    image_raster_source_kw=dict(allow_streaming=True),
    size=200,
    stride=200,
    out_size=256,
    padding=0,
)
2024-07-03 21:04:14:rastervision.pipeline.file_system.utils: INFO - Using cached file /opt/data/tmp/cache/s3/spacenet-dataset/spacenet/SN7_buildings/train/L15-0331E-1257N_1327_3160_13/labels/global_monthly_2018_01_mosaic_L15-0331E-1257N_1327_3160_13_Buildings.geojson.
2024-07-03 21:04:15:rastervision.pipeline.file_system.utils: INFO - Using cached file /opt/data/tmp/cache/s3/spacenet-dataset/spacenet/SN7_buildings/train/L15-0331E-1257N_1327_3160_13/labels/global_monthly_2018_01_mosaic_L15-0331E-1257N_1327_3160_13_Buildings.geojson.
[24]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6), sharex=True, sharey=True)
show_windows(img_full, ds_pad_all_sides.windows, title='Sliding windows with pad_direction=\'both\'', ax=ax1, show=False)
show_windows(img_full, ds_no_padding.windows, title='Sliding windows with padding=0', ax=ax2, show=False)
plt.show()
../../_images/usage_tutorials_sampling_training_data_27_0.png

RandomWindowGeoDataset#

The RandomWindowGeoDataset allows reading the scene by sampling random window sizes and locations.

[8]:
image_uri = 's3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0331E-1257N_1327_3160_13/images/global_monthly_2018_01_mosaic_L15-0331E-1257N_1327_3160_13.tif'
label_uri = 's3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0331E-1257N_1327_3160_13/labels/global_monthly_2018_01_mosaic_L15-0331E-1257N_1327_3160_13_Buildings.geojson'

As before, we make use of the convenience API, GeoDataset.from_uris() (specifically, SemanticSegmentationRandomWindowGeoDataset.from_uris()), but we can also use the normal constructor if we want to manually define the RasterSource and LabelSource.

[9]:
from rastervision.core.data import ClassConfig
from rastervision.pytorch_learner import SemanticSegmentationRandomWindowGeoDataset

import albumentations as A

class_config = ClassConfig(
    names=['background', 'building'],
    colors=['lightgray', 'darkred'],
    null_class='background')

ds = SemanticSegmentationRandomWindowGeoDataset.from_uris(
    class_config=class_config,
    image_uri=image_uri,
    label_vector_uri=label_uri,
    label_vector_default_class_id=class_config.get_class_id('building'),
    image_raster_source_kw=dict(allow_streaming=True),
    # window sizes will randomly vary from 100x100 to 300x300
    size_lims=(100, 300),
    # resize chips to 256x256 before returning
    out_size=256,
    # allow windows to overflow the extent by 100 pixels
    padding=100,
    max_windows=10
)

img_full = ds.scene.raster_source[:, :]
windows = [ds.sample_window() for _ in range(50)]
show_windows(img_full, windows, title='Random windows')
2024-04-09 20:06:18:rastervision.pipeline.file_system.utils: INFO - Using cached file /opt/data/tmp/cache/s3/spacenet-dataset/spacenet/SN7_buildings/train/L15-0331E-1257N_1327_3160_13/labels/global_monthly_2018_01_mosaic_L15-0331E-1257N_1327_3160_13_Buildings.geojson.
../../_images/usage_tutorials_sampling_training_data_33_1.png
[10]:
x, y = ds[0]
viz.plot_batch(x.unsqueeze(0), y.unsqueeze(0), show=True)
../../_images/usage_tutorials_sampling_training_data_34_0.png