Note
This page was generated from train.ipynb.
Note
If running outside of the Docker image, you may need to set some environment variables manually. You can do it like so:
import os
from subprocess import check_output
os.environ['GDAL_DATA'] = check_output('pip show rasterio | grep Location | awk \'{print $NF"/rasterio/gdal_data/"}\'', shell=True).decode().strip()
We will be accessing files on S3 in this notebook. Since those files are public, we set the AWS_NO_SIGN_REQUEST
to tell rasterio
to skip the sign-in.
[ ]:
%env AWS_NO_SIGN_REQUEST=YES
Training a model#
Define ClassConfig#
[2]:
from rastervision.core.data import ClassConfig
class_config = ClassConfig(
names=['background', 'building'],
colors=['lightgray', 'darkred'],
null_class='background')
Define training and validation datasets#
To keep things simple, we use one scene for training and one for validation. In a real workflow, we would normally use many more scenes.
[3]:
train_image_uri = 's3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0331E-1257N_1327_3160_13/images/global_monthly_2018_01_mosaic_L15-0331E-1257N_1327_3160_13.tif'
train_label_uri = 's3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0331E-1257N_1327_3160_13/labels/global_monthly_2018_01_mosaic_L15-0331E-1257N_1327_3160_13_Buildings.geojson'
[4]:
val_image_uri = 's3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0357E-1223N_1429_3296_13/images/global_monthly_2018_01_mosaic_L15-0357E-1223N_1429_3296_13.tif'
val_label_uri = 's3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0357E-1223N_1429_3296_13/labels/global_monthly_2018_01_mosaic_L15-0357E-1223N_1429_3296_13_Buildings.geojson'
[5]:
import albumentations as A
from rastervision.pytorch_learner import (
SemanticSegmentationRandomWindowGeoDataset,
SemanticSegmentationSlidingWindowGeoDataset,
SemanticSegmentationVisualizer)
viz = SemanticSegmentationVisualizer(
class_names=class_config.names, class_colors=class_config.colors)
Training dataset with random-window sampling and data augmentation#
[6]:
data_augmentation_transform = A.Compose([
A.Flip(),
A.ShiftScaleRotate(),
A.OneOf([
A.HueSaturationValue(hue_shift_limit=10),
A.RGBShift(),
A.ToGray(),
A.ToSepia(),
A.RandomBrightnessContrast(),
A.RandomGamma(),
]),
A.CoarseDropout(max_height=32, max_width=32, max_holes=5)
])
train_ds = SemanticSegmentationRandomWindowGeoDataset.from_uris(
class_config=class_config,
image_uri=train_image_uri,
label_vector_uri=train_label_uri,
label_vector_default_class_id=class_config.get_class_id('building'),
size_lims=(150, 200),
out_size=256,
max_windows=400,
transform=data_augmentation_transform,
)
len(train_ds)
2024-04-09 20:16:43:rastervision.pipeline.file_system.utils: INFO - Using cached file /opt/data/tmp/cache/s3/spacenet-dataset/spacenet/SN7_buildings/train/L15-0331E-1257N_1327_3160_13/images/global_monthly_2018_01_mosaic_L15-0331E-1257N_1327_3160_13.tif.
2024-04-09 20:16:43:rastervision.pipeline.file_system.utils: INFO - Using cached file /opt/data/tmp/cache/s3/spacenet-dataset/spacenet/SN7_buildings/train/L15-0331E-1257N_1327_3160_13/labels/global_monthly_2018_01_mosaic_L15-0331E-1257N_1327_3160_13_Buildings.geojson.
[6]:
400
[8]:
x, y = viz.get_batch(train_ds, 4)
viz.plot_batch(x, y, show=True)
Validation dataset with sliding-window sampling (and no data augmentation)#
[7]:
val_ds = SemanticSegmentationSlidingWindowGeoDataset.from_uris(
class_config=class_config,
image_uri=val_image_uri,
label_vector_uri=val_label_uri,
label_vector_default_class_id=class_config.get_class_id('building'),
size=200,
stride=100,
out_size=256,
)
len(val_ds)
2024-04-09 20:16:46:rastervision.pipeline.file_system.utils: INFO - Using cached file /opt/data/tmp/cache/s3/spacenet-dataset/spacenet/SN7_buildings/train/L15-0357E-1223N_1429_3296_13/images/global_monthly_2018_01_mosaic_L15-0357E-1223N_1429_3296_13.tif.
2024-04-09 20:16:46:rastervision.pipeline.file_system.utils: INFO - Using cached file /opt/data/tmp/cache/s3/spacenet-dataset/spacenet/SN7_buildings/train/L15-0357E-1223N_1429_3296_13/labels/global_monthly_2018_01_mosaic_L15-0357E-1223N_1429_3296_13_Buildings.geojson.
[7]:
100
[10]:
x, y = viz.get_batch(val_ds, 4)
viz.plot_batch(x, y, show=True)
Define model#
Use a light-weight panoptic FPN model with a ResNet-18 backbone.
[ ]:
import torch
model = torch.hub.load(
'AdeelH/pytorch-fpn:0.3',
'make_fpn_resnet',
name='resnet18',
fpn_type='panoptic',
num_classes=len(class_config),
fpn_channels=128,
in_channels=3,
out_size=(256, 256),
pretrained=True)
Configure the training#
SolverConfig
– Configure the loss, optimizer, and scheduler(s)#
[10]:
from rastervision.pytorch_learner import SolverConfig
solver_cfg = SolverConfig(
batch_sz=8,
lr=3e-2,
class_loss_weights=[1., 10.]
)
LearnerConfig
– Combine DataConfig
, SolverConfig
(and optionally, ModelConfig
)#
[11]:
from rastervision.pytorch_learner import SemanticSegmentationLearnerConfig
learner_cfg = SemanticSegmentationLearnerConfig(data=data_cfg, solver=solver_cfg)
Initialize Learner
#
[15]:
from rastervision.pytorch_learner import SemanticSegmentationLearner
learner = SemanticSegmentationLearner(
cfg=learner_cfg,
output_dir='./train-demo/',
model=model,
train_ds=train_ds,
valid_ds=val_ds,
)
2024-04-09 20:18:03:rastervision.pytorch_learner.learner: INFO - Building dataloaders
[16]:
learner.log_data_stats()
2024-04-09 20:18:03:rastervision.pytorch_learner.learner: INFO - train_ds: 400 items
2024-04-09 20:18:03:rastervision.pytorch_learner.learner: INFO - valid_ds: 100 items
Run Tensorboard
for monitoring#
Note
If running inside the Raster Vision docker image, you will need to pass –tensorboard to docker/run for this to work.
If the dashboard doen’t auto-reload, you can click the reload button on the top-right.
[13]:
%load_ext tensorboard
This will start an instance of tensorboard and embed it in the output of the cell:
[ ]:
%tensorboard --bind_all --logdir "./train-demo/tb-logs" --reload_interval 10
Train – Learner.train()
#
[17]:
learner.train(epochs=3)
2024-04-09 20:18:43:rastervision.pytorch_learner.learner: INFO - train_ds: 400 items
2024-04-09 20:18:43:rastervision.pytorch_learner.learner: INFO - valid_ds: 100 items
2024-04-09 20:18:43:rastervision.pytorch_learner.learner: INFO - Plotting sample training batch.
2024-04-09 20:18:43:rastervision.pytorch_learner.learner: INFO - Plotting sample validation batch.
2024-04-09 20:18:44:rastervision.pytorch_learner.learner: INFO - epoch: 0
2024-04-09 20:18:49:rastervision.pytorch_learner.learner: INFO - metrics:
{'epoch': 0,
'train_loss': 0.2913476228713989,
'train_time': '0:00:03.397020',
'val_loss': 0.3804737329483032,
'avg_precision': 0.9294803738594055,
'avg_recall': 0.8374766707420349,
'avg_f1': 0.8810832500457764,
'background_precision': 0.9746540188789368,
'background_recall': 0.8492995500564575,
'background_f1': 0.9076691269874573,
'building_precision': 0.21413441002368927,
'building_recall': 0.6502557396888733,
'building_f1': 0.3221742510795593,
'valid_time': '0:00:01.604234'}
2024-04-09 20:18:49:rastervision.pytorch_learner.learner: INFO - epoch: 1
2024-04-09 20:18:53:rastervision.pytorch_learner.learner: INFO - metrics:
{'epoch': 1,
'train_loss': 0.24667344987392426,
'train_time': '0:00:02.883060',
'val_loss': 0.4013445973396301,
'avg_precision': 0.927025318145752,
'avg_recall': 0.8806623816490173,
'avg_f1': 0.9032492637634277,
'background_precision': 0.9691619873046875,
'background_recall': 0.9018215537071228,
'background_f1': 0.9342798590660095,
'building_precision': 0.25977078080177307,
'building_recall': 0.5455965399742126,
'building_f1': 0.35196369886398315,
'valid_time': '0:00:01.320343'}
2024-04-09 20:18:53:rastervision.pytorch_learner.learner: INFO - epoch: 2
2024-04-09 20:18:57:rastervision.pytorch_learner.learner: INFO - metrics:
{'epoch': 2,
'train_loss': 0.2440817952156067,
'train_time': '0:00:02.781280',
'val_loss': 0.39747753739356995,
'avg_precision': 0.9319270849227905,
'avg_recall': 0.8825308680534363,
'avg_f1': 0.9065566062927246,
'background_precision': 0.9732670783996582,
'background_recall': 0.8998284935951233,
'background_f1': 0.9351081252098083,
'building_precision': 0.2772882282733917,
'building_recall': 0.6086140275001526,
'building_f1': 0.380993515253067,
'valid_time': '0:00:01.286454'}
Train some more#
[18]:
learner.train(epochs=1)
2024-04-09 20:18:57:rastervision.pytorch_learner.learner: INFO - Resuming training from epoch 3
2024-04-09 20:18:57:rastervision.pytorch_learner.learner: INFO - train_ds: 400 items
2024-04-09 20:18:57:rastervision.pytorch_learner.learner: INFO - valid_ds: 100 items
2024-04-09 20:18:57:rastervision.pytorch_learner.learner: INFO - Plotting sample training batch.
2024-04-09 20:18:57:rastervision.pytorch_learner.learner: INFO - Plotting sample validation batch.
2024-04-09 20:18:58:rastervision.pytorch_learner.learner: INFO - epoch: 3
2024-04-09 20:19:02:rastervision.pytorch_learner.learner: INFO - metrics:
{'epoch': 3,
'train_loss': 0.2168915718793869,
'train_time': '0:00:02.736246',
'val_loss': 0.43969860672950745,
'avg_precision': 0.9304015040397644,
'avg_recall': 0.8608637452125549,
'avg_f1': 0.894282877445221,
'background_precision': 0.9738898873329163,
'background_recall': 0.8755510449409485,
'background_f1': 0.9221059679985046,
'building_precision': 0.24174126982688904,
'building_recall': 0.6282839775085449,
'building_f1': 0.3491442799568176,
'valid_time': '0:00:01.285872'}
Examine predictions – Learner.plot_predictions()
#
[19]:
learner.plot_predictions(split='valid', show=True)
2024-04-09 20:19:02:rastervision.pytorch_learner.learner: INFO - Making and plotting sample predictions on the valid set...
2024-04-09 20:19:03:rastervision.pytorch_learner.learner: INFO - Sample predictions written to ./train-demo/valid_preds.png.
Save as a model-bundle – Learner.save_model_bundle()
#
Note the warning about ModelConfig
. This is relevant when loading from from the bundle as we will see below.
[20]:
learner.save_model_bundle()
2024-04-09 20:19:03:rastervision.pytorch_learner.learner: WARNING - Model was not configured via ModelConfig, and therefore, will not be reconstructable form the model-bundle. You will need to initialize the model yourself and pass it to from_model_bundle().
2024-04-09 20:19:03:rastervision.pytorch_learner.learner: INFO - Creating bundle.
2024-04-09 20:19:03:rastervision.pytorch_learner.learner: INFO - Exporting to model to ONNX.
2024-04-09 20:19:04:rastervision.pytorch_learner.learner: INFO - Saving bundle to ./train-demo/model-bundle.zip.
Examine learner output#
The trained model weights are saved at ./train-demo/last-model.pth
as well as inside the model-bundle.
[21]:
!tree "./train-demo/"
./train-demo/
├── checkpoints
├── dataloaders
│ ├── train.png
│ └── valid.png
├── last-model.pth
├── learner-config.json
├── log.csv
├── model-bundle.zip
├── tb-logs
│ ├── events.out.tfevents.1712693861.60eb69042744.1456.0
│ ├── events.out.tfevents.1712693868.60eb69042744.1456.1
│ └── events.out.tfevents.1712693883.60eb69042744.1456.2
└── valid_preds.png
3 directories, 10 files
Using model-bundles#
For predictions – Learner.from_model_bundle()
#
We can use the model-bundle to re-construct our Learner
and then use it to make predictions.
Note
Since we used a custom model instead of using ModelConfig
, the model-bundle does not know how to construct the model; therefore, we need to pass in the model again.
[22]:
from rastervision.pytorch_learner import SemanticSegmentationLearner
learner = SemanticSegmentationLearner.from_model_bundle(
model_bundle_uri='./train-demo/model-bundle.zip',
output_dir='./train-demo/',
model=model,
)
2024-04-09 20:19:16:rastervision.pytorch_learner.learner: INFO - Loading learner from bundle ./train-demo/model-bundle.zip.
2024-04-09 20:19:16:rastervision.pytorch_learner.learner: INFO - Unzipping model-bundle to /opt/data/tmp/tmpm5o2kvyl/model-bundle
2024-04-09 20:19:17:rastervision.pytorch_learner.learner: INFO - Loading model weights from: /opt/data/tmp/tmpm5o2kvyl/model-bundle/model.pth
For next steps, see the “Prediction and Evaluation” tutorial.
For fine-tuning – Learner.from_model_bundle()
#
We can also re-construct the Learner
in order to continue training, perhaps on a different dataset. To do this, we pass in train_ds
and val_ds
and set training=True
Note
Since we used a custom model instead of using ModelConfig
, the model-bundle does not know how to construct the model; therefore, we need to pass in the model again.
Note
Optimizers and schedulers are (currently) not stored in model-bundles.
[23]:
from rastervision.pytorch_learner import SemanticSegmentationLearner
learner = SemanticSegmentationLearner.from_model_bundle(
model_bundle_uri='./train-demo/model-bundle.zip',
output_dir='./train-demo/',
model=model,
train_ds=train_ds,
valid_ds=val_ds,
training=True,
)
2024-04-09 20:19:20:rastervision.pytorch_learner.learner: INFO - Loading learner from bundle ./train-demo/model-bundle.zip.
2024-04-09 20:19:20:rastervision.pytorch_learner.learner: INFO - Unzipping model-bundle to /opt/data/tmp/tmplp5b8748/model-bundle
2024-04-09 20:19:21:rastervision.pytorch_learner.learner: INFO - Loading model weights from: /opt/data/tmp/tmplp5b8748/model-bundle/model.pth
2024-04-09 20:19:21:rastervision.pytorch_learner.learner: INFO - Building dataloaders
2024-04-09 20:19:21:rastervision.pytorch_learner.learner: INFO - Loading checkpoint from ./train-demo/last-model.pth
Continue training:
[24]:
learner.train(epochs=1)
2024-04-09 20:19:35:rastervision.pytorch_learner.learner: INFO - Resuming training from epoch 4
2024-04-09 20:19:35:rastervision.pytorch_learner.learner: INFO - train_ds: 400 items
2024-04-09 20:19:35:rastervision.pytorch_learner.learner: INFO - valid_ds: 100 items
2024-04-09 20:19:35:rastervision.pytorch_learner.learner: INFO - Plotting sample training batch.
2024-04-09 20:19:36:rastervision.pytorch_learner.learner: INFO - Plotting sample validation batch.
2024-04-09 20:19:36:rastervision.pytorch_learner.learner: INFO - epoch: 4
2024-04-09 20:19:40:rastervision.pytorch_learner.learner: INFO - metrics:
{'epoch': 4,
'train_loss': 0.34853240847587585,
'train_time': '0:00:02.899091',
'val_loss': 0.4300776422023773,
'avg_precision': 0.9265687465667725,
'avg_recall': 0.8659200668334961,
'avg_f1': 0.8952184319496155,
'background_precision': 0.9700926542282104,
'background_recall': 0.8847286701202393,
'background_f1': 0.9254463315010071,
'building_precision': 0.2373460829257965,
'building_recall': 0.5680769085884094,
'building_f1': 0.33480748534202576,
'valid_time': '0:00:01.288514'}