Source code for rastervision.core.data.crs_transformer.rasterio_crs_transformer

from typing import Any, Optional
from pyproj import Transformer

import rasterio as rio
from rasterio.transform import (rowcol, xy)
from rasterio import Affine

from rastervision.core.data.crs_transformer import (CRSTransformer,
                                                    IdentityCRSTransformer)


[docs]class RasterioCRSTransformer(CRSTransformer): """Transformer for a RasterioRasterSource."""
[docs] def __init__(self, transform: Affine, image_crs: Any, map_crs: Any = 'epsg:4326', round_pixels: bool = True): """Constructor. Args: transform (Affine): Rasterio affine transform. image_crs (Any): CRS of image in format that PyProj can handle eg. wkt or init string. map_crs (Any): CRS of the labels. Defaults to "epsg:4326". round_pixels (bool): If True, round outputs of map_to_pixel and inputs of pixel_to_map to integers. Defaults to False. """ if (image_crs is None) or (image_crs == map_crs): self.map2image = lambda *args, **kws: args[:2] self.image2map = lambda *args, **kws: args[:2] else: self.map2image = Transformer.from_crs( map_crs, image_crs, always_xy=True).transform self.image2map = Transformer.from_crs( image_crs, map_crs, always_xy=True).transform self.round_pixels = round_pixels super().__init__(transform, image_crs, map_crs)
def __repr__(self) -> str: # pragma: no cover cls_name = type(self).__name__ image_crs_str = str(self.image_crs) if len(image_crs_str) > 70: image_crs_str = image_crs_str[:70] + '...' map_crs_str = str(self.image_crs) if len(map_crs_str) > 70: map_crs_str = map_crs_str[:70] + '...' transform_str = ( '\n\t\t' + (str(self.transform).replace('\n', '\n\t\t'))) out = f"""{cls_name}( image_crs="{image_crs_str}", map_crs="{map_crs_str}", round_pixels="{self.round_pixels}", transform={transform_str}) """ return out def _map_to_pixel(self, map_point): """Transform point from map to pixel-based coordinates. Args: map_point: (x, y) tuple in map coordinates Returns: (x, y) tuple in pixel coordinates """ image_point = self.map2image(*map_point) x, y = image_point if self.round_pixels: row, col = rowcol(self.transform, x, y) else: row, col = rowcol(self.transform, x, y, op=lambda x: x) pixel_point = (col, row) return pixel_point def _pixel_to_map(self, pixel_point): """Transform point from pixel to map-based coordinates. Args: pixel_point: (x, y) tuple in pixel coordinates Returns: (x, y) tuple in map coordinates """ col, row = pixel_point if self.round_pixels: col, row = int(col), int(row) image_point = xy(self.transform, row, col, offset='center') map_point = self.image2map(*image_point) return map_point
[docs] @classmethod def from_dataset(cls, dataset, map_crs: Optional[str] = 'epsg:4326', **kwargs) -> 'RasterioCRSTransformer': transform = dataset.transform image_crs = None if dataset.crs is None else dataset.crs.wkt map_crs = image_crs if map_crs is None else map_crs no_crs_tf = (image_crs is None) or (image_crs == map_crs) no_affine_tf = (transform is None) or (transform == Affine.identity()) if no_crs_tf and no_affine_tf: return IdentityCRSTransformer() if transform is None: transform = Affine.identity() return cls(transform, image_crs, map_crs, **kwargs)
[docs] @classmethod def from_uri(cls, uri: str, map_crs: Optional[str] = 'epsg:4326', **kwargs) -> 'RasterioCRSTransformer': with rio.open(uri) as ds: return cls.from_dataset(ds, map_crs=map_crs, **kwargs)