from typing import List, Type, Union, Optional, Callable, Dict, TYPE_CHECKING
import inspect
from pydantic import ( # noqa
BaseModel, create_model, Field, root_validator, validate_model,
ValidationError, validator)
from typing_extensions import Literal
from rastervision.pipeline import (registry_ as registry, rv_config_ as
rv_config)
from rastervision.pipeline.file_system import str_to_file
if TYPE_CHECKING:
from rastervision.pipeline.pipeline_config import PipelineConfig
[docs]class ConfigError(ValueError):
"""Exception raised for invalid configuration."""
pass
[docs]class Config(BaseModel):
"""Base class that can be extended to provide custom configurations.
This adds some extra methods to Pydantic BaseModel.
See https://pydantic-docs.helpmanual.io/
The general idea is that configuration schemas can be defined by
subclassing this and adding class attributes with types and
default values for each field. Configs can be defined hierarchically,
ie. a Config can have fields which are of type Config.
Validation, serialization, deserialization, and IDE support is
provided automatically based on this schema.
"""
# This is here to forbid instantiating Configs with fields that do not
# exist in the schema, which helps avoid a command source of bugs.
class Config:
extra = 'forbid'
validate_assignment = True
[docs] def update(self, *args, **kwargs):
"""Update any fields before validation.
Subclasses should override this to provide complex default behavior, for
example, setting default values as a function of the values of other
fields. The arguments to this method will vary depending on the type of Config.
"""
pass
[docs] def build(self):
"""Build an instance of the corresponding type of object using this config.
For example, BackendConfig will build a Backend object. The arguments to this
method will vary depending on the type of Config.
"""
pass
[docs] def validate_config(self):
"""Validate fields that should be checked after update is called.
This is to complement the builtin validation that Pydantic performs at the time
of object construction.
"""
pass
[docs] def revalidate(self):
"""Re-validate an instantiated Config.
Runs all Pydantic validators plus self.validate_config().
Adapted from:
https://github.com/samuelcolvin/pydantic/issues/1864#issuecomment-679044432
"""
*_, validation_error = validate_model(self.__class__, self.__dict__)
if validation_error:
raise validation_error
self.validate_config()
[docs] def recursive_validate_config(self):
"""Recursively validate hierarchies of Configs.
This uses reflection to call validate_config on a hierarchy of Configs
using a depth-first pre-order traversal.
"""
class_hierarchy = type(self).mro()
for klass in class_hierarchy:
if issubclass(klass, Config):
klass.validate_config(self)
child_configs = [
x for x in self.__dict__.values() if isinstance(x, Config)
]
for c in child_configs:
c.recursive_validate_config()
[docs] def validate_list(self, field: str, valid_options: List[str]):
"""Validate a list field.
Args:
field (str): name of field to validate
valid_options (List[str]): values that field is allowed to take
Raises:
ConfigError: if field is invalid
"""
val = getattr(self, field)
if isinstance(val, list):
for v in val:
if v not in valid_options:
raise ConfigError(f'{v} is not a valid option for {field}')
else:
if val not in valid_options:
raise ConfigError(f'{val} is not a valid option for {field}')
def __repr_args__(self):
"""Override to delete 'type_hint' field."""
args = dict(super().__repr_args__())
try:
del args['type_hint']
except KeyError:
pass
return args.items()
[docs]def save_pipeline_config(cfg: 'PipelineConfig', output_uri: str) -> None:
"""Save a PipelineConfig to JSON file.
Inject rv_config and plugin_versions before saving.
"""
cfg.rv_config = rv_config.get_config_dict(registry.rv_config_schema)
cfg.plugin_versions = registry.plugin_versions
cfg_json = cfg.json()
str_to_file(cfg_json, output_uri)
[docs]def build_config(x: Union[dict, List[Union[dict, Config]], Config]
) -> Union[Config, List[Config]]:
"""Build a Config from various types of input.
This is useful for deserializing from JSON. It implements polymorphic
deserialization by using the `type_hint` in each dict to get the
corresponding Config class from the registry.
Args:
x: some representation of Config(s)
Returns:
Config: the corresponding Config(s)
"""
if isinstance(x, dict):
new_x = {}
for k, v in x.items():
new_x[k] = build_config(v)
type_hint = new_x.get('type_hint')
if type_hint is not None:
config_cls = registry.get_config(type_hint)
new_x = config_cls(**new_x)
return new_x
elif isinstance(x, list):
return [build_config(v) for v in x]
else:
return x
def _upgrade_config(x: Union[dict, List[dict]], plugin_versions: Dict[str, int]
) -> Union[dict, List[dict]]:
"""Upgrade serialized Config(s) to the latest version.
Used to implement backward compatibility of Configs using upgraders stored
in the registry.
Args:
x: serialized Config(s) which are potentially of a non-current version
plugin_versions: dict mapping from plugin module name to the latest version
Returns:
the corresponding serialized Config(s) that have been upgraded to the
current version
"""
if isinstance(x, dict):
new_x = {}
for k, v in x.items():
new_x[k] = _upgrade_config(v, plugin_versions)
type_hint = new_x.get('type_hint')
if type_hint is not None:
type_hint_lineage = registry.get_type_hint_lineage(type_hint)
for th in type_hint_lineage:
plugin = registry.get_plugin(th)
old_version = plugin_versions[plugin]
curr_version = registry.get_plugin_version(plugin)
upgrader = registry.get_upgrader(th)
if upgrader:
for version in range(old_version, curr_version):
new_x = upgrader(new_x, version)
return new_x
elif isinstance(x, list):
return [_upgrade_config(v, plugin_versions) for v in x]
else:
return x
[docs]def upgrade_plugin_versions(plugin_versions: Dict[str, int]) -> Dict[str, int]:
"""Update the names of the plugins using the plugin aliases in the registry.
This allows changing the names of plugins over time and maintaining backward
compatibility of serialized PipelineConfigs.
Args:
plugin_version: maps from plugin name to version
"""
new_plugin_versions = {}
for alias, version in plugin_versions.items():
plugin = registry.get_plugin_from_alias(alias)
if plugin:
new_plugin_versions[plugin] = version
else:
raise ConfigError(
'The plugin_versions field contains an unrecognized '
f'plugin name: {alias}.')
return new_plugin_versions
[docs]def upgrade_config(
config_dict: Union[dict, List[dict]]) -> Union[dict, List[dict]]:
"""Upgrade serialized Config(s) to the latest version.
Used to implement backward compatibility of Configs using upgraders stored
in the registry.
Args:
config_dict: serialized PipelineConfig(s) which are potentially of a
non-current version
Returns:
the corresponding serialized PipelineConfig(s) that have been upgraded to the
current version
"""
plugin_versions = config_dict.get('plugin_versions')
plugin_versions = upgrade_plugin_versions(plugin_versions)
if plugin_versions is None:
raise ConfigError(
'Configuration is missing plugin_version field so is not backward '
'compatible.')
return _upgrade_config(config_dict, plugin_versions)
[docs]def get_plugin(config_cls: Type) -> str:
"""Infer the module path of the plugin where a Config class is defined.
This only works correctly if the plugin is in a module under rastervision.
"""
cls_module = inspect.getmodule(config_cls)
return 'rastervision.' + cls_module.__name__.split('.')[1]
[docs]def register_config(type_hint: str,
plugin: Optional[str] = None,
upgrader: Optional[Callable] = None) -> Callable:
"""Class decorator used to register Config classes with registry.
All Configs must be registered! Registering a Config does the following:
1. Associates Config classes with type_hint, plugin, and upgrader, which
is necessary for polymorphic deserialization. See build_config() for
more details.
2. Adds a constant `type_hint` field to the Config which is set to
type_hint.
Args:
type_hint (str): a type hint used to deserialize Configs. Must be
unique across all registered Configs.
plugin (Optional[str], optional): the module path of the plugin where
the Config is defined. If None, will be inferred.
Defauilts to None.
upgrader (Optional[Callable], optional): a function of the form
upgrade(config_dict, version) which returns the corresponding
config dict of version = version + 1. This can be useful for
maintaining backward compatibility by allowing old configs using an
outdated schema to be upgraded to the current schema.
Defaults to None.
Returns:
Callable: A function that returns a new class that is identical to the
input Config with an additional ``type_hint`` field.
"""
def _register_config(cls: Type):
new_cls = create_model(
cls.__name__,
__base__=cls,
__module__=cls.__module__,
# add a new field called "type_hint" with type Literal[type_hint]
# and default value type_hint to the config
type_hint=(Literal[type_hint], type_hint), # type: ignore
)
_plugin = plugin or get_plugin(cls)
registry.add_config(type_hint, new_cls, _plugin, upgrader)
# retain docstring after wrapping
new_cls.__doc__ = cls.__doc__
return new_cls
return _register_config