from typing import TYPE_CHECKING, Literal
from collections.abc import Callable
import inspect
import logging
from pydantic import ( # noqa
ConfigDict, BaseModel, create_model, Field, model_validator,
ValidationError, field_validator)
from rastervision.pipeline import (registry_ as registry, rv_config_ as
rv_config)
from rastervision.pipeline.file_system import (file_to_json, json_to_file,
str_to_file)
if TYPE_CHECKING:
from typing import Self
from rastervision.pipeline.pipeline_config import PipelineConfig
log = logging.getLogger(__name__)
[docs]class ConfigError(ValueError):
"""Exception raised for invalid configuration."""
[docs]class Config(BaseModel):
"""Base class that can be extended to provide custom configurations.
This adds some extra methods to Pydantic BaseModel.
See https://pydantic-docs.helpmanual.io/
The general idea is that configuration schemas can be defined by
subclassing this and adding class attributes with types and
default values for each field. Configs can be defined hierarchically,
ie. a Config can have fields which are of type Config.
Validation, serialization, deserialization, and IDE support is
provided automatically based on this schema.
"""
model_config = ConfigDict(extra='forbid', validate_assignment=True)
[docs] def update(self, *args, **kwargs):
"""Update any fields before validation.
Subclasses should override this to provide complex default behavior, for
example, setting default values as a function of the values of other
fields. The arguments to this method will vary depending on the type of Config.
"""
[docs] def build(self):
"""Build an instance of the corresponding type of object using this config.
For example, BackendConfig will build a Backend object. The arguments to this
method will vary depending on the type of Config.
"""
[docs] def validate_config(self):
"""Validate fields that should be checked after update is called.
This is to complement the builtin validation that Pydantic performs at the time
of object construction.
"""
[docs] def revalidate(self):
"""Re-validate an instantiated Config.
Runs all Pydantic validators plus self.validate_config().
"""
self.model_validate(self.__dict__)
self.validate_config()
[docs] def recursive_validate_config(self):
"""Recursively validate hierarchies of Configs.
This uses reflection to call validate_config on a hierarchy of Configs
using a depth-first pre-order traversal.
"""
class_hierarchy = type(self).mro()
for klass in class_hierarchy:
if issubclass(klass, Config):
klass.validate_config(self)
child_configs = [
x for x in self.__dict__.values() if isinstance(x, Config)
]
for c in child_configs:
c.recursive_validate_config()
[docs] def validate_list(self, field: str, valid_options: list[str]):
"""Validate a list field.
Args:
field (str): name of field to validate
valid_options (list[str]): values that field is allowed to take
Raises:
ConfigError: if field is invalid
"""
val = getattr(self, field)
if isinstance(val, list):
for v in val:
if v not in valid_options:
raise ConfigError(f'{v} is not a valid option for {field}')
else:
if val not in valid_options:
raise ConfigError(f'{val} is not a valid option for {field}')
def copy(self) -> 'Self':
return self.model_copy()
def dict(self, with_rv_metadata: bool = False, **kwargs) -> dict:
cfg_dict = self.model_dump(serialize_as_any=True, **kwargs)
if with_rv_metadata:
cfg_dict['plugin_versions'] = registry.plugin_versions
return cfg_dict
def json(self, **kwargs) -> dict:
return self.model_dump_json(serialize_as_any=True, **kwargs)
[docs] def to_file(self, uri: str, with_rv_metadata: bool = True) -> None:
"""Save a Config to a JSON file, optionally with RV metadata.
Args:
uri: URI to save to.
with_rv_metadata: If True, inject Raster Vision metadata such as
``plugin_versions``, so that the config can be upgraded when
loaded.
"""
# Using self.dict() followed by json_to_file() would make the code
# simpler, but it doesn't work due to serialization errors.
cfg_json = self.json()
if with_rv_metadata:
import json
cfg_dict = json.loads(cfg_json)
cfg_dict['plugin_versions'] = registry.plugin_versions
json_to_file(cfg_dict, uri)
else:
str_to_file(cfg_json, uri)
[docs] @classmethod
def deserialize(cls, inp: 'str | dict | Config') -> 'Self':
"""Deserialize Config from a JSON file or dict, upgrading if possible.
If ``inp`` is already a :class:`.Config`, it is returned as is.
Args:
inp: a URI to a JSON file or a dict.
"""
if isinstance(inp, Config):
return inp
if isinstance(inp, dict):
return cls.from_dict(inp)
if isinstance(inp, str):
return cls.from_file(inp)
raise TypeError(f'Cannot deserialize Config from type: {type(inp)}.')
[docs] @classmethod
def from_file(cls, uri: str) -> 'Self':
"""Deserialize Config from a JSON file, upgrading if possible.
Args:
uri: URI to load from.
"""
cfg_dict = file_to_json(uri)
cfg = cls.from_dict(cfg_dict)
return cfg
[docs] @classmethod
def from_dict(cls, cfg_dict: dict) -> 'Self':
"""Deserialize Config from a dict.
Args:
cfg_dict: Dict to deserialize.
"""
if 'plugin_versions' in cfg_dict:
cfg_dict: dict = upgrade_config(cfg_dict)
cfg_dict.pop('plugin_versions', None)
cfg = build_config(cfg_dict)
return cfg
def __repr_args__(self):
"""Override to delete 'type_hint' field."""
args = dict(super().__repr_args__())
try:
del args['type_hint']
except KeyError:
pass
return args.items()
[docs]def save_pipeline_config(cfg: 'PipelineConfig', output_uri: str) -> None:
"""Save a PipelineConfig to JSON file.
Inject rv_config and plugin_versions before saving.
"""
cfg.rv_config = rv_config.get_config_dict(registry.rv_config_schema)
cfg.plugin_versions = registry.plugin_versions
cfg_json = cfg.json()
str_to_file(cfg_json, output_uri)
[docs]def build_config(
x: dict | list[dict | Config] | Config) -> Config | list[Config]:
"""Build a Config from various types of input.
This is useful for deserializing from JSON. It implements polymorphic
deserialization by using the `type_hint` in each dict to get the
corresponding Config class from the registry.
Args:
x: some representation of Config(s)
Returns:
The corresponding Config(s).
"""
if isinstance(x, dict):
new_x = {k: build_config(v) for k, v in x.items()}
type_hint = new_x.get('type_hint')
if type_hint is not None:
config_cls = registry.get_config(type_hint)
new_x = config_cls(**new_x)
return new_x
elif isinstance(x, list):
return [build_config(v) for v in x]
else:
return x
def _upgrade_config(x: dict | list[dict],
plugin_versions: dict[str, int]) -> dict | list[dict]:
"""Upgrade serialized Config(s) to the latest version.
Used to implement backward compatibility of Configs using upgraders stored
in the registry.
Args:
x: serialized Config(s) which are potentially of a non-current version
plugin_versions: dict mapping from plugin module name to the latest version
Returns:
The corresponding serialized Config(s) that have been upgraded to the
current version.
"""
if isinstance(x, dict):
new_x = {}
for k, v in x.items():
new_x[k] = _upgrade_config(v, plugin_versions)
type_hint = new_x.get('type_hint')
if type_hint is None:
return new_x
if type_hint in registry.renamed_type_hints:
type_hint = registry.renamed_type_hints[type_hint]
new_x['type_hint'] = type_hint
type_hint_lineage = registry.get_type_hint_lineage(type_hint)
for th in type_hint_lineage:
plugin = registry.get_plugin(th)
old_version = plugin_versions[plugin]
curr_version = registry.get_plugin_version(plugin)
upgrader = registry.get_upgrader(th)
if upgrader:
for version in range(old_version, curr_version):
new_x = upgrader(new_x, version)
return new_x
elif isinstance(x, list):
return [_upgrade_config(v, plugin_versions) for v in x]
else:
return x
[docs]def upgrade_plugin_versions(plugin_versions: dict[str, int]) -> dict[str, int]:
"""Update the names of the plugins using the plugin aliases in the registry.
This allows changing the names of plugins over time and maintaining backward
compatibility of serialized PipelineConfigs.
Args:
plugin_version: maps from plugin name to version
"""
new_plugin_versions = {}
missing_plugins = []
for alias, version in plugin_versions.items():
plugin = registry.get_plugin_from_alias(alias)
if plugin:
new_plugin_versions[plugin] = version
else:
missing_plugins.append(alias)
if len(missing_plugins) > 0:
log.warning('There are plugins listed in the pipeline config that are '
f'not currently installed: {missing_plugins}')
return new_plugin_versions
[docs]def upgrade_config(config_dict: dict | list[dict]) -> dict | list[dict]:
"""Upgrade serialized Config(s) to the latest version.
Used to implement backward compatibility of Configs using upgraders stored
in the registry.
Args:
config_dict: serialized PipelineConfig(s) which are potentially of a
non-current version
Returns:
The corresponding serialized PipelineConfig(s) that have been upgraded
to the current version.
"""
plugin_versions = config_dict.get('plugin_versions')
if plugin_versions is None:
return config_dict
plugin_versions = upgrade_plugin_versions(plugin_versions)
out = _upgrade_config(config_dict, plugin_versions)
return out
[docs]def get_plugin(config_cls: type) -> str:
"""Infer the module path of the plugin where a Config class is defined.
This only works correctly if the plugin is in a module under rastervision.
"""
cls_module = inspect.getmodule(config_cls)
return 'rastervision.' + cls_module.__name__.split('.')[1]
[docs]def register_config(type_hint: str,
plugin: str | None = None,
upgrader: Callable[[dict, int], dict] | None = None
) -> Callable[[], Config]:
"""Class decorator used to register Config classes with registry.
All Configs must be registered! Registering a Config does the following:
1. Associates Config classes with type_hint, plugin, and upgrader, which
is necessary for polymorphic deserialization. See build_config() for
more details.
2. Adds a constant `type_hint` field to the Config which is set to
type_hint.
Args:
type_hint: a type hint used to deserialize Configs. Must be unique
across all registered Configs.
plugin: the module path of the plugin where the ``Config`` is defined.
If ``None``, will be inferred. Defauilts to ``None``.
upgrader: a function of the form ``upgrade(config_dict, version)``
which returns the corresponding config dict of
``version = version + 1``. This can be useful for maintaining
backward compatibility by allowing old configs using an outdated
schema to be upgraded to the current schema. Defaults to ``None``.
Returns:
A function that returns a new class that is identical to the input
``Config`` with an additional ``type_hint`` field.
"""
def _register_config(cls: type):
new_cls = create_model(
cls.__name__,
__base__=cls,
__module__=cls.__module__,
# add a new field called "type_hint" with type Literal[type_hint]
# and default value type_hint to the config
type_hint=(Literal[type_hint], type_hint), # type: ignore
)
_plugin = plugin or get_plugin(cls)
registry.add_config(type_hint, new_cls, _plugin, upgrader)
# retain docstring after wrapping
new_cls.__doc__ = cls.__doc__
return new_cls
return _register_config