Note

This page was generated from prechipped_datasets.ipynb.

Note

If running outside of the Docker image, you may need to set some environment variables manually. You can do it like so:

import os
from subprocess import check_output

os.environ['GDAL_DATA'] = check_output('pip show rasterio | grep Location | awk \'{print $NF"/rasterio/gdal_data/"}\'', shell=True).decode().strip()

Working with pre-chipped datasets#

It is not uncommon for geospatial datasets to be released in a “pre-chipped” form; i.e., not as large GeoTIFFs, but as non-georeferenced chips/tiles/patches in ordinary image formats such as PNG or JPEG.

In such scenarios, you do not necessarily need to use Raster Vision to read this data and train a model. In fact, you can train a model outside of Raster Vision and then use Raster Vision to run it over GeoTIFFs, as shown in the “Using Raster Vision with Lightning” tutorial.

Nevertheless, Raster Vision is capable of training with ordinary images and this tutorial notebook will walk you through examples for each of the three supported computer vision tasks. Note that this notebook only demonstrates how to read these datasets, but that is all that is needed; once you have instantiated these dataset classes, the rest of the training and prediction procedure is identical to that for non-chipped geo-referenced data.


The ImageDataset class#

The ImageDataset is a PyTorch-compatible Dataset implementation that allows reading non-geospatial image datasets.

Just like GeoDataset, it is based on AlbumentationsDataset.

Supported image formats#

Each ImageDataset subclass is capable of reading all image formats supported by pillow (.png, .jp[e]g, .tif[f], .bmp, and more) plus the numpy format, .npy.


Semantic segmentation – SemanticSegmentationImageDataset#

The SemanticSegmentationImageDataset (which internally uses a SemanticSegmentationDataReader) expects a path to a directory containing images and a path to a directory containing segmentation masks. Images are matched to their respective label-masks based on their filename (excluding the extension). The masks can be in any supported image format.

Dataset#

For this example, we will use the Extended Optical Remote Sensing Saliency Detection (EORSSD) Dataset.

Zhang, Qijian, Runmin Cong, Chongyi Li, Ming-Ming Cheng, Yuming Fang, Xiaochun Cao, Yao Zhao, and Sam Kwong. “Dense attention fluid network for salient object detection in optical remote sensing images.” IEEE Transactions on Image Processing 30 (2020): 1305-1317.

Below we download (63 MB) and unzip the data.

[14]:
!wget "https://github.com/rmcong/EORSSD-dataset/raw/master/EORSSD.zip"
!apt-get install unzip -y
!unzip -q "EORSSD.zip" -d "EORSSD"
!ls "EORSSD"
test-images  test-labels  train-images  train-labels

Usage#

[1]:
import albumentations as A

from rastervision.pytorch_learner import SemanticSegmentationImageDataset

ds = SemanticSegmentationImageDataset(
    img_dir='EORSSD/train-images/',
    label_dir='EORSSD/train-labels/',
    transform=A.Resize(256, 256),
)
len(ds)
[1]:
1400

We can read a data sample and the corresponding ground truth from the Dataset like so:

[2]:
x, y = ds[0]
x.shape, y.shape
[2]:
(torch.Size([3, 256, 256]), torch.Size([256, 256]))

And then plot it using the SemanticSegmentationVisualizer:

[3]:
from rastervision.pytorch_learner import SemanticSegmentationVisualizer

viz = SemanticSegmentationVisualizer(
    class_names=['background', 'foreground'], class_colors=['black', 'white'])
viz.plot_batch(x.unsqueeze(0), y.unsqueeze(0), show=True)
../../_images/usage_tutorials_prechipped_datasets_19_0.png

Object detection – ObjectDetectionImageDataset#

The ObjectDetectionImageDataset (which internally uses a CocoDataset) expects a path to a directory containing images and a URI to a JSON file containing annotations in the COCO format.

Dataset#

For this example, we will use the Airbus Aircraft Detection dataset.

You will need to manually download the dataset (92 MB) from here: https://www.kaggle.com/datasets/airbusgeo/airbus-aircrafts-sample-dataset.

The cells below assume that the data has been unzipped into an airbus/ directory.

[1]:
!ls "airbus/"
LICENSE.txt  README.md  annotations.csv  annotations.json  extras  images

Note

This dataset cannot, strictly speaking, be called “pre-chipped” since the individual images are still pretty large and warrant further chipping. Nevertheless, the images are provided as non-georeferenced JPEGs and thus serve the purpose of this tutorial.

Transform annotations into COCO format#

  • Add a bbox column representing bounding boxes in xywh format and delete the old geometry column.

  • Add a category_id column representing class IDs.

[1]:
import pandas as pd
from shapely.geometry import Polygon

from rastervision.core.box import Box
from rastervision.pipeline.file_system.utils import json_to_file
[31]:
class_names = ['Airplane', 'Truncated_airplane']
ann_df = pd.read_csv('airbus/annotations.csv')
ann_df.loc[:, 'bbox'] = [Box.from_shapely(Polygon(eval(g))).to_xywh() for g in ann_df.geometry]
ann_df = ann_df.drop(columns='geometry')
ann_df.loc[:, 'category_id'] = [class_names.index(c) for c in ann_df['class']]

Convert to JSON and save to file:

[24]:
ann_json = {
    'images': [dict(id=image_id, file_name=image_id) for image_id in ann_df.image_id.unique()],
    'annotations': ann_df.to_dict(orient='records'),
}
json_to_file(ann_json, 'airbus/annotations.json')

Usage#

The annotations are now ready to be used:

[4]:
import albumentations as A

from rastervision.pytorch_learner import ObjectDetectionImageDataset

ds = ObjectDetectionImageDataset(
    img_dir='airbus/images/',
    annotation_uri='airbus/annotations.json',
    transform=A.Resize(1024, 1024),
)
len(ds)
[4]:
103

We can read a data sample and the corresponding ground truth from the Dataset like so:

[5]:
x, y = ds[0]
x.shape, y
[5]:
(torch.Size([3, 1024, 1024]),
 <rastervision.pytorch_learner.object_detection_utils.BoxList at 0x7f22380d5460>)

Note that y is a BoxList.

And then plot it using the ObjectDetectionVisualizer:

[6]:
from rastervision.pytorch_learner import ObjectDetectionVisualizer

viz = ObjectDetectionVisualizer(
    class_names=['Airplane', 'Truncated_airplane'],
    class_colors=['red', 'green'])
viz.scale = 8
viz.plot_batch(x.unsqueeze(0), [y], show=True)
../../_images/usage_tutorials_prechipped_datasets_40_0.png

Classification – ClassificationImageDataset#

The ClassificationImageDataset subclasses torchvision’s DatasetFolder and expects a path to a directory containing images in the following structure:

<data_dir>/
    class_1/
        <images>
    class_2/
        <images>
    ...
    class_N/
        <images>

Dataset#

For this example, we will use the EuroSat Dataset.

Helber, Patrick, Benjamin Bischke, Andreas Dengel, and Damian Borth. “Eurosat: A novel dataset and deep learning benchmark for land use and land cover classification.” IEEE Journal of Selected Topics in Applied Earth Observations and Remote Sensing 12, no. 7 (2019): 2217-2226.

Below we download (94 MB) and unzip the data.

[11]:
!wget "https://madm.dfki.de/files/sentinel/EuroSAT.zip"
!apt-get install unzip -y
!unzip -q "EuroSAT.zip" -d "EuroSAT"
!mv "EuroSAT/2750/"* "EuroSAT/" && rm -rf "EuroSAT/2750"
!ls "EuroSAT"
AnnualCrop  HerbaceousVegetation  Industrial  PermanentCrop  River
Forest      Highway               Pasture     Residential    SeaLake

Usage#

[13]:
import albumentations as A

from rastervision.pytorch_learner import ClassificationImageDataset

ds = ClassificationImageDataset(
    data_dir='EuroSAT',
    # You can pass in a list explicitly if you want to enforce a specific
    # class-name to class-ID mapping.
    class_names=None,
    transform=A.Resize(256, 256),
)
len(ds)
[13]:
27000

We can read a data sample and the corresponding ground truth from the Dataset like so:

[22]:
x, y = ds[10_000]
x.shape, y
[22]:
(torch.Size([3, 256, 256]), tensor(3))

And then plot it using the ClassificationVisualizer:

[23]:
from rastervision.pytorch_learner import ClassificationVisualizer

viz = ClassificationVisualizer(class_names=ds.orig_dataset.classes)
viz.plot_batch(x.unsqueeze(0), y.unsqueeze(0), show=True)
../../_images/usage_tutorials_prechipped_datasets_53_0.png