This page was generated from lightning_workflow.ipynb.


If running outside of the Docker image, you might need to set a couple of environment variables manually. You can do it like so:

import os
from subprocess import check_output

os.environ['GDAL_DATA'] = check_output('pip show rasterio | grep Location | awk \'{print $NF"/rasterio/gdal_data/"}\'', shell=True).decode().strip()
os.environ['AWS_NO_SIGN_REQUEST'] = 'YES'

See this Colab notebook for an example.

Using Raster Vision with Lightning#

The Lightning logo.

Lightning (formerly known as PyTorch Lightning) is a high-level library for training PyTorch models. In this tutorial, we demonstrate a complete workflow for doing semantic segmentation on SpaceNet Vegas using a combination of Raster Vision and Lightning. We use Raster Vision for reading data, Lightning for training a model, and then Raster Vision again for making predictions and evaluations on whole scenes.

Raster Vision has easy-to-use, built-in model training functionality implemented by the Learner class which is shown in the “Training a model” tutorial. However, some users may prefer to use Lightning for training models, either because they already know how to use it, and like it, or because they desire more flexibility than the Learner class offers. This notebook shows how these libraries can be used together, but does not attempt to use either library in a particularly sophisticated manner.

First, we need to install pytorch-lightning since it is not a dependency of Raster Vision.

[ ]:
! pip install pytorch-lightning==1.8

Define training and validation datasets#

We use Raster Vision to create training and validation Dataset objects. To keep things simple, we use a single scene for training and the same for validation. In a real workflow we would use many more scenes.

import albumentations as A

from rastervision.pytorch_learner import (
from import ClassConfig
scene_id = 5631
train_image_uri = f's3://spacenet-dataset/spacenet/SN2_buildings/train/AOI_2_Vegas/PS-RGB/SN2_buildings_train_AOI_2_Vegas_PS-RGB_img{scene_id}.tif'
train_label_uri = f's3://spacenet-dataset/spacenet/SN2_buildings/train/AOI_2_Vegas/geojson_buildings/SN2_buildings_train_AOI_2_Vegas_geojson_buildings_img{scene_id}.geojson'

class_config = ClassConfig(
    names=['building', 'background'],
    colors=['orange', 'black'],

data_augmentation_transform = A.Compose([

train_ds = SemanticSegmentationRandomWindowGeoDataset.from_uris(
    size_lims=(300, 350),
2022-11-21 22:27:37:rastervision.pipeline.file_system.utils: INFO - Using cached file /opt/data/tmp/cache/s3/spacenet-dataset/spacenet/SN2_buildings/train/AOI_2_Vegas/PS-RGB/SN2_buildings_train_AOI_2_Vegas_PS-RGB_img5631.tif.
2022-11-21 WARNING - Raster block size (2, 650) is too non-square. This can slow down reading. Consider re-tiling using GDAL.
2022-11-21 22:27:37:rastervision.pipeline.file_system.utils: INFO - Using cached file /opt/data/tmp/cache/s3/spacenet-dataset/spacenet/SN2_buildings/train/AOI_2_Vegas/geojson_buildings/SN2_buildings_train_AOI_2_Vegas_geojson_buildings_img5631.geojson.

To check that data is being read correctly, we use the Visualizer to plot a batch.

viz = SemanticSegmentationVisualizer(
    class_names=class_config.names, class_colors=class_config.colors)
x, y = viz.get_batch(train_ds, 4)
viz.plot_batch(x, y, show=True)
scene_id = 5632
val_image_uri = f's3://spacenet-dataset/spacenet/SN2_buildings/train/AOI_2_Vegas/PS-RGB/SN2_buildings_train_AOI_2_Vegas_PS-RGB_img{scene_id}.tif'
val_label_uri = f's3://spacenet-dataset/spacenet/SN2_buildings/train/AOI_2_Vegas/geojson_buildings/SN2_buildings_train_AOI_2_Vegas_geojson_buildings_img{scene_id}.geojson'

val_ds = SemanticSegmentationSlidingWindowGeoDataset.from_uris(
2022-11-21 22:28:31:rastervision.pipeline.file_system.utils: INFO - Using cached file /opt/data/tmp/cache/s3/spacenet-dataset/spacenet/SN2_buildings/train/AOI_2_Vegas/PS-RGB/SN2_buildings_train_AOI_2_Vegas_PS-RGB_img5632.tif.
2022-11-21 WARNING - Raster block size (2, 650) is too non-square. This can slow down reading. Consider re-tiling using GDAL.
2022-11-21 22:28:31:rastervision.pipeline.file_system.utils: INFO - Using cached file /opt/data/tmp/cache/s3/spacenet-dataset/spacenet/SN2_buildings/train/AOI_2_Vegas/geojson_buildings/SN2_buildings_train_AOI_2_Vegas_geojson_buildings_img5632.geojson.

Train Model using Lightning#

Here we build a DeepLab-ResNet50 model, and then train it using Lightning. We only train for 3 epochs so this can run in a minute or so on a CPU. In a real workflow we would train for 10-100 epochs on GPU. Because of this, the model will not be accurate at all.

from tqdm.autonotebook import tqdm
import torch
from import DataLoader
from torch.nn import functional as F
from torchvision.models.segmentation import deeplabv3_resnet50
import pytorch_lightning as pl

from rastervision.pipeline.file_system import make_dir
batch_size = 8
lr = 1e-4
epochs = 3
output_dir = './lightning-demo/'
fast_dev_run = False
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)
val_dl = DataLoader(val_ds, batch_size=batch_size, num_workers=4)

One of the main abstractions in Lightning is the LightningModule which extends a PyTorch nn.Module with extra methods that define how to train and validate the model. Here we define a LightningModule that does the bare minimum to train a DeepLab semantic segmentation model.

class SemanticSegmentation(pl.LightningModule):
    def __init__(self, deeplab, lr=1e-4):
        self.deeplab = deeplab = lr

    def forward(self, img):
        return self.deeplab(img)['out']

    def training_step(self, batch, batch_idx):
        img, mask = batch
        img = img.float()
        mask = mask.long()
        out = self.forward(img)
        loss = F.cross_entropy(out, mask)
        log_dict = {'train_loss': loss}
        self.log_dict(log_dict, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        img, mask = batch
        img = img.float()
        mask = mask.long()
        out = self.forward(img)
        loss = F.cross_entropy(out, mask)
        log_dict = {'validation_loss': loss}
        self.log_dict(log_dict, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
        return optimizer

The other main abstraction in Lighting is the Trainer which is responsible for actually training a LightningModule. This is configured to log metrics to Tensorboard.

from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger

deeplab = deeplabv3_resnet50(num_classes=len(class_config) + 1)
model = SemanticSegmentation(deeplab, lr=lr)
tb_logger = TensorBoardLogger(save_dir=output_dir, flush_secs=10)
trainer = pl.Trainer(
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

Monitor training using Tensorboard#

This runs an instance of Tensorboard inside this notebook.


  • If running inside the Raster Vision docker image, you will need to pass –tensorboard to docker/run for this to work.

  • If the dashboard doen’t auto-reload, you can click the reload button on the top-right.

[ ]:
%reload_ext tensorboard
%tensorboard --bind_all --logdir "./lightning-demo/lightning_logs" --reload_interval 10

A screenshot of the Tensorboard dashboard.

[15]:, train_dl, val_dl)
Missing logger folder: ./lightning-demo/lightning_logs

  | Name    | Type      | Params
0 | deeplab | DeepLabV3 | 39.6 M
39.6 M    Trainable params
0         Non-trainable params
39.6 M    Total params
158.537   Total estimated model params size (MB)
`` stopped: `max_epochs=4` reached.

Load saved model#

After training the model for only 3 epochs, it will not make good predictions. In order to have some sensible looking output, we will loads weights from a model that was fully trained on SpaceNet Vegas.

weights_uri = ''
deeplab.load_state_dict(torch.hub.load_state_dict_from_url(weights_uri, map_location=torch.device('cpu')))
<All keys matched successfully>

Make predictions for scene#

We can now use Raster Vision’s SemanticSegmentationLabels class to make predictions over a whole scene. The SemanticSegmentationLabels.from_predictions() method takes an iterator over predictions. We create this using a get_predictions() helper function defined below.

def get_predictions(dataloader):
    for x, _ in tqdm(dataloader):
        with torch.inference_mode():
            out_batch = model(x)
        # This needs to yield a single prediction, not a whole batch of them.
        for out in out_batch:
            yield out.numpy()
from import SemanticSegmentationLabels

predictions = get_predictions(val_dl)
pred_labels = SemanticSegmentationLabels.from_predictions(,
    num_classes=len(class_config) + 1)
scores = pred_labels.get_score_arr(pred_labels.extent)

Visualize and then save predictions#

from matplotlib import pyplot as plt

scores_building = scores[0]
scores_background = scores[1]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
ax1.imshow(scores_building, cmap='plasma')
ax2.imshow(scores_background, cmap='plasma')
from os.path import join
    uri=join(output_dir, 'predictions'),

Evaluate predictions for a scene#

Now that we have predictions for the validation scene, we can evaluate them by comparing to ground truth using SemanticSegmentationEvaluator.

from rastervision.core.evaluation import SemanticSegmentationEvaluator

evaluator = SemanticSegmentationEvaluator(class_config)

evaluation = evaluator.evaluate_predictions(

SemanticSegmentationEvaluator.evaluate_predictions() returns a SemanticSegmentationEvaluation object which contains evaluations for each class as ClassEvaluationItem objects.

Here are the metrics for the building and background classes.

{'class_id': 0,
 'class_name': 'building',
 'conf_mat': [[336253.0, 5163.0], [3586.0, 77498.0]],
 'conf_mat_dict': {'FN': 3586.0, 'FP': 5163.0, 'TN': 336253.0, 'TP': 77498.0},
 'conf_mat_frac': [[0.7958650887573965, 0.012220118343195266],
                   [0.008487573964497041, 0.18342721893491123]],
 'conf_mat_frac_dict': {'FN': 0.008487573964497041,
                        'FP': 0.012220118343195266,
                        'TN': 0.7958650887573965,
                        'TP': 0.18342721893491123},
 'count_error': 1577.0,
 'gt_count': 81084.0,
 'metrics': {'f1': 0.946569360896516,
             'precision': 0.937540073311477,
             'recall': 0.9557742587933501,
             'sensitivity': 0.9557742587933501,
             'specificity': 0.9848776858729527},
 'pred_count': 82661.0,
 'relative_frequency': 0.19191479289940827}
{'class_id': 1,
 'class_name': 'background',
 'conf_mat': [[77498.0, 3586.0], [5163.0, 336253.0]],
 'conf_mat_dict': {'FN': 5163.0, 'FP': 3586.0, 'TN': 77498.0, 'TP': 336253.0},
 'conf_mat_frac': [[0.18342721893491123, 0.008487573964497041],
                   [0.012220118343195266, 0.7958650887573965]],
 'conf_mat_frac_dict': {'FN': 0.012220118343195266,
                        'FP': 0.008487573964497041,
                        'TN': 0.18342721893491123,
                        'TP': 0.7958650887573965},
 'count_error': 1577.0,
 'gt_count': 341416.0,
 'metrics': {'f1': 0.9871575254493545,
             'precision': 0.9894479444678215,
             'recall': 0.9848776858729527,
             'sensitivity': 0.9848776858729527,
             'specificity': 0.9557742587933501},
 'pred_count': 339839.0,
 'relative_frequency': 0.8080852071005917}
[24]:, 'evaluation.json'))