Source code for rastervision.pytorch_learner.dataset.classification_dataset

from typing import TYPE_CHECKING, Iterable, List, Optional, Union
import logging

from rastervision.pytorch_learner.dataset import (
    ImageDataset, TransformType, SlidingWindowGeoDataset,
    RandomWindowGeoDataset, make_image_folder_dataset)
from rastervision.core.data.utils import make_cc_scene

if TYPE_CHECKING:
    from rastervision.core.data import ClassConfig, ChipClassificationLabelSource

log = logging.getLogger(__name__)


[docs]class ClassificationImageDataset(ImageDataset): """Read images and class labels from images stored in class folders. I.e., all images for a class "A" are stored in directory ``A/``, all images for a class "B" are stored in directory ``B/``, and so on. And all class directories are located in the same parent directory. """
[docs] def __init__(self, data_dir: str, class_names: Optional[Iterable[str]], *args, **kwargs): """Constructor. Args: data_dir (str): Root directory containing class dirs. class_names (Optional[Iterable[str]]): Class names. Should match class dir names. *args: See :meth:`.ImageDataset.__init__`. **kwargs: See :meth:`.ImageDataset.__init__`. """ ds = make_image_folder_dataset(data_dir, classes=class_names) super().__init__( ds, *args, **kwargs, transform_type=TransformType.classification)
[docs]def make_cc_geodataset(cls, image_uri: Union[str, List[str]], label_vector_uri: Optional[str] = None, class_config: Optional['ClassConfig'] = None, aoi_uri: Union[str, List[str]] = [], label_vector_default_class_id: Optional[int] = None, image_raster_source_kw: dict = {}, label_vector_source_kw: dict = {}, label_source_kw: dict = {}, **kwargs): """Create an instance of this class from image and label URIs. This is a convenience method. For more fine-grained control, it is recommended to use the default constructor. Args: class_config (ClassConfig): The ClassConfig. image_uri (Union[str, List[str]]): URI or list of URIs of GeoTIFFs to use as the source of image data. label_vector_uri (Optional[str], optional): URI of GeoJSON file to use as the source of segmentation label data. Defaults to None. class_config (Optional['ClassConfig']): The ClassConfig. Can be None if not using any labels. aoi_uri (Union[str, List[str]], optional): URI or list of URIs of GeoJSONs that specify the area-of-interest. If provided, the dataset will only access data from this area. Defaults to []. label_vector_default_class_id (Optional[int], optional): If using label_vector_uri and all polygons in that file belong to the same class and they do not contain a `class_id` property, then use this argument to map all of the polgons to the appropriate class ID. See docs for ClassInferenceTransformer for more details. Defaults to None. image_raster_source_kw (dict, optional): Additional arguments to pass to the RasterioSource used for image data. See docs for RasterioSource for more details. Defaults to {}. label_vector_source_kw (dict, optional): Additional arguments to pass to the GeoJSONVectorSourceConfig used for label data, if label_vector_uri is set. See docs for GeoJSONVectorSourceConfig for more details. Defaults to {}. label_source_kw (dict, optional): Additional arguments to pass to the ChipClassificationLabelSourceConfig used for label data, if label_vector_uri is set. See docs for ChipClassificationLabelSourceConfig for more details. Defaults to {}. **kwargs: All other keyword args are passed to the default constructor for this class. Returns: An instance of this GeoDataset subclass. """ scene = make_cc_scene( image_uri=image_uri, label_vector_uri=label_vector_uri, class_config=class_config, aoi_uri=aoi_uri, label_vector_default_class_id=label_vector_default_class_id, image_raster_source_kw=image_raster_source_kw, label_vector_source_kw=label_vector_source_kw, label_source_kw=label_source_kw) ds = cls(scene, **kwargs) return ds
[docs]class ClassificationSlidingWindowGeoDataset(SlidingWindowGeoDataset): from_uris = classmethod(make_cc_geodataset)
[docs] def __init__(self, *args, **kwargs): super().__init__( *args, **kwargs, transform_type=TransformType.classification)
[docs] def init_windows(self): super().init_windows() if self.scene.label_source is not None: ls: 'ChipClassificationLabelSource' = self.scene.label_source ls.populate_labels(cells=self.windows)
[docs]class ClassificationRandomWindowGeoDataset(RandomWindowGeoDataset): from_uris = classmethod(make_cc_geodataset)
[docs] def __init__(self, *args, **kwargs): super().__init__( *args, **kwargs, transform_type=TransformType.classification)