Source code for rastervision.core.data.utils.rasterio

from typing import TYPE_CHECKING, Sequence
import os
from os.path import join
import subprocess
import logging

import numpy as np
import rasterio as rio
import rasterio.windows as rio_windows
from rasterio.transform import from_origin
from rasterio.enums import (ColorInterp, MaskFlags, Resampling)
from rasterio.session import AWSSession

from rastervision.pipeline.file_system.utils import (
    file_to_json, get_local_path, get_tmp_dir, make_dir, upload_or_copy,
    download_if_needed, uri_to_vsi_path)
from rastervision.core.box import Box

if TYPE_CHECKING:
    from rasterio.io import DatasetReader
    from rasterio.session import Session

log = logging.getLogger(__name__)


[docs]def write_window(dataset: 'DatasetReader', arr: np.ndarray, window: Box | None = None) -> None: """Write a (H, W[, C]) array out to a rasterio dataset. Args: dataset (DatasetReader): Rasterio dataset, opened for writing. arr (np.ndarray): Array to write. window (Box | None): Window (in pixel coords) to write to. Defaults to None. """ if window is not None: window = window.rasterio_format() if arr.ndim == 2: dataset.write_band(1, arr, window=window) else: arr_chw = arr.transpose(2, 0, 1) for i, band in enumerate(arr_chw, start=1): dataset.write_band(i, band, window=window)
[docs]def write_bbox(path: str, arr: np.ndarray, bbox: Box, crs_wkt: str, **kwargs): """Write a (H, W[, C]) array to a GeoTIFF, georeferenced to the given bbox. Args: path (str): GeoTiff path. arr (np.ndarray): (H, W[, C]) Array to write. bbox (Box): Bounding box in map coords to georeference the GeoTiff to. crs_wkt (str): CRS in WKT format. """ if arr.ndim == 2: h_arr, w_arr = arr.shape num_channels = 1 else: h_arr, w_arr, num_channels = arr.shape h_bbox, w_bbox = bbox.size resolution = h_bbox / h_arr, w_bbox / w_arr transform = from_origin(bbox.xmin, bbox.ymax, *resolution) out_profile = dict( driver='GTiff', height=h_arr, width=w_arr, crs=crs_wkt, count=num_channels, dtype=arr.dtype, transform=transform, ) out_profile.update(kwargs) with rio.open(path, 'w', **out_profile) as ds: write_window(ds, arr)
[docs]def write_geotiff_like_geojson(path: str, arr: np.ndarray, geojson_path: str, crs: str | None = None, **kwargs) -> None: """Write array to GeoTIFF, georeferenced to same bbox as the given GeoJSON. Args: path (str): GeoTiff path. arr (np.ndarray): (H, W[, C]) Array to write. geojson_path (str): GeoJSON path. crs (str): CRS name. If None, read from the GeoJSON. If not specified in the GeoJSON, use "EPSG:4326". Defaults to None. """ from rastervision.core.data.utils.geojson import geojson_to_geoms import pyproj from shapely.ops import unary_union geojson = file_to_json(geojson_path) if crs is None: try: crs = geojson['crs']['properties']['name'] except KeyError: crs = 'epsg:4326' crs_wkt = pyproj.CRS(crs).to_wkt() geoms = unary_union(list(geojson_to_geoms(geojson))) bbox = Box.from_shapely(geoms) write_bbox(path, arr, bbox=bbox, crs_wkt=crs_wkt, **kwargs)
[docs]def crop_geotiff(src_uri: str, window: Box, dst_uri: str): """Create a new GeoTIFF from a crop of an existing GeoTIFF. Args: src_uri (str): Source GeoTIFF URI to read from. window (Box): Window specifying the crop bounds. dst_uri (str): Crop GeoTIFF URI to write to. """ rio_window = window.rasterio_format() with rio.open(src_uri) as src_ds, get_tmp_dir() as tmp_dir: crop_path = get_local_path(dst_uri, tmp_dir) make_dir(crop_path, use_dirname=True) meta = src_ds.meta colorinterp = src_ds.colorinterp img_cropped = src_ds.read(window=rio_window) meta['height'], meta['width'] = window.size meta['transform'] = rio_windows.transform(rio_window, src_ds.transform) with rio.open(crop_path, 'w', **meta) as dst_ds: dst_ds.colorinterp = colorinterp dst_ds.write(img_cropped) upload_or_copy(crop_path, dst_uri)
[docs]def build_vrt(vrt_path: str, image_uris: list[str]) -> None: """Build a VRT for a set of TIFF files. Args: vrt_path (str): Local path for the VRT to be created. image_uris (list[str]): Image URIs. """ log.info('Building VRT...') cmd = ['gdalbuildvrt', vrt_path] image_uris_vsi = [uri_to_vsi_path(uri) for uri in image_uris] cmd.extend(image_uris_vsi) make_dir(vrt_path, use_dirname=True) subprocess.run(cmd, env=os.environ)
[docs]def download_and_build_vrt(image_uris: list[str], vrt_dir: str, stream: bool = False) -> str: """Download images (if needed) and build a VRT for a set of TIFF files. Args: image_uris (list[str]): Image URIs. vrt_dir (str): Dir where the VRT will be created. stream (bool): If true, do not download images. Defaults to False. Returns: str: The path to the created VRT file. """ if not stream: image_uris = [download_if_needed(uri) for uri in image_uris] vrt_path = join(vrt_dir, 'index.vrt') build_vrt(vrt_path, image_uris) return vrt_path
[docs]def read_window(dataset: 'DatasetReader', bands: int | Sequence[int] | None = None, window: tuple[tuple[int, int], tuple[int, int]] | None = None, is_masked: bool = False, out_shape: tuple[int, ...] | None = None, session: 'Session | None' = None) -> np.ndarray: """Load a window of an image using Rasterio. Args: dataset: a Rasterio dataset. bands: Band index or indices to read. Must be 1-indexed. window: ``((row_start, row_stop), (col_start, col_stop))`` or ``((y_min, y_max), (x_min, x_max))``. If ``None``, reads the entire raster. Defaults to ``None``. is_masked: If ``True``, read a masked array from rasterio. Defaults to ``False``. out_shape: (height, width) of the output chip. If ``None``, no resizing is done. Defaults to ``None``. session: Rasterio :class:`.Session`. Returns: np.ndarray: array of shape (height, width, channels). """ if bands is not None: bands = tuple(bands) with rio.Env(session=session): im = dataset.read( indexes=bands, window=window, boundless=True, masked=is_masked, out_shape=out_shape, resampling=Resampling.bilinear) if is_masked: im = np.ma.filled(im, fill_value=0) # Handle non-zero NODATA values by setting the data to 0. if bands is None: for channel, nodataval in enumerate(dataset.nodatavals): if nodataval is not None and nodataval != 0: im[channel, im[channel] == nodataval] = 0 else: for channel, src_band in enumerate(bands): src_band_0_indexed = src_band - 1 nodataval = dataset.nodatavals[src_band_0_indexed] if nodataval is not None and nodataval != 0: im[channel, im[channel] == nodataval] = 0 im = np.transpose(im, axes=[1, 2, 0]) return im
[docs]def get_channel_order_from_dataset(dataset: 'DatasetReader') -> list[int]: """Get channel order from rasterio image dataset. Accounts for dataset's ``colorinterp`` if defined. Args: dataset (DatasetReader): Rasterio image dataset. Returns: list[int]: List of channel indices. """ colorinterp = dataset.colorinterp if colorinterp: channel_order = [ i for i, color_interp in enumerate(colorinterp) if color_interp != ColorInterp.alpha ] else: channel_order = list(range(0, dataset.count)) return channel_order
[docs]def is_masked(dataset: 'DatasetReader') -> bool: """Check if dataset has any masks defined.""" mask_flags = dataset.mask_flag_enums is_masked = any(m for m in mask_flags if m != MaskFlags.all_valid) return is_masked
[docs]def get_aws_session() -> 'Session': """Build a rasterio AWS session from environment variables.""" try: from rastervision.aws_s3 import S3FileSystem requester_pays = S3FileSystem.get_request_payer() except ModuleNotFoundError: requester_pays = os.getenv('AWS_REQUEST_PAYER', '').lower() == 'requestor' session = AWSSession.from_environ(requester_pays=requester_pays) return session