from typing import Iterable
from os import PathLike
from os.path import join, splitext
from pathlib import Path
from itertools import chain
import numpy as np
from torchvision.datasets.folder import (IMG_EXTENSIONS, DatasetFolder)
from PIL import Image
import rasterio as rio
IMG_EXTENSIONS = tuple([*IMG_EXTENSIONS, '.npy'])
[docs]class DatasetError(Exception):
pass
[docs]class ImageDatasetError(DatasetError):
pass
[docs]class GeoDatasetError(DatasetError):
pass
[docs]def discover_images(dir: PathLike,
extensions: Iterable[str] = IMG_EXTENSIONS) -> list[Path]:
"""Find all images with the given ``extensions`` in ``dir``."""
dir = Path(dir)
img_paths = chain.from_iterable(
(dir.glob(f'*{ext}') for ext in extensions))
return list(img_paths)
[docs]def load_image(path: PathLike) -> np.ndarray:
"""Read in image from path and return as a (H, W, C) numpy array."""
ext = splitext(path)[-1]
if ext == '.npy':
img = np.load(path)
elif ext == '.tif' or ext == '.tiff':
with rio.open(path, 'r') as f:
img = f.read()
img = img.transpose(1, 2, 0)
else:
img = np.array(Image.open(path))
if img.ndim == 2:
# (h, w) --> (h, w, 1)
img = img[..., np.newaxis]
return img
[docs]def make_image_folder_dataset(
data_dir: str, classes: Iterable[str] | None = None) -> DatasetFolder:
"""Initializes and returns an ImageFolder.
If classes is specified, ImageFolder's default class-to-index mapping
behavior is overridden to use the indices of classes instead.
"""
if classes is None:
return DatasetFolder(
data_dir, loader=load_image, extensions=IMG_EXTENSIONS)
from rastervision.pipeline.file_system.utils import (file_exists,
list_paths)
class_dirs = [join(data_dir, c) for c in classes]
classes_present = [
c for c, dir in zip(classes, class_dirs)
if file_exists(dir, include_dir=True) and len(list_paths(dir)) > 0
]
class_to_id = {c: classes.index(c) for c in classes_present}
class ImageFolder(DatasetFolder):
def find_classes(self,
directory: str) -> tuple[list[str], dict[str, int]]:
"""Override to force mapping from class name to class index."""
return classes_present, class_to_id
return ImageFolder(data_dir, loader=load_image, extensions=IMG_EXTENSIONS)