Note

This page was generated from train.ipynb.

Note

If running outside of the Docker image, you may need to set some environment variables manually. You can do it like so:

import os
from subprocess import check_output

os.environ['GDAL_DATA'] = check_output('pip show rasterio | grep Location | awk \'{print $NF"/rasterio/gdal_data/"}\'', shell=True).decode().strip()

We will be accessing files on S3 in this notebook. Since those files are public, we set the AWS_NO_SIGN_REQUEST to tell rasterio to skip the sign-in.

[ ]:
%env AWS_NO_SIGN_REQUEST=YES

Training a model#

Define ClassConfig#

[2]:
from rastervision.core.data import ClassConfig

class_config = ClassConfig(
    names=['background', 'building'],
    colors=['lightgray', 'darkred'],
    null_class='background')

Define training and validation datasets#

To keep things simple, we use one scene for training and one for validation. In a real workflow, we would normally use many more scenes.

[3]:
train_image_uri = 's3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0331E-1257N_1327_3160_13/images/global_monthly_2018_01_mosaic_L15-0331E-1257N_1327_3160_13.tif'
train_label_uri = 's3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0331E-1257N_1327_3160_13/labels/global_monthly_2018_01_mosaic_L15-0331E-1257N_1327_3160_13_Buildings.geojson'
[4]:
val_image_uri = 's3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0357E-1223N_1429_3296_13/images/global_monthly_2018_01_mosaic_L15-0357E-1223N_1429_3296_13.tif'
val_label_uri = 's3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0357E-1223N_1429_3296_13/labels/global_monthly_2018_01_mosaic_L15-0357E-1223N_1429_3296_13_Buildings.geojson'
[5]:
import albumentations as A

from rastervision.pytorch_learner import (
    SemanticSegmentationRandomWindowGeoDataset,
    SemanticSegmentationSlidingWindowGeoDataset,
    SemanticSegmentationVisualizer)

viz = SemanticSegmentationVisualizer(
    class_names=class_config.names, class_colors=class_config.colors)

Training dataset with random-window sampling and data augmentation#

[6]:
data_augmentation_transform = A.Compose([
    A.Flip(),
    A.ShiftScaleRotate(),
    A.OneOf([
        A.HueSaturationValue(hue_shift_limit=10),
        A.RGBShift(),
        A.ToGray(),
        A.ToSepia(),
        A.RandomBrightnessContrast(),
        A.RandomGamma(),
    ]),
    A.CoarseDropout(max_height=32, max_width=32, max_holes=5)
])

train_ds = SemanticSegmentationRandomWindowGeoDataset.from_uris(
    class_config=class_config,
    image_uri=train_image_uri,
    label_vector_uri=train_label_uri,
    label_vector_default_class_id=class_config.get_class_id('building'),
    size_lims=(150, 200),
    out_size=256,
    max_windows=400,
    transform=data_augmentation_transform,
)
len(train_ds)
2024-04-09 20:16:43:rastervision.pipeline.file_system.utils: INFO - Using cached file /opt/data/tmp/cache/s3/spacenet-dataset/spacenet/SN7_buildings/train/L15-0331E-1257N_1327_3160_13/images/global_monthly_2018_01_mosaic_L15-0331E-1257N_1327_3160_13.tif.
2024-04-09 20:16:43:rastervision.pipeline.file_system.utils: INFO - Using cached file /opt/data/tmp/cache/s3/spacenet-dataset/spacenet/SN7_buildings/train/L15-0331E-1257N_1327_3160_13/labels/global_monthly_2018_01_mosaic_L15-0331E-1257N_1327_3160_13_Buildings.geojson.
[6]:
400

Visualize:

[8]:
x, y = viz.get_batch(train_ds, 4)
viz.plot_batch(x, y, show=True)
../../_images/usage_tutorials_train_14_0.png

Validation dataset with sliding-window sampling (and no data augmentation)#

[7]:
val_ds = SemanticSegmentationSlidingWindowGeoDataset.from_uris(
    class_config=class_config,
    image_uri=val_image_uri,
    label_vector_uri=val_label_uri,
    label_vector_default_class_id=class_config.get_class_id('building'),
    size=200,
    stride=100,
    out_size=256,
)
len(val_ds)
2024-04-09 20:16:46:rastervision.pipeline.file_system.utils: INFO - Using cached file /opt/data/tmp/cache/s3/spacenet-dataset/spacenet/SN7_buildings/train/L15-0357E-1223N_1429_3296_13/images/global_monthly_2018_01_mosaic_L15-0357E-1223N_1429_3296_13.tif.
2024-04-09 20:16:46:rastervision.pipeline.file_system.utils: INFO - Using cached file /opt/data/tmp/cache/s3/spacenet-dataset/spacenet/SN7_buildings/train/L15-0357E-1223N_1429_3296_13/labels/global_monthly_2018_01_mosaic_L15-0357E-1223N_1429_3296_13_Buildings.geojson.
[7]:
100

Visualize:

[10]:
x, y = viz.get_batch(val_ds, 4)
viz.plot_batch(x, y, show=True)
../../_images/usage_tutorials_train_18_0.png

Define model#

Use a light-weight panoptic FPN model with a ResNet-18 backbone.

[ ]:
import torch

model = torch.hub.load(
    'AdeelH/pytorch-fpn:0.3',
    'make_fpn_resnet',
    name='resnet18',
    fpn_type='panoptic',
    num_classes=len(class_config),
    fpn_channels=128,
    in_channels=3,
    out_size=(256, 256),
    pretrained=True)

Configure the training#

SolverConfig – Configure the loss, optimizer, and scheduler(s)#

[10]:
from rastervision.pytorch_learner import SolverConfig

solver_cfg = SolverConfig(
    batch_sz=8,
    lr=3e-2,
    class_loss_weights=[1., 10.]
)

LearnerConfig – Combine DataConfig, SolverConfig (and optionally, ModelConfig)#

[11]:
from rastervision.pytorch_learner import SemanticSegmentationLearnerConfig

learner_cfg = SemanticSegmentationLearnerConfig(data=data_cfg, solver=solver_cfg)

Initialize Learner#

[15]:
from rastervision.pytorch_learner import SemanticSegmentationLearner

learner = SemanticSegmentationLearner(
    cfg=learner_cfg,
    output_dir='./train-demo/',
    model=model,
    train_ds=train_ds,
    valid_ds=val_ds,
)
2024-04-09 20:18:03:rastervision.pytorch_learner.learner: INFO - Building dataloaders
[16]:
learner.log_data_stats()
2024-04-09 20:18:03:rastervision.pytorch_learner.learner: INFO - train_ds: 400 items
2024-04-09 20:18:03:rastervision.pytorch_learner.learner: INFO - valid_ds: 100 items

Run Tensorboard for monitoring#

Note

  • If running inside the Raster Vision docker image, you will need to pass –tensorboard to docker/run for this to work.

  • If the dashboard doen’t auto-reload, you can click the reload button on the top-right.

[13]:
%load_ext tensorboard

This will start an instance of tensorboard and embed it in the output of the cell:

[ ]:
%tensorboard --bind_all --logdir "./train-demo/tb-logs" --reload_interval 10

A screenshot of the Tensorboard dashboard.

Train – Learner.train()#

[17]:
learner.train(epochs=3)
2024-04-09 20:18:43:rastervision.pytorch_learner.learner: INFO - train_ds: 400 items
2024-04-09 20:18:43:rastervision.pytorch_learner.learner: INFO - valid_ds: 100 items
2024-04-09 20:18:43:rastervision.pytorch_learner.learner: INFO - Plotting sample training batch.
2024-04-09 20:18:43:rastervision.pytorch_learner.learner: INFO - Plotting sample validation batch.
2024-04-09 20:18:44:rastervision.pytorch_learner.learner: INFO - epoch: 0
2024-04-09 20:18:49:rastervision.pytorch_learner.learner: INFO - metrics:
{'epoch': 0,
 'train_loss': 0.2913476228713989,
 'train_time': '0:00:03.397020',
 'val_loss': 0.3804737329483032,
 'avg_precision': 0.9294803738594055,
 'avg_recall': 0.8374766707420349,
 'avg_f1': 0.8810832500457764,
 'background_precision': 0.9746540188789368,
 'background_recall': 0.8492995500564575,
 'background_f1': 0.9076691269874573,
 'building_precision': 0.21413441002368927,
 'building_recall': 0.6502557396888733,
 'building_f1': 0.3221742510795593,
 'valid_time': '0:00:01.604234'}
2024-04-09 20:18:49:rastervision.pytorch_learner.learner: INFO - epoch: 1
2024-04-09 20:18:53:rastervision.pytorch_learner.learner: INFO - metrics:
{'epoch': 1,
 'train_loss': 0.24667344987392426,
 'train_time': '0:00:02.883060',
 'val_loss': 0.4013445973396301,
 'avg_precision': 0.927025318145752,
 'avg_recall': 0.8806623816490173,
 'avg_f1': 0.9032492637634277,
 'background_precision': 0.9691619873046875,
 'background_recall': 0.9018215537071228,
 'background_f1': 0.9342798590660095,
 'building_precision': 0.25977078080177307,
 'building_recall': 0.5455965399742126,
 'building_f1': 0.35196369886398315,
 'valid_time': '0:00:01.320343'}
2024-04-09 20:18:53:rastervision.pytorch_learner.learner: INFO - epoch: 2
2024-04-09 20:18:57:rastervision.pytorch_learner.learner: INFO - metrics:
{'epoch': 2,
 'train_loss': 0.2440817952156067,
 'train_time': '0:00:02.781280',
 'val_loss': 0.39747753739356995,
 'avg_precision': 0.9319270849227905,
 'avg_recall': 0.8825308680534363,
 'avg_f1': 0.9065566062927246,
 'background_precision': 0.9732670783996582,
 'background_recall': 0.8998284935951233,
 'background_f1': 0.9351081252098083,
 'building_precision': 0.2772882282733917,
 'building_recall': 0.6086140275001526,
 'building_f1': 0.380993515253067,
 'valid_time': '0:00:01.286454'}

Train some more#

[18]:
learner.train(epochs=1)
2024-04-09 20:18:57:rastervision.pytorch_learner.learner: INFO - Resuming training from epoch 3
2024-04-09 20:18:57:rastervision.pytorch_learner.learner: INFO - train_ds: 400 items
2024-04-09 20:18:57:rastervision.pytorch_learner.learner: INFO - valid_ds: 100 items
2024-04-09 20:18:57:rastervision.pytorch_learner.learner: INFO - Plotting sample training batch.
2024-04-09 20:18:57:rastervision.pytorch_learner.learner: INFO - Plotting sample validation batch.
2024-04-09 20:18:58:rastervision.pytorch_learner.learner: INFO - epoch: 3
2024-04-09 20:19:02:rastervision.pytorch_learner.learner: INFO - metrics:
{'epoch': 3,
 'train_loss': 0.2168915718793869,
 'train_time': '0:00:02.736246',
 'val_loss': 0.43969860672950745,
 'avg_precision': 0.9304015040397644,
 'avg_recall': 0.8608637452125549,
 'avg_f1': 0.894282877445221,
 'background_precision': 0.9738898873329163,
 'background_recall': 0.8755510449409485,
 'background_f1': 0.9221059679985046,
 'building_precision': 0.24174126982688904,
 'building_recall': 0.6282839775085449,
 'building_f1': 0.3491442799568176,
 'valid_time': '0:00:01.285872'}

Examine predictions – Learner.plot_predictions()#

[19]:
learner.plot_predictions(split='valid', show=True)
2024-04-09 20:19:02:rastervision.pytorch_learner.learner: INFO - Making and plotting sample predictions on the valid set...
../../_images/usage_tutorials_train_45_1.png
2024-04-09 20:19:03:rastervision.pytorch_learner.learner: INFO - Sample predictions written to ./train-demo/valid_preds.png.

Save as a model-bundle – Learner.save_model_bundle()#

Note the warning about ModelConfig. This is relevant when loading from from the bundle as we will see below.

[20]:
learner.save_model_bundle()
2024-04-09 20:19:03:rastervision.pytorch_learner.learner: WARNING - Model was not configured via ModelConfig, and therefore, will not be reconstructable form the model-bundle. You will need to initialize the model yourself and pass it to from_model_bundle().
2024-04-09 20:19:03:rastervision.pytorch_learner.learner: INFO - Creating bundle.
2024-04-09 20:19:03:rastervision.pytorch_learner.learner: INFO - Exporting to model to ONNX.
2024-04-09 20:19:04:rastervision.pytorch_learner.learner: INFO - Saving bundle to ./train-demo/model-bundle.zip.

Examine learner output#

The trained model weights are saved at ./train-demo/last-model.pth as well as inside the model-bundle.

[21]:
!tree "./train-demo/"
./train-demo/
├── checkpoints
├── dataloaders
│   ├── train.png
│   └── valid.png
├── last-model.pth
├── learner-config.json
├── log.csv
├── model-bundle.zip
├── tb-logs
│   ├── events.out.tfevents.1712693861.60eb69042744.1456.0
│   ├── events.out.tfevents.1712693868.60eb69042744.1456.1
│   └── events.out.tfevents.1712693883.60eb69042744.1456.2
└── valid_preds.png

3 directories, 10 files

Using model-bundles#

For predictions – Learner.from_model_bundle()#

We can use the model-bundle to re-construct our Learner and then use it to make predictions.

Note

Since we used a custom model instead of using ModelConfig, the model-bundle does not know how to construct the model; therefore, we need to pass in the model again.

[22]:
from rastervision.pytorch_learner import SemanticSegmentationLearner

learner = SemanticSegmentationLearner.from_model_bundle(
    model_bundle_uri='./train-demo/model-bundle.zip',
    output_dir='./train-demo/',
    model=model,
)
2024-04-09 20:19:16:rastervision.pytorch_learner.learner: INFO - Loading learner from bundle ./train-demo/model-bundle.zip.
2024-04-09 20:19:16:rastervision.pytorch_learner.learner: INFO - Unzipping model-bundle to /opt/data/tmp/tmpm5o2kvyl/model-bundle
2024-04-09 20:19:17:rastervision.pytorch_learner.learner: INFO - Loading model weights from: /opt/data/tmp/tmpm5o2kvyl/model-bundle/model.pth

For next steps, see the “Prediction and Evaluation” tutorial.

For fine-tuning – Learner.from_model_bundle()#

We can also re-construct the Learner in order to continue training, perhaps on a different dataset. To do this, we pass in train_ds and val_ds and set training=True

Note

Since we used a custom model instead of using ModelConfig, the model-bundle does not know how to construct the model; therefore, we need to pass in the model again.

Note

Optimizers and schedulers are (currently) not stored in model-bundles.

[23]:
from rastervision.pytorch_learner import SemanticSegmentationLearner

learner = SemanticSegmentationLearner.from_model_bundle(
    model_bundle_uri='./train-demo/model-bundle.zip',
    output_dir='./train-demo/',
    model=model,
    train_ds=train_ds,
    valid_ds=val_ds,
    training=True,
)
2024-04-09 20:19:20:rastervision.pytorch_learner.learner: INFO - Loading learner from bundle ./train-demo/model-bundle.zip.
2024-04-09 20:19:20:rastervision.pytorch_learner.learner: INFO - Unzipping model-bundle to /opt/data/tmp/tmplp5b8748/model-bundle
2024-04-09 20:19:21:rastervision.pytorch_learner.learner: INFO - Loading model weights from: /opt/data/tmp/tmplp5b8748/model-bundle/model.pth
2024-04-09 20:19:21:rastervision.pytorch_learner.learner: INFO - Building dataloaders
2024-04-09 20:19:21:rastervision.pytorch_learner.learner: INFO - Loading checkpoint from ./train-demo/last-model.pth

Continue training:

[24]:
learner.train(epochs=1)
2024-04-09 20:19:35:rastervision.pytorch_learner.learner: INFO - Resuming training from epoch 4
2024-04-09 20:19:35:rastervision.pytorch_learner.learner: INFO - train_ds: 400 items
2024-04-09 20:19:35:rastervision.pytorch_learner.learner: INFO - valid_ds: 100 items
2024-04-09 20:19:35:rastervision.pytorch_learner.learner: INFO - Plotting sample training batch.
2024-04-09 20:19:36:rastervision.pytorch_learner.learner: INFO - Plotting sample validation batch.
2024-04-09 20:19:36:rastervision.pytorch_learner.learner: INFO - epoch: 4
2024-04-09 20:19:40:rastervision.pytorch_learner.learner: INFO - metrics:
{'epoch': 4,
 'train_loss': 0.34853240847587585,
 'train_time': '0:00:02.899091',
 'val_loss': 0.4300776422023773,
 'avg_precision': 0.9265687465667725,
 'avg_recall': 0.8659200668334961,
 'avg_f1': 0.8952184319496155,
 'background_precision': 0.9700926542282104,
 'background_recall': 0.8847286701202393,
 'background_f1': 0.9254463315010071,
 'building_precision': 0.2373460829257965,
 'building_recall': 0.5680769085884094,
 'building_f1': 0.33480748534202576,
 'valid_time': '0:00:01.288514'}