ModelConfig#
Note
All Configs are derived from rastervision.pipeline.config.Config
, which itself is a pydantic Model.
- pydantic model ModelConfig[source]#
Config related to models.
Show JSON schema
{ "title": "ModelConfig", "description": "Config related to models.", "type": "object", "properties": { "backbone": { "description": "The torchvision.models backbone to use.", "default": "resnet18", "allOf": [ { "$ref": "#/definitions/Backbone" } ] }, "pretrained": { "title": "Pretrained", "description": "If True, use ImageNet weights. If False, use random initialization.", "default": true, "type": "boolean" }, "init_weights": { "title": "Init Weights", "description": "URI of PyTorch model weights used to initialize model. If set, this supersedes the pretrained option.", "type": "string" }, "load_strict": { "title": "Load Strict", "description": "If True, the keys in the state dict referenced by init_weights must match exactly. Setting this to False can be useful if you just want to load the backbone of a model.", "default": true, "type": "boolean" }, "external_def": { "title": "External Def", "description": "If specified, the model will be built from the definition from this external source, using Torch Hub.", "allOf": [ { "$ref": "#/definitions/ExternalModuleConfig" } ] }, "extra_args": { "title": "Extra Args", "description": "Other implementation-specific args that might be useful for constructing the default model. This is ignored if using an external model.", "default": {}, "type": "object" }, "type_hint": { "title": "Type Hint", "default": "model", "enum": [ "model" ], "type": "string" } }, "additionalProperties": false, "definitions": { "Backbone": { "title": "Backbone", "description": "An enumeration.", "enum": [ "alexnet", "densenet121", "densenet169", "densenet201", "densenet161", "googlenet", "inception_v3", "mnasnet0_5", "mnasnet0_75", "mnasnet1_0", "mnasnet1_3", "mobilenet_v2", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d", "resnext101_32x8d", "wide_resnet50_2", "wide_resnet101_2", "shufflenet_v2_x0_5", "shufflenet_v2_x1_0", "shufflenet_v2_x1_5", "shufflenet_v2_x2_0", "squeezenet1_0", "squeezenet1_1", "vgg11", "vgg11_bn", "vgg13", "vgg13_bn", "vgg16", "vgg16_bn", "vgg19_bn", "vgg19" ] }, "ExternalModuleConfig": { "title": "ExternalModuleConfig", "description": "Config describing an object to be loaded via Torch Hub.", "type": "object", "properties": { "uri": { "title": "Uri", "description": "Local uri of a zip file, or local uri of a directory,or remote uri of zip file.", "minLength": 1, "type": "string" }, "github_repo": { "title": "Github Repo", "description": "<repo-owner>/<repo-name>[:tag]", "pattern": ".+/.+", "type": "string" }, "name": { "title": "Name", "description": "Name of the folder in which to extract/copy the definition files.", "minLength": 1, "type": "string" }, "entrypoint": { "title": "Entrypoint", "description": "Name of a callable present in hubconf.py. See docs for torch.hub for details.", "minLength": 1, "type": "string" }, "entrypoint_args": { "title": "Entrypoint Args", "description": "Args to pass to the entrypoint. Must be serializable.", "default": [], "type": "array", "items": {} }, "entrypoint_kwargs": { "title": "Entrypoint Kwargs", "description": "Keyword args to pass to the entrypoint. Must be serializable.", "default": {}, "type": "object" }, "force_reload": { "title": "Force Reload", "description": "Force reload of module definition.", "default": false, "type": "boolean" }, "type_hint": { "title": "Type Hint", "default": "external-module", "enum": [ "external-module" ], "type": "string" } }, "required": [ "entrypoint" ], "additionalProperties": false } } }
- Config
extra: str = forbid
validate_assignment: bool = True
- Fields
- field external_def: Optional[ExternalModuleConfig] = None#
If specified, the model will be built from the definition from this external source, using Torch Hub.
- field extra_args: dict = {}#
Other implementation-specific args that might be useful for constructing the default model. This is ignored if using an external model.
- field init_weights: Optional[str] = None#
URI of PyTorch model weights used to initialize model. If set, this supersedes the pretrained option.
- field load_strict: bool = True#
If True, the keys in the state dict referenced by init_weights must match exactly. Setting this to False can be useful if you just want to load the backbone of a model.
- build(num_classes: int, in_channels: int, save_dir: Optional[str] = None, hubconf_dir: Optional[str] = None, **kwargs) torch.nn.Module [source]#
Build and return a model based on the config.
- Parameters
num_classes (int) – Number of classes.
in_channels (int, optional) – Number of channels in the images that will be fed into the model. Defaults to 3.
save_dir (Optional[str], optional) – Used for building external_def if specified. Defaults to None.
hubconf_dir (Optional[str], optional) – Used for building external_def if specified. Defaults to None.
**kwargs – Extra args for
build_default_model()
.
- Returns
A PyTorch nn.Module.
- Return type
- build_default_model(num_classes: int, in_channels: int, **kwargs) torch.nn.Module [source]#
Build and return the default model.
- Parameters
- Returns
A PyTorch nn.Module.
- Return type
- build_external_model(save_dir: str, hubconf_dir: Optional[str] = None) torch.nn.Module [source]#
Build and return an external model.
- Parameters
- Returns
A PyTorch nn.Module.
- Return type
- recursive_validate_config()#
Recursively validate hierarchies of Configs.
This uses reflection to call validate_config on a hierarchy of Configs using a depth-first pre-order traversal.
- revalidate()#
Re-validate an instantiated Config.
Runs all Pydantic validators plus self.validate_config().
Adapted from: https://github.com/samuelcolvin/pydantic/issues/1864#issuecomment-679044432
- update(*args, **kwargs)#
Update any fields before validation.
Subclasses should override this to provide complex default behavior, for example, setting default values as a function of the values of other fields. The arguments to this method will vary depending on the type of Config.
- validate_config()#
Validate fields that should be checked after update is called.
This is to complement the builtin validation that Pydantic performs at the time of object construction.