PyTorchChipClassificationConfig#

Note

All Configs are derived from rastervision.pipeline.config.Config, which itself is a pydantic Model.

pydantic model PyTorchChipClassificationConfig[source]#

Configure a PyTorchChipClassification backend.

Show JSON schema
{
   "title": "PyTorchChipClassificationConfig",
   "description": "Configure a :class:`.PyTorchChipClassification` backend.",
   "type": "object",
   "properties": {
      "type_hint": {
         "const": "pytorch_chip_classification_backend",
         "default": "pytorch_chip_classification_backend",
         "enum": [
            "pytorch_chip_classification_backend"
         ],
         "title": "Type Hint",
         "type": "string"
      },
      "model": {
         "$ref": "#/$defs/ClassificationModelConfig"
      },
      "solver": {
         "$ref": "#/$defs/SolverConfig"
      },
      "data": {
         "$ref": "#/$defs/DataConfig"
      },
      "log_tensorboard": {
         "default": true,
         "description": "If True, log events to Tensorboard log files.",
         "title": "Log Tensorboard",
         "type": "boolean"
      },
      "run_tensorboard": {
         "default": false,
         "description": "If True, run Tensorboard server pointing at log files.",
         "title": "Run Tensorboard",
         "type": "boolean"
      },
      "save_all_checkpoints": {
         "default": false,
         "description": "If True, all checkpoints would be saved. The latest checkpoint would be saved as `last-model.pth`. The checkpoints prior to last epoch are stored as `model-ckpt-epoch-{N}.pth` where `N` is the epoch number.",
         "title": "Save All Checkpoints",
         "type": "boolean"
      }
   },
   "$defs": {
      "Backbone": {
         "enum": [
            "alexnet",
            "densenet121",
            "densenet169",
            "densenet201",
            "densenet161",
            "googlenet",
            "inception_v3",
            "mnasnet0_5",
            "mnasnet0_75",
            "mnasnet1_0",
            "mnasnet1_3",
            "mobilenet_v2",
            "resnet18",
            "resnet34",
            "resnet50",
            "resnet101",
            "resnet152",
            "resnext50_32x4d",
            "resnext101_32x8d",
            "wide_resnet50_2",
            "wide_resnet101_2",
            "shufflenet_v2_x0_5",
            "shufflenet_v2_x1_0",
            "shufflenet_v2_x1_5",
            "shufflenet_v2_x2_0",
            "squeezenet1_0",
            "squeezenet1_1",
            "vgg11",
            "vgg11_bn",
            "vgg13",
            "vgg13_bn",
            "vgg16",
            "vgg16_bn",
            "vgg19_bn",
            "vgg19"
         ],
         "title": "Backbone",
         "type": "string"
      },
      "ClassConfig": {
         "additionalProperties": false,
         "description": "Configure class information for a machine learning task.",
         "properties": {
            "names": {
               "description": "Names of classes. The i-th class in this list will have class ID = i.",
               "items": {
                  "type": "string"
               },
               "title": "Names",
               "type": "array"
            },
            "colors": {
               "anyOf": [
                  {
                     "items": {
                        "anyOf": [
                           {
                              "type": "string"
                           },
                           {
                              "items": {},
                              "type": "array"
                           }
                        ]
                     },
                     "type": "array"
                  },
                  {
                     "type": "null"
                  }
               ],
               "default": null,
               "description": "Colors used to visualize classes. Can be color strings accepted by matplotlib or RGB tuples. If None, a random color will be auto-generated for each class.",
               "title": "Colors"
            },
            "null_class": {
               "anyOf": [
                  {
                     "type": "string"
                  },
                  {
                     "type": "null"
                  }
               ],
               "default": null,
               "description": "Optional name of class in `names` to use as the null class. This is used in semantic segmentation to represent the label for imagery pixels that are NODATA or that are missing a label. If None and the class names include \"null\", it will automatically be used as the null class. If None, and this Config is part of a SemanticSegmentationConfig, a null class will be added automatically.",
               "title": "Null Class"
            },
            "type_hint": {
               "const": "class_config",
               "default": "class_config",
               "enum": [
                  "class_config"
               ],
               "title": "Type Hint",
               "type": "string"
            }
         },
         "required": [
            "names"
         ],
         "title": "ClassConfig",
         "type": "object"
      },
      "ClassificationModelConfig": {
         "additionalProperties": false,
         "description": "Configure a classification model.",
         "properties": {
            "backbone": {
               "allOf": [
                  {
                     "$ref": "#/$defs/Backbone"
                  }
               ],
               "default": "resnet18",
               "description": "The torchvision.models backbone to use."
            },
            "pretrained": {
               "default": true,
               "description": "If True, use ImageNet weights. If False, use random initialization.",
               "title": "Pretrained",
               "type": "boolean"
            },
            "init_weights": {
               "anyOf": [
                  {
                     "type": "string"
                  },
                  {
                     "type": "null"
                  }
               ],
               "default": null,
               "description": "URI of PyTorch model weights used to initialize model. If set, this supersedes the pretrained option.",
               "title": "Init Weights"
            },
            "load_strict": {
               "default": true,
               "description": "If True, the keys in the state dict referenced by init_weights must match exactly. Setting this to False can be useful if you just want to load the backbone of a model.",
               "title": "Load Strict",
               "type": "boolean"
            },
            "external_def": {
               "anyOf": [
                  {
                     "$ref": "#/$defs/ExternalModuleConfig"
                  },
                  {
                     "type": "null"
                  }
               ],
               "default": null,
               "description": "If specified, the model will be built from the definition from this external source, using Torch Hub."
            },
            "extra_args": {
               "default": {},
               "description": "Other implementation-specific args that might be useful for constructing the default model. This is ignored if using an external model.",
               "title": "Extra Args",
               "type": "object"
            },
            "type_hint": {
               "const": "classification_model",
               "default": "classification_model",
               "enum": [
                  "classification_model"
               ],
               "title": "Type Hint",
               "type": "string"
            }
         },
         "title": "ClassificationModelConfig",
         "type": "object"
      },
      "DataConfig": {
         "additionalProperties": false,
         "description": "Config related to dataset for training and testing.",
         "properties": {
            "class_config": {
               "anyOf": [
                  {
                     "$ref": "#/$defs/ClassConfig"
                  },
                  {
                     "type": "null"
                  }
               ],
               "default": null,
               "description": "Class config."
            },
            "img_channels": {
               "anyOf": [
                  {
                     "exclusiveMinimum": 0,
                     "type": "integer"
                  },
                  {
                     "type": "null"
                  }
               ],
               "default": null,
               "description": "The number of channels of the training images.",
               "title": "Img Channels"
            },
            "img_sz": {
               "default": 256,
               "description": "Length of a side of each image in pixels. This is the size to transform it to during training, not the size in the raw dataset.",
               "exclusiveMinimum": 0,
               "title": "Img Sz",
               "type": "integer"
            },
            "train_sz": {
               "anyOf": [
                  {
                     "type": "integer"
                  },
                  {
                     "type": "null"
                  }
               ],
               "default": null,
               "description": "If set, the number of training images to use. If fewer images exist, then an exception will be raised.",
               "title": "Train Sz"
            },
            "train_sz_rel": {
               "anyOf": [
                  {
                     "type": "number"
                  },
                  {
                     "type": "null"
                  }
               ],
               "default": null,
               "description": "If set, the proportion of training images to use.",
               "title": "Train Sz Rel"
            },
            "num_workers": {
               "default": 4,
               "description": "Number of workers to use when DataLoader makes batches.",
               "title": "Num Workers",
               "type": "integer"
            },
            "augmentors": {
               "default": [
                  "RandomRotate90",
                  "HorizontalFlip",
                  "VerticalFlip"
               ],
               "description": "Names of albumentations augmentors to use for training batches. Choices include: ['Blur', 'RandomRotate90', 'HorizontalFlip', 'VerticalFlip', 'GaussianBlur', 'GaussNoise', 'RGBShift', 'ToGray']. Alternatively, a custom transform can be provided via the aug_transform option.",
               "items": {
                  "type": "string"
               },
               "title": "Augmentors",
               "type": "array"
            },
            "base_transform": {
               "anyOf": [
                  {
                     "type": "object"
                  },
                  {
                     "type": "null"
                  }
               ],
               "default": null,
               "description": "An Albumentations transform serialized as a dict that will be applied to all datasets: training, validation, and test. This transformation is in addition to the resizing due to img_sz. This is useful for, for example, applying the same normalization to all datasets.",
               "title": "Base Transform"
            },
            "aug_transform": {
               "anyOf": [
                  {
                     "type": "object"
                  },
                  {
                     "type": "null"
                  }
               ],
               "default": null,
               "description": "An Albumentations transform serialized as a dict that will be applied as data augmentation to the training dataset. This transform is applied before base_transform. If provided, the augmentors option is ignored.",
               "title": "Aug Transform"
            },
            "plot_options": {
               "anyOf": [
                  {
                     "$ref": "#/$defs/PlotOptions"
                  },
                  {
                     "type": "null"
                  }
               ],
               "default": {
                  "transform": {
                     "__version__": "1.4.14",
                     "transform": {
                        "__class_fullname__": "rastervision.pytorch_learner.utils.utils.MinMaxNormalize",
                        "dtype": 5,
                        "max_val": 1.0,
                        "min_val": 0.0,
                        "p": 1.0
                     }
                  },
                  "channel_display_groups": null,
                  "type_hint": "plot_options"
               },
               "description": "Options to control plotting."
            },
            "preview_batch_limit": {
               "anyOf": [
                  {
                     "type": "integer"
                  },
                  {
                     "type": "null"
                  }
               ],
               "default": null,
               "description": "Optional limit on the number of items in the preview plots produced during training.",
               "title": "Preview Batch Limit"
            },
            "type_hint": {
               "const": "data",
               "default": "data",
               "enum": [
                  "data"
               ],
               "title": "Type Hint",
               "type": "string"
            }
         },
         "title": "DataConfig",
         "type": "object"
      },
      "ExternalModuleConfig": {
         "additionalProperties": false,
         "description": "Config describing an object to be loaded via Torch Hub.",
         "properties": {
            "uri": {
               "anyOf": [
                  {
                     "minLength": 1,
                     "type": "string"
                  },
                  {
                     "type": "null"
                  }
               ],
               "default": null,
               "description": "Local uri of a zip file, or local uri of a directory,or remote uri of zip file.",
               "title": "Uri"
            },
            "github_repo": {
               "anyOf": [
                  {
                     "pattern": ".+/.+",
                     "type": "string"
                  },
                  {
                     "type": "null"
                  }
               ],
               "default": null,
               "description": "<repo-owner>/<repo-name>[:tag]",
               "title": "Github Repo"
            },
            "name": {
               "anyOf": [
                  {
                     "minLength": 1,
                     "type": "string"
                  },
                  {
                     "type": "null"
                  }
               ],
               "default": null,
               "description": "Name of the folder in which to extract/copy the definition files.",
               "title": "Name"
            },
            "entrypoint": {
               "description": "Name of a Callable present in ``hubconf.py``. See docs for ``torch.hub`` for details.",
               "minLength": 1,
               "title": "Entrypoint",
               "type": "string"
            },
            "entrypoint_args": {
               "default": [],
               "description": "Args to pass to the entrypoint. Must be serializable.",
               "items": {},
               "title": "Entrypoint Args",
               "type": "array"
            },
            "entrypoint_kwargs": {
               "default": {},
               "description": "Keyword args to pass to the entrypoint. Must be serializable.",
               "title": "Entrypoint Kwargs",
               "type": "object"
            },
            "force_reload": {
               "default": false,
               "description": "Force reload of module definition.",
               "title": "Force Reload",
               "type": "boolean"
            },
            "type_hint": {
               "const": "external-module",
               "default": "external-module",
               "enum": [
                  "external-module"
               ],
               "title": "Type Hint",
               "type": "string"
            }
         },
         "required": [
            "entrypoint"
         ],
         "title": "ExternalModuleConfig",
         "type": "object"
      },
      "PlotOptions": {
         "additionalProperties": false,
         "description": "Config related to plotting.",
         "properties": {
            "transform": {
               "anyOf": [
                  {
                     "type": "object"
                  },
                  {
                     "type": "null"
                  }
               ],
               "default": {
                  "__version__": "1.4.14",
                  "transform": {
                     "__class_fullname__": "rastervision.pytorch_learner.utils.utils.MinMaxNormalize",
                     "dtype": 5,
                     "max_val": 1.0,
                     "min_val": 0.0,
                     "p": 1.0
                  }
               },
               "description": "An Albumentations transform serialized as a dict that will be applied to each image before it is plotted. Mainly useful for undoing any data transformation that you do not want included in the plot, such as normalization. The default value will shift and scale the image so the values range from 0.0 to 1.0 which is the expected range for the plotting function. This default is useful for cases where the values after normalization are close to zero which makes the plot difficult to see.",
               "title": "Transform"
            },
            "channel_display_groups": {
               "anyOf": [
                  {
                     "additionalProperties": {
                        "items": {
                           "minimum": 0,
                           "type": "integer"
                        },
                        "type": "array"
                     },
                     "type": "object"
                  },
                  {
                     "items": {
                        "items": {
                           "minimum": 0,
                           "type": "integer"
                        },
                        "type": "array"
                     },
                     "type": "array"
                  },
                  {
                     "type": "null"
                  }
               ],
               "default": null,
               "description": "Groups of image channels to display together as a subplot when plotting the data and predictions. Can be a list or tuple of groups (e.g. [(0, 1, 2), (3,)]) or a dict containing title-to-group mappings (e.g. {\"RGB\": [0, 1, 2], \"IR\": [3]}), where each group is a list or tuple of channel indices and title is a string that will be used as the title of the subplot for that group.",
               "title": "Channel Display Groups"
            },
            "type_hint": {
               "const": "plot_options",
               "default": "plot_options",
               "enum": [
                  "plot_options"
               ],
               "title": "Type Hint",
               "type": "string"
            }
         },
         "title": "PlotOptions",
         "type": "object"
      },
      "SolverConfig": {
         "additionalProperties": false,
         "description": "Config related to solver aka optimizer.",
         "properties": {
            "lr": {
               "default": 0.0001,
               "description": "Learning rate.",
               "exclusiveMinimum": 0.0,
               "title": "Lr",
               "type": "number"
            },
            "num_epochs": {
               "default": 10,
               "description": "Number of epochs (ie. sweeps through the whole training set).",
               "exclusiveMinimum": 0,
               "title": "Num Epochs",
               "type": "integer"
            },
            "sync_interval": {
               "default": 1,
               "description": "The interval in epochs for each sync to the cloud.",
               "exclusiveMinimum": 0,
               "title": "Sync Interval",
               "type": "integer"
            },
            "batch_sz": {
               "default": 32,
               "description": "Batch size.",
               "exclusiveMinimum": 0,
               "title": "Batch Sz",
               "type": "integer"
            },
            "one_cycle": {
               "default": true,
               "description": "If True, use triangular LR scheduler with a single cycle across all epochs with start and end LR being lr/10 and the peak being lr.",
               "title": "One Cycle",
               "type": "boolean"
            },
            "multi_stage": {
               "default": [],
               "description": "List of epoch indices at which to divide LR by 10.",
               "items": {
                  "type": "integer"
               },
               "title": "Multi Stage",
               "type": "array"
            },
            "class_loss_weights": {
               "anyOf": [
                  {
                     "items": {
                        "type": "number"
                     },
                     "type": "array"
                  },
                  {
                     "type": "null"
                  }
               ],
               "default": null,
               "description": "Class weights for weighted loss.",
               "title": "Class Loss Weights"
            },
            "ignore_class_index": {
               "anyOf": [
                  {
                     "type": "integer"
                  },
                  {
                     "type": "null"
                  }
               ],
               "default": null,
               "description": "If specified, this index is ignored when computing the loss. See pytorch documentation for nn.CrossEntropyLoss for more details. This can also be negative, in which case it is treated as a negative slice index i.e. -1 = last index, -2 = second-last index, and so on.",
               "title": "Ignore Class Index"
            },
            "external_loss_def": {
               "anyOf": [
                  {
                     "$ref": "#/$defs/ExternalModuleConfig"
                  },
                  {
                     "type": "null"
                  }
               ],
               "default": null,
               "description": "If specified, the loss will be built from the definition from this external source, using Torch Hub."
            },
            "type_hint": {
               "const": "solver",
               "default": "solver",
               "enum": [
                  "solver"
               ],
               "title": "Type Hint",
               "type": "string"
            }
         },
         "title": "SolverConfig",
         "type": "object"
      }
   },
   "additionalProperties": false,
   "required": [
      "model",
      "solver",
      "data"
   ]
}

Config:
  • extra: str = forbid

  • validate_assignment: bool = True

Fields:
field data: DataConfig [Required]#
field log_tensorboard: bool = True#

If True, log events to Tensorboard log files.

field model: ClassificationModelConfig [Required]#
field run_tensorboard: bool = False#

If True, run Tensorboard server pointing at log files.

field save_all_checkpoints: bool = False#

If True, all checkpoints would be saved. The latest checkpoint would be saved as last-model.pth. The checkpoints prior to last epoch are stored as model-ckpt-epoch-{N}.pth where N is the epoch number.

field solver: SolverConfig [Required]#
field type_hint: Literal['pytorch_chip_classification_backend'] = 'pytorch_chip_classification_backend'#
build(pipeline, tmp_dir)[source]#

Build an instance of the corresponding type of object using this config.

For example, BackendConfig will build a Backend object. The arguments to this method will vary depending on the type of Config.

classmethod deserialize(inp: str | dict | Config) Self#

Deserialize Config from a JSON file or dict, upgrading if possible.

If inp is already a Config, it is returned as is.

Parameters:

inp (str | dict | Config) – a URI to a JSON file or a dict.

Return type:

Self

filter_commands(commands: list[str]) list[str]#

Filter out any commands that are not needed or supported.

Parameters:

commands (list[str]) –

Return type:

list[str]

classmethod from_dict(cfg_dict: dict) Self#

Deserialize Config from a dict.

Parameters:

cfg_dict (dict) – Dict to deserialize.

Return type:

Self

classmethod from_file(uri: str) Self#

Deserialize Config from a JSON file, upgrading if possible.

Parameters:

uri (str) – URI to load from.

Return type:

Self

get_bundle_filenames()#

Returns the names of files that should be included in a model bundle.

The files are assumed to be in the train/ directory generated by the train command. Note that only the names, not the full paths should be returned.

get_img_channels(pipeline_cfg: RVPipelineConfig) int#

Determine img_channels from scenes.

Parameters:

pipeline_cfg (RVPipelineConfig) –

Return type:

int

get_learner_config(pipeline)[source]#
recursive_validate_config()#

Recursively validate hierarchies of Configs.

This uses reflection to call validate_config on a hierarchy of Configs using a depth-first pre-order traversal.

revalidate()#

Re-validate an instantiated Config.

Runs all Pydantic validators plus self.validate_config().

to_file(uri: str, with_rv_metadata: bool = True) None#

Save a Config to a JSON file, optionally with RV metadata.

Parameters:
  • uri (str) – URI to save to.

  • with_rv_metadata (bool) – If True, inject Raster Vision metadata such as plugin_versions, so that the config can be upgraded when loaded.

Return type:

None

update(pipeline: rastervision.core.rv_pipeline.rv_pipeline_config.RVPipelineConfig | None = None)#

Update any fields before validation.

Subclasses should override this to provide complex default behavior, for example, setting default values as a function of the values of other fields. The arguments to this method will vary depending on the type of Config.

Parameters:

pipeline (rastervision.core.rv_pipeline.rv_pipeline_config.RVPipelineConfig | None) –

validate_config()#

Validate fields that should be checked after update is called.

This is to complement the builtin validation that Pydantic performs at the time of object construction.

validate_list(field: str, valid_options: list[str])#

Validate a list field.

Parameters:
  • field (str) – name of field to validate

  • valid_options (list[str]) – values that field is allowed to take

Raises:

ConfigError – if field is invalid