Source code for rastervision.core.data.vector_transformer.class_inference_transformer

from typing import TYPE_CHECKING, Dict, Optional
from copy import deepcopy
import logging

from rastervision.core.data.vector_transformer import VectorTransformer
from rastervision.core.data.vector_transformer.label_maker.filter import (
    create_filter)
from rastervision.core.data.utils.geojson import features_to_geojson

if TYPE_CHECKING:
    from rastervision.core.data import ClassConfig, CRSTransformer

log = logging.getLogger(__name__)


[docs]class ClassInferenceTransformer(VectorTransformer): """Infers missing class_ids from GeoJSON features. Rules: 1) If class_id is in feature['properties'], use it. 2) If class_config is set and class_name or label are in feature['properties'] and in class_config, use corresponding class_id. 3) If class_id_to_filter is set and filter is true when applied to feature, use corresponding class_id. 4) Otherwise, return the default_class_id """
[docs] def __init__(self, default_class_id: Optional[int], class_config: Optional['ClassConfig'] = None, class_id_to_filter: Optional[Dict[int, list]] = None): self.class_config = class_config self.class_id_to_filter = class_id_to_filter self.default_class_id = default_class_id if self.class_id_to_filter is not None: self.class_id_to_filter = {} for class_id, filter_exp in class_id_to_filter.items(): self.class_id_to_filter[int(class_id)] = create_filter( filter_exp)
[docs] @staticmethod def infer_feature_class_id( feature: dict, default_class_id: Optional[int], class_config: Optional['ClassConfig'] = None, class_id_to_filter: Optional[Dict[int, list]] = None ) -> Optional[int]: """Infer the class_id for a GeoJSON feature. Rules: 1) If class_id is in feature['properties'], use it. 2) If class_config is set and class_name or label are in feature['properties'] and in class_config, use corresponding class_id. 3) If class_id_to_filter is set and filter is true when applied to feature, use corresponding class_id. 4) Otherwise, return the default_class_id. Args: feature (dict): GeoJSON feature. Returns: Optional[int]: Inferred class ID. """ class_id = feature.get('properties', {}).get('class_id') if class_id is not None: return class_id if class_config is not None: class_name = feature.get('properties', {}).get('class_name') if class_name in class_config.names: return class_config.names.index(class_name) label = feature.get('properties', {}).get('label') if label in class_config.names: return class_config.names.index(label) if class_id_to_filter is not None: for class_id, filter_fn in class_id_to_filter.items(): if filter_fn(feature): return class_id return default_class_id
[docs] def transform(self, geojson: dict, crs_transformer: Optional['CRSTransformer'] = None) -> dict: """Add class_id to feature properties and drop features with no class. For each feature in geojson, the class_id is inferred and is set into feature['properties']. If the class_id is None (because none of the rules apply and the default_class_id is None), the feature is dropped. """ new_features = [] warned = False for feature in geojson['features']: class_id = self.infer_feature_class_id( feature, default_class_id=self.default_class_id, class_config=self.class_config, class_id_to_filter=self.class_id_to_filter) if class_id is not None: feature = deepcopy(feature) properties = feature.get('properties', {}) properties['class_id'] = class_id feature['properties'] = properties new_features.append(feature) elif not warned: log.warning( 'ClassInferenceTransformer is dropping vector features because ' 'class_id cannot be inferred. To avoid this behavior, ' 'set default_class_id to a non-None value in ' 'ClassInferenceTransformer.') warned = True new_geojson = features_to_geojson(new_features) return new_geojson