Source code for rastervision.core.data.vector_transformer.class_inference_transformer

from typing import TYPE_CHECKING, Dict, Optional
from copy import deepcopy
import logging

from rastervision.core.data.vector_transformer import VectorTransformer
from rastervision.core.data.vector_transformer.label_maker.filter import (
    create_filter)
from rastervision.core.data.utils.geojson import features_to_geojson

if TYPE_CHECKING:
    from rastervision.core.data import ClassConfig, CRSTransformer

log = logging.getLogger(__name__)


[docs]class ClassInferenceTransformer(VectorTransformer): """Infers missing class IDs from GeoJSON features. Rules: 1) If ``class_id`` is in ``feature['properties']``, use it. 2) If ``class_config`` is set and ``"class_name"`` or ``"label"`` are in ``feature['properties']`` and in ``class_config``, use corresponding ``class_id``. 3) If ``class_id_to_filter`` is set and filter is true when applied to feature, use corresponding ``class_id``. 4) Otherwise, return the ``default_class_id``. """
[docs] def __init__(self, default_class_id: Optional[int], class_config: Optional['ClassConfig'] = None, class_id_to_filter: Optional[Dict[int, list]] = None, class_name_mapping: Optional[dict[str, str]] = None): """Constructor. Args: default_class_id: The default ``class_id`` to use if class cannot be inferred using other mechanisms. If a feature has an inferred ``class_id`` of None, then it will be deleted. Defaults to ``None``. class_config: ``ClassConfig`` to match the class names in the GeoJSON features to. Required if using ``class_name_mapping``. Defaults to None. class_id_to_filter: Map from ``class_id`` to JSON filter used to infer missing class IDs. Each key should be a class ID, and its value should be a boolean expression which is run against the property field for each feature. This allows matching different features to different class IDs based on its properties. The expression schema is that described by https://docs.mapbox.com/mapbox-gl-js/style-spec/other/#other-filter. Defaults to ``None``. class_name_mapping: ``old_name --> new_name`` mapping for values in the ``class_name`` or ``label`` property of the GeoJSON features. The ``new_name`` must be a valid class name in the ``ClassConfig``. This can also be used to merge multiple classes into one e.g.: ``dict(car="vehicle", truck="vehicle")``. Defaults to ``None``. """ if class_name_mapping is not None and class_config is None: raise ValueError( 'class_config must be specified if class_name_mapping is.') self.class_config = class_config self.class_id_to_filter = class_id_to_filter self.default_class_id = default_class_id self.class_name_mapping = class_name_mapping if self.class_id_to_filter is not None: self.class_id_to_filter = {} for class_id, filter_exp in class_id_to_filter.items(): self.class_id_to_filter[int(class_id)] = create_filter( filter_exp)
[docs] @staticmethod def infer_feature_class_id( feature: dict, default_class_id: Optional[int], class_config: Optional['ClassConfig'] = None, class_id_to_filter: Optional[Dict[int, list]] = None, class_name_mapping: Optional[dict[str, str]] = None ) -> Optional[int]: """Infer the class ID for a GeoJSON feature. Rules: 1) If ``class_id`` is in ``feature['properties']``, use it. 2) If ``class_config`` is set and ``"class_name"`` or ``"label"`` are in ``feature['properties']`` and in ``class_config``, use corresponding ``class_id``. 3) If ``class_id_to_filter`` is set and filter is true when applied to feature, use corresponding ``class_id``. 4) Otherwise, return the ``default_class_id``. Args: feature: GeoJSON feature. default_class_id: The default ``class_id`` to use if class cannot be inferred using other mechanisms. If a feature has an inferred ``class_id`` of None, then it will be deleted. Defaults to ``None``. class_config: ``ClassConfig`` to match the class names in the GeoJSON features to. Required if using ``class_name_mapping``. Defaults to None. class_id_to_filter: Map from ``class_id`` to JSON filter used to infer missing class IDs. Each key should be a class ID, and its value should be a boolean expression which is run against the property field for each feature. This allows matching different features to different class IDs based on its properties. The expression schema is that described by https://docs.mapbox.com/mapbox-gl-js/style-spec/other/#other-filter. Defaults to ``None``. class_name_mapping: ``old_name --> new_name`` mapping for values in the ``class_name`` or ``label`` property of the GeoJSON features. The ``new_name`` must be a valid class name in the ``ClassConfig``. This can also be used to merge multiple classes into one e.g.: ``dict(car="vehicle", truck="vehicle")``. Defaults to ``None``. Returns: Optional[int]: Inferred class ID. """ if class_name_mapping is not None and class_config is None: raise ValueError( 'class_config must be specified if class_name_mapping is.') properties: dict = feature.get('properties', {}) class_id = properties.get('class_id') if class_id is not None: return class_id if class_config is not None: if class_name_mapping is None: class_name_mapping = {} class_name = properties.get('class_name') if class_name is None: class_name = properties.get('label') class_name = class_name_mapping.get(class_name, class_name) if class_name in class_config.names: return class_config.names.index(class_name) if class_id_to_filter is not None: for class_id, filter_fn in class_id_to_filter.items(): if filter_fn(feature): return class_id return default_class_id
[docs] def transform(self, geojson: dict, crs_transformer: Optional['CRSTransformer'] = None) -> dict: """Add class_id to feature properties and drop features with no class. For each feature in geojson, the class_id is inferred and is set into feature['properties']. If the class_id is None (because none of the rules apply and the default_class_id is None), the feature is dropped. """ new_features = [] warned = False for feature in geojson['features']: class_id = self.infer_feature_class_id( feature, default_class_id=self.default_class_id, class_config=self.class_config, class_id_to_filter=self.class_id_to_filter, class_name_mapping=self.class_name_mapping) if class_id is not None: feature = deepcopy(feature) properties = feature.get('properties', {}) properties['class_id'] = class_id feature['properties'] = properties new_features.append(feature) elif not warned: log.warning( 'ClassInferenceTransformer is dropping vector features because ' 'class_id cannot be inferred. To avoid this behavior, ' 'set default_class_id to a non-None value in ' 'ClassInferenceTransformer.') warned = True new_geojson = features_to_geojson(new_features) return new_geojson