SolverConfig#
Note
All Configs are derived from rastervision.pipeline.config.Config
, which itself is a pydantic Model.
- pydantic model SolverConfig[source]#
Config related to solver aka optimizer.
Show JSON schema
{ "title": "SolverConfig", "description": "Config related to solver aka optimizer.", "type": "object", "properties": { "lr": { "default": 0.0001, "description": "Learning rate.", "exclusiveMinimum": 0.0, "title": "Lr", "type": "number" }, "num_epochs": { "default": 10, "description": "Number of epochs (ie. sweeps through the whole training set).", "exclusiveMinimum": 0, "title": "Num Epochs", "type": "integer" }, "sync_interval": { "default": 1, "description": "The interval in epochs for each sync to the cloud.", "exclusiveMinimum": 0, "title": "Sync Interval", "type": "integer" }, "batch_sz": { "default": 32, "description": "Batch size.", "exclusiveMinimum": 0, "title": "Batch Sz", "type": "integer" }, "one_cycle": { "default": true, "description": "If True, use triangular LR scheduler with a single cycle across all epochs with start and end LR being lr/10 and the peak being lr.", "title": "One Cycle", "type": "boolean" }, "multi_stage": { "default": [], "description": "List of epoch indices at which to divide LR by 10.", "items": { "type": "integer" }, "title": "Multi Stage", "type": "array" }, "class_loss_weights": { "anyOf": [ { "items": { "type": "number" }, "type": "array" }, { "type": "null" } ], "default": null, "description": "Class weights for weighted loss.", "title": "Class Loss Weights" }, "ignore_class_index": { "anyOf": [ { "type": "integer" }, { "type": "null" } ], "default": null, "description": "If specified, this index is ignored when computing the loss. See pytorch documentation for nn.CrossEntropyLoss for more details. This can also be negative, in which case it is treated as a negative slice index i.e. -1 = last index, -2 = second-last index, and so on.", "title": "Ignore Class Index" }, "external_loss_def": { "anyOf": [ { "$ref": "#/$defs/ExternalModuleConfig" }, { "type": "null" } ], "default": null, "description": "If specified, the loss will be built from the definition from this external source, using Torch Hub." }, "type_hint": { "const": "solver", "default": "solver", "enum": [ "solver" ], "title": "Type Hint", "type": "string" } }, "$defs": { "ExternalModuleConfig": { "additionalProperties": false, "description": "Config describing an object to be loaded via Torch Hub.", "properties": { "uri": { "anyOf": [ { "minLength": 1, "type": "string" }, { "type": "null" } ], "default": null, "description": "Local uri of a zip file, or local uri of a directory,or remote uri of zip file.", "title": "Uri" }, "github_repo": { "anyOf": [ { "pattern": ".+/.+", "type": "string" }, { "type": "null" } ], "default": null, "description": "<repo-owner>/<repo-name>[:tag]", "title": "Github Repo" }, "name": { "anyOf": [ { "minLength": 1, "type": "string" }, { "type": "null" } ], "default": null, "description": "Name of the folder in which to extract/copy the definition files.", "title": "Name" }, "entrypoint": { "description": "Name of a Callable present in ``hubconf.py``. See docs for ``torch.hub`` for details.", "minLength": 1, "title": "Entrypoint", "type": "string" }, "entrypoint_args": { "default": [], "description": "Args to pass to the entrypoint. Must be serializable.", "items": {}, "title": "Entrypoint Args", "type": "array" }, "entrypoint_kwargs": { "default": {}, "description": "Keyword args to pass to the entrypoint. Must be serializable.", "title": "Entrypoint Kwargs", "type": "object" }, "force_reload": { "default": false, "description": "Force reload of module definition.", "title": "Force Reload", "type": "boolean" }, "type_hint": { "const": "external-module", "default": "external-module", "enum": [ "external-module" ], "title": "Type Hint", "type": "string" } }, "required": [ "entrypoint" ], "title": "ExternalModuleConfig", "type": "object" } }, "additionalProperties": false }
- Config:
extra: str = forbid
validate_assignment: bool = True
- Fields:
- Validators:
check_no_loss_opts_if_external
»all fields
- field batch_sz: PosInt = 32#
Batch size.
- Constraints:
gt = 0
- Validated by:
- field class_loss_weights: Sequence[float] | None = None#
Class weights for weighted loss.
- Validated by:
- field external_loss_def: ExternalModuleConfig | None = None#
If specified, the loss will be built from the definition from this external source, using Torch Hub.
- Validated by:
- field ignore_class_index: int | None = None#
If specified, this index is ignored when computing the loss. See pytorch documentation for nn.CrossEntropyLoss for more details. This can also be negative, in which case it is treated as a negative slice index i.e. -1 = last index, -2 = second-last index, and so on.
- Validated by:
- field lr: PositiveFloat = 0.0001#
Learning rate.
- Constraints:
gt = 0
- Validated by:
- field num_epochs: PosInt = 10#
Number of epochs (ie. sweeps through the whole training set).
- Constraints:
gt = 0
- Validated by:
- field one_cycle: bool = True#
If True, use triangular LR scheduler with a single cycle across all epochs with start and end LR being lr/10 and the peak being lr.
- Validated by:
- field sync_interval: PosInt = 1#
The interval in epochs for each sync to the cloud.
- Constraints:
gt = 0
- Validated by:
- build()#
Build an instance of the corresponding type of object using this config.
For example, BackendConfig will build a Backend object. The arguments to this method will vary depending on the type of Config.
- build_epoch_scheduler(optimizer: Optimizer, last_epoch: int = -1, **kwargs) torch.optim.lr_scheduler._LRScheduler | None [source]#
Returns an LR scheduler that changes the LR each epoch.
This is used to divide the learning rate by 10 at certain epochs.
- Parameters:
optimizer (optim.Optimizer) – Optimizer to build scheduler for.
last_epoch (int) – Last epoch. Defaults to -1.
**kwargs – Extra args for the scheduler constructor.
- Returns:
An epoch scheduler, if applicable. Otherwise, None.
- Return type:
torch.optim.lr_scheduler._LRScheduler | None
- build_loss(num_classes: int, save_dir: str | None = None, hubconf_dir: str | None = None) Callable[[...], Tensor] [source]#
Build and return a loss function based on the config.
- build_optimizer(model: Module, **kwargs) Adam [source]#
Build and return an Adam optimizer for the given model.
- Parameters:
model (nn.Module) – Model to be trained.
**kwargs – Extra args for the optimizer constructor.
- Returns:
An Adam optimizer instance.
- Return type:
- build_step_scheduler(optimizer: Optimizer, train_ds_sz: int, last_epoch: int = -1, **kwargs) torch.optim.lr_scheduler._LRScheduler | None [source]#
Returns an LR scheduler that changes the LR each step.
This is used to implement the “one cycle” schedule popularized by FastAI.
- Parameters:
- Returns:
A step scheduler, if applicable. Otherwise, None.
- Return type:
torch.optim.lr_scheduler._LRScheduler | None
- classmethod deserialize(inp: str | dict | Config) Self #
Deserialize Config from a JSON file or dict, upgrading if possible.
If
inp
is already aConfig
, it is returned as is.
- classmethod from_dict(cfg_dict: dict) Self #
Deserialize Config from a dict.
- Parameters:
cfg_dict (dict) – Dict to deserialize.
- Return type:
Self
- classmethod from_file(uri: str) Self #
Deserialize Config from a JSON file, upgrading if possible.
- Parameters:
uri (str) – URI to load from.
- Return type:
Self
- recursive_validate_config()#
Recursively validate hierarchies of Configs.
This uses reflection to call validate_config on a hierarchy of Configs using a depth-first pre-order traversal.
- revalidate()#
Re-validate an instantiated Config.
Runs all Pydantic validators plus self.validate_config().
- to_file(uri: str, with_rv_metadata: bool = True) None #
Save a Config to a JSON file, optionally with RV metadata.
- update(*args, **kwargs)#
Update any fields before validation.
Subclasses should override this to provide complex default behavior, for example, setting default values as a function of the values of other fields. The arguments to this method will vary depending on the type of Config.
- validate_config()#
Validate fields that should be checked after update is called.
This is to complement the builtin validation that Pydantic performs at the time of object construction.