

All Configs are derived from rastervision.pipeline.config.Config, which itself is a pydantic Model.

pydantic model ClassificationModelConfig[source]#

Configure a classification model.

Show JSON schema
   "title": "ClassificationModelConfig",
   "description": "Configure a classification model.",
   "type": "object",
   "properties": {
      "backbone": {
         "$ref": "#/$defs/Backbone",
         "default": "resnet18",
         "description": "The torchvision.models backbone to use."
      "pretrained": {
         "default": true,
         "description": "If True, use ImageNet weights. If False, use random initialization.",
         "title": "Pretrained",
         "type": "boolean"
      "init_weights": {
         "anyOf": [
               "type": "string"
               "type": "null"
         "default": null,
         "description": "URI of PyTorch model weights used to initialize model. If set, this supersedes the pretrained option.",
         "title": "Init Weights"
      "load_strict": {
         "default": true,
         "description": "If True, the keys in the state dict referenced by init_weights must match exactly. Setting this to False can be useful if you just want to load the backbone of a model.",
         "title": "Load Strict",
         "type": "boolean"
      "external_def": {
         "anyOf": [
               "$ref": "#/$defs/ExternalModuleConfig"
               "type": "null"
         "default": null,
         "description": "If specified, the model will be built from the definition from this external source, using Torch Hub."
      "extra_args": {
         "default": {},
         "description": "Other implementation-specific args that might be useful for constructing the default model. This is ignored if using an external model.",
         "title": "Extra Args",
         "type": "object"
      "type_hint": {
         "const": "classification_model",
         "default": "classification_model",
         "title": "Type Hint",
         "type": "string"
   "$defs": {
      "Backbone": {
         "enum": [
         "title": "Backbone",
         "type": "string"
      "ExternalModuleConfig": {
         "additionalProperties": false,
         "description": "Config describing an object to be loaded via Torch Hub.",
         "properties": {
            "uri": {
               "anyOf": [
                     "minLength": 1,
                     "type": "string"
                     "type": "null"
               "default": null,
               "description": "Local uri of a zip file, or local uri of a directory,or remote uri of zip file.",
               "title": "Uri"
            "github_repo": {
               "anyOf": [
                     "pattern": ".+/.+",
                     "type": "string"
                     "type": "null"
               "default": null,
               "description": "<repo-owner>/<repo-name>[:tag]",
               "title": "Github Repo"
            "name": {
               "anyOf": [
                     "minLength": 1,
                     "type": "string"
                     "type": "null"
               "default": null,
               "description": "Name of the folder in which to extract/copy the definition files.",
               "title": "Name"
            "entrypoint": {
               "description": "Name of a Callable present in ````. See docs for ``torch.hub`` for details.",
               "minLength": 1,
               "title": "Entrypoint",
               "type": "string"
            "entrypoint_args": {
               "default": [],
               "description": "Args to pass to the entrypoint. Must be serializable.",
               "items": {},
               "title": "Entrypoint Args",
               "type": "array"
            "entrypoint_kwargs": {
               "default": {},
               "description": "Keyword args to pass to the entrypoint. Must be serializable.",
               "title": "Entrypoint Kwargs",
               "type": "object"
            "force_reload": {
               "default": false,
               "description": "Force reload of module definition.",
               "title": "Force Reload",
               "type": "boolean"
            "type_hint": {
               "const": "external-module",
               "default": "external-module",
               "title": "Type Hint",
               "type": "string"
         "required": [
         "title": "ExternalModuleConfig",
         "type": "object"
   "additionalProperties": false

  • extra: str = forbid

  • validate_assignment: bool = True

field backbone: Backbone = Backbone.resnet18#

The torchvision.models backbone to use.

field external_def: ExternalModuleConfig | None = None#

If specified, the model will be built from the definition from this external source, using Torch Hub.

field extra_args: dict = {}#

Other implementation-specific args that might be useful for constructing the default model. This is ignored if using an external model.

field init_weights: str | None = None#

URI of PyTorch model weights used to initialize model. If set, this supersedes the pretrained option.

field load_strict: bool = True#

If True, the keys in the state dict referenced by init_weights must match exactly. Setting this to False can be useful if you just want to load the backbone of a model.

field pretrained: bool = True#

If True, use ImageNet weights. If False, use random initialization.

field type_hint: Literal['classification_model'] = 'classification_model'#
build(num_classes: int, in_channels: int, save_dir: str | None = None, hubconf_dir: str | None = None, ddp_rank: int | None = None, **kwargs) Module#

Build and return a model based on the config.

  • num_classes (int) – Number of classes.

  • in_channels (int) – Number of channels in the images that will be fed into the model. Defaults to 3.

  • save_dir (str|None) – Used for building external_def if specified. Defaults to None.

  • hubconf_dir (str|None) – Used for building external_def if specified. Defaults to None.

  • **kwargs – Extra args for build_default_model().

  • ddp_rank (int | None) –


A PyTorch nn.Module.

Return type:


build_default_model(num_classes: int, in_channels: int) Module[source]#

Build and return the default model.

  • num_classes (int) – Number of classes.

  • in_channels (int) – Number of channels in the images that will be fed into the model. Defaults to 3.


A PyTorch nn.Module.

Return type:


build_external_model(save_dir: str, hubconf_dir: str | None = None, ddp_rank: int | None = None) Module#

Build and return an external model.

  • save_dir (str) – The module def will be saved here.

  • hubconf_dir (str|None) – Path to existing definition. Defaults to None.

  • ddp_rank (int | None) –


A PyTorch nn.Module.

Return type:


classmethod deserialize(inp: str | dict | Config) Self#

Deserialize Config from a JSON file or dict, upgrading if possible.

If inp is already a Config, it is returned as is.


inp (str | dict | Config) – a URI to a JSON file or a dict.

Return type:


classmethod from_dict(cfg_dict: dict) Self#

Deserialize Config from a dict.


cfg_dict (dict) – Dict to deserialize.

Return type:


classmethod from_file(uri: str) Self#

Deserialize Config from a JSON file, upgrading if possible.


uri (str) – URI to load from.

Return type:



Recursively validate hierarchies of Configs.

This uses reflection to call validate_config on a hierarchy of Configs using a depth-first pre-order traversal.


Re-validate an instantiated Config.

Runs all Pydantic validators plus self.validate_config().

to_file(uri: str, with_rv_metadata: bool = True) None#

Save a Config to a JSON file, optionally with RV metadata.

  • uri (str) – URI to save to.

  • with_rv_metadata (bool) – If True, inject Raster Vision metadata such as plugin_versions, so that the config can be upgraded when loaded.

Return type:


update(*args, **kwargs)#

Update any fields before validation.

Subclasses should override this to provide complex default behavior, for example, setting default values as a function of the values of other fields. The arguments to this method will vary depending on the type of Config.


Validate fields that should be checked after update is called.

This is to complement the builtin validation that Pydantic performs at the time of object construction.

validate_list(field: str, valid_options: list[str])#

Validate a list field.

  • field (str) – name of field to validate

  • valid_options (list[str]) – values that field is allowed to take


ConfigError – if field is invalid