Note

This page was generated from lightning_workflow.ipynb.

Note

If running outside of the Docker image, you may need to set some environment variables manually. You can do it like so:

import os
from subprocess import check_output

os.environ['GDAL_DATA'] = check_output('pip show rasterio | grep Location | awk \'{print $NF"/rasterio/gdal_data/"}\'', shell=True).decode().strip()

Using Raster Vision with Lightning#

The Lightning logo.

Lightning (formerly known as PyTorch Lightning) is a high-level library for training PyTorch models. In this tutorial, we demonstrate a complete workflow for doing semantic segmentation on SpaceNet Vegas using a combination of Raster Vision and Lightning. We use Raster Vision for reading data, Lightning for training a model, and then Raster Vision again for making predictions and evaluations on whole scenes.

Raster Vision has easy-to-use, built-in model training functionality implemented by the Learner class which is shown in the “Training a model” tutorial. However, some users may prefer to use Lightning for training models, either because they already know how to use it, and like it, or because they desire more flexibility than the Learner class offers. This notebook shows how these libraries can be used together, but does not attempt to use either library in a particularly sophisticated manner.

First, we need to install pytorch-lightning since it is not a dependency of Raster Vision.

[ ]:
! pip install pytorch-lightning==2.0.5

Define training and validation datasets#

We use Raster Vision to create training and validation Dataset objects. To keep things simple, we use a single scene for training and the same for validation. In a real workflow we would use many more scenes.

[4]:
import os

os.environ['AWS_NO_SIGN_REQUEST'] = 'YES'
[12]:
import albumentations as A

from rastervision.pytorch_learner import (
    SemanticSegmentationRandomWindowGeoDataset,
    SemanticSegmentationSlidingWindowGeoDataset,
    SemanticSegmentationVisualizer)
from rastervision.core.data import ClassConfig, StatsTransformer
[50]:
scene_id = 5631
train_image_uri = f's3://spacenet-dataset/spacenet/SN2_buildings/train/AOI_2_Vegas/PS-RGB/SN2_buildings_train_AOI_2_Vegas_PS-RGB_img{scene_id}.tif'
train_label_uri = f's3://spacenet-dataset/spacenet/SN2_buildings/train/AOI_2_Vegas/geojson_buildings/SN2_buildings_train_AOI_2_Vegas_geojson_buildings_img{scene_id}.geojson'

class_config = ClassConfig(
    names=['building', 'background'],
    colors=['orange', 'black'],
    null_class='background')

data_augmentation_transform = A.Compose([
    A.Flip(),
    A.ShiftScaleRotate(),
    A.RGBShift()
])

train_ds = SemanticSegmentationRandomWindowGeoDataset.from_uris(
    class_config=class_config,
    image_uri=train_image_uri,
    label_vector_uri=train_label_uri,
    label_vector_default_class_id=class_config.get_class_id('building'),
    size_lims=(200, 250),
    out_size=325,
    max_windows=10,
    transform=data_augmentation_transform,
    padding=50,
    image_raster_source_kw=dict(raster_transformers=[
        StatsTransformer(means=[562.7, 716.6, 517.1], stds=[341.1, 341.5, 197.5])
    ])
)
2023-07-20 19:18:10:rastervision.pipeline.file_system.utils: INFO - Using cached file /opt/data/tmp/cache/s3/spacenet-dataset/spacenet/SN2_buildings/train/AOI_2_Vegas/PS-RGB/SN2_buildings_train_AOI_2_Vegas_PS-RGB_img5631.tif.
2023-07-20 19:18:10:rastervision.pipeline.file_system.utils: INFO - Using cached file /opt/data/tmp/cache/s3/spacenet-dataset/spacenet/SN2_buildings/train/AOI_2_Vegas/geojson_buildings/SN2_buildings_train_AOI_2_Vegas_geojson_buildings_img5631.geojson.

To check that data is being read correctly, we use the Visualizer to plot a batch.

[25]:
viz = SemanticSegmentationVisualizer(
    class_names=class_config.names, class_colors=class_config.colors)
x, y = viz.get_batch(train_ds, 4)
viz.plot_batch(x, y, show=True)
../../_images/usage_tutorials_lightning_workflow_9_0.png
[56]:
scene_id = 5632
val_image_uri = f's3://spacenet-dataset/spacenet/SN2_buildings/train/AOI_2_Vegas/PS-RGB/SN2_buildings_train_AOI_2_Vegas_PS-RGB_img{scene_id}.tif'
val_label_uri = f's3://spacenet-dataset/spacenet/SN2_buildings/train/AOI_2_Vegas/geojson_buildings/SN2_buildings_train_AOI_2_Vegas_geojson_buildings_img{scene_id}.geojson'

val_ds = SemanticSegmentationSlidingWindowGeoDataset.from_uris(
    class_config=class_config,
    image_uri=val_image_uri,
    label_vector_uri=val_label_uri,
    label_vector_default_class_id=class_config.get_class_id('building'),
    size=325,
    stride=100,
)
2023-07-20 19:20:20:rastervision.pipeline.file_system.utils: INFO - Using cached file /opt/data/tmp/cache/s3/spacenet-dataset/spacenet/SN2_buildings/train/AOI_2_Vegas/PS-RGB/SN2_buildings_train_AOI_2_Vegas_PS-RGB_img5632.tif.
2023-07-20 19:20:20:rastervision.pipeline.file_system.utils: INFO - Using cached file /opt/data/tmp/cache/s3/spacenet-dataset/spacenet/SN2_buildings/train/AOI_2_Vegas/geojson_buildings/SN2_buildings_train_AOI_2_Vegas_geojson_buildings_img5632.geojson.

Train Model using Lightning#

Here we build a DeepLab-ResNet50 model, and then train it using Lightning. We only train for 3 epochs so this can run in a minute or so on a CPU. In a real workflow we would train for 10-100 epochs on GPU. Because of this, the model will not be accurate at all.

[20]:
from tqdm.autonotebook import tqdm
import torch
from torch.utils.data import DataLoader
from torch.nn import functional as F
from torchvision.models.segmentation import deeplabv3_resnet50
import pytorch_lightning as pl

from rastervision.pipeline.file_system import make_dir
[38]:
batch_size = 8
lr = 1e-4
epochs = 3
output_dir = './lightning-demo/'
make_dir(output_dir)
fast_dev_run = False
[57]:
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)
val_dl = DataLoader(val_ds, batch_size=batch_size, num_workers=4)

One of the main abstractions in Lightning is the LightningModule which extends a PyTorch nn.Module with extra methods that define how to train and validate the model. Here we define a LightningModule that does the bare minimum to train a DeepLab semantic segmentation model.

[28]:
class SemanticSegmentation(pl.LightningModule):
    def __init__(self, deeplab, lr=1e-4):
        super().__init__()
        self.deeplab = deeplab
        self.lr = lr

    def forward(self, img):
        return self.deeplab(img)['out']

    def training_step(self, batch, batch_idx):
        img, mask = batch
        img = img.float()
        mask = mask.long()
        out = self.forward(img)
        loss = F.cross_entropy(out, mask)
        log_dict = {'train_loss': loss}
        self.log_dict(log_dict, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        img, mask = batch
        img = img.float()
        mask = mask.long()
        out = self.forward(img)
        loss = F.cross_entropy(out, mask)
        log_dict = {'validation_loss': loss}
        self.log_dict(log_dict, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(), lr=self.lr)
        return optimizer

The other main abstraction in Lighting is the Trainer which is responsible for actually training a LightningModule. This is configured to log metrics to Tensorboard.

[ ]:
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger

deeplab = deeplabv3_resnet50(num_classes=len(class_config) + 1)
model = SemanticSegmentation(deeplab, lr=lr)
tb_logger = TensorBoardLogger(save_dir=output_dir, flush_secs=10)
trainer = pl.Trainer(
    accelerator='auto',
    min_epochs=1,
    max_epochs=epochs+1,
    default_root_dir=output_dir,
    logger=[tb_logger],
    fast_dev_run=fast_dev_run,
    log_every_n_steps=1,
)

Monitor training using Tensorboard#

This runs an instance of Tensorboard inside this notebook.

Note

  • If running inside the Raster Vision docker image, you will need to pass –tensorboard to docker/run for this to work.

  • If the dashboard doen’t auto-reload, you can click the reload button on the top-right.

[ ]:
%reload_ext tensorboard
%tensorboard --bind_all --logdir "./lightning-demo/lightning_logs" --reload_interval 10

A screenshot of the Tensorboard dashboard.

[ ]:
trainer.fit(model, train_dl, val_dl)

Load saved model#

After training the model for only 3 epochs, it will not make good predictions. In order to have some sensible looking output, we will loads weights from a model that was fully trained on SpaceNet Vegas.

[44]:
weights_uri = 'https://s3.amazonaws.com/azavea-research-public-data/raster-vision/examples/model-zoo-0.21/spacenet-vegas-buildings-ss/model.pth'
deeplab.load_state_dict(torch.hub.load_state_dict_from_url(weights_uri, map_location=torch.device('cpu')))
Downloading: "https://s3.amazonaws.com/azavea-research-public-data/raster-vision/examples/model-zoo-0.20/spacenet-vegas-buildings-ss/model.pth" to /root/.cache/torch/hub/checkpoints/model.pth
100%|████████████████████████████████████████████████████████████████████████████████| 152M/152M [00:05<00:00, 27.7MB/s]
[44]:
<All keys matched successfully>

Make predictions for scene#

We can now use Raster Vision’s SemanticSegmentationLabels class to make predictions over a whole scene. The SemanticSegmentationLabels.from_predictions() method takes an iterator over predictions. We create this using a get_predictions() helper function defined below.

[31]:
def get_predictions(dataloader):
    for x, _ in tqdm(dataloader):
        with torch.inference_mode():
            out_batch = model(x)
        # This needs to yield a single prediction, not a whole batch of them.
        for out in out_batch:
            yield out.numpy()
[58]:
from rastervision.core.data import SemanticSegmentationLabels

model.eval()
predictions = get_predictions(val_dl)
pred_labels = SemanticSegmentationLabels.from_predictions(
    val_ds.windows,
    predictions,
    smooth=True,
    extent=val_ds.scene.extent,
    num_classes=len(class_config) + 1)
scores = pred_labels.get_score_arr(pred_labels.extent)

Visualize and then save predictions#

[59]:
from matplotlib import pyplot as plt

scores_building = scores[0]
scores_background = scores[1]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
fig.tight_layout(w_pad=-2)
ax1.imshow(scores_building, cmap='plasma')
ax1.axis('off')
ax1.set_title('building')
ax2.imshow(scores_background, cmap='plasma')
ax2.axis('off')
ax2.set_title('background')
plt.show()
../../_images/usage_tutorials_lightning_workflow_31_0.png
[60]:
from os.path import join

pred_labels.save(
    uri=join(output_dir, 'predictions'),
    crs_transformer=val_ds.scene.raster_source.crs_transformer,
    class_config=class_config)

Evaluate predictions for a scene#

Now that we have predictions for the validation scene, we can evaluate them by comparing to ground truth using SemanticSegmentationEvaluator.

[61]:
from rastervision.core.evaluation import SemanticSegmentationEvaluator

evaluator = SemanticSegmentationEvaluator(class_config)

evaluation = evaluator.evaluate_predictions(
    ground_truth=val_ds.scene.label_source.get_labels(),
    predictions=pred_labels)

SemanticSegmentationEvaluator.evaluate_predictions() returns a SemanticSegmentationEvaluation object which contains evaluations for each class as ClassEvaluationItem objects.

Here are the metrics for the building and background classes.

[62]:
evaluation.class_to_eval_item[0]
[62]:
{'class_id': 0,
 'class_name': 'building',
 'conf_mat': [[338610.0, 2806.0], [11267.0, 69817.0]],
 'conf_mat_dict': {'FN': 11267.0, 'FP': 2806.0, 'TN': 338610.0, 'TP': 69817.0},
 'conf_mat_frac': [[0.8014437869822485, 0.006641420118343195],
                   [0.026667455621301774, 0.1652473372781065]],
 'conf_mat_frac_dict': {'FN': 0.026667455621301774,
                        'FP': 0.006641420118343195,
                        'TN': 0.8014437869822485,
                        'TP': 0.1652473372781065},
 'count_error': 8461.0,
 'gt_count': 81084.0,
 'metrics': {'f1': 0.9084426864098577,
             'precision': 0.9613621029150544,
             'recall': 0.8610453357012481,
             'sensitivity': 0.8610453357012481,
             'specificity': 0.9917812873444712},
 'pred_count': 72623.0,
 'relative_frequency': 0.19191479289940827}
[63]:
evaluation.class_to_eval_item[1]
[63]:
{'class_id': 1,
 'class_name': 'background',
 'conf_mat': [[69817.0, 11267.0], [2806.0, 338610.0]],
 'conf_mat_dict': {'FN': 2806.0, 'FP': 11267.0, 'TN': 69817.0, 'TP': 338610.0},
 'conf_mat_frac': [[0.1652473372781065, 0.026667455621301774],
                   [0.006641420118343195, 0.8014437869822485]],
 'conf_mat_frac_dict': {'FN': 0.006641420118343195,
                        'FP': 0.026667455621301774,
                        'TN': 0.1652473372781065,
                        'TP': 0.8014437869822485},
 'count_error': 8461.0,
 'gt_count': 341416.0,
 'metrics': {'f1': 0.9796424960183309,
             'precision': 0.9677972544637116,
             'recall': 0.9917812873444712,
             'sensitivity': 0.9917812873444712,
             'specificity': 0.8610453357012481},
 'pred_count': 349877.0,
 'relative_frequency': 0.8080852071005917}
[64]:
evaluation.save(join(output_dir, 'evaluation.json'))