ClassificationLearnerConfig#

Note

All Configs are derived from rastervision.pipeline.config.Config, which itself is a pydantic Model.

pydantic model ClassificationLearnerConfig[source]#

Configure a ClassificationLearner.

Show JSON schema
{
   "title": "ClassificationLearnerConfig",
   "description": "Configure a :class:`.ClassificationLearner`.",
   "type": "object",
   "properties": {
      "model": {
         "$ref": "#/definitions/ClassificationModelConfig"
      },
      "solver": {
         "$ref": "#/definitions/SolverConfig"
      },
      "data": {
         "title": "Data",
         "anyOf": [
            {
               "$ref": "#/definitions/ClassificationImageDataConfig"
            },
            {
               "$ref": "#/definitions/ClassificationGeoDataConfig"
            }
         ]
      },
      "predict_mode": {
         "title": "Predict Mode",
         "description": "If True, skips training, loads model, and does final eval.",
         "default": false,
         "type": "boolean"
      },
      "test_mode": {
         "title": "Test Mode",
         "description": "If True, uses test_num_epochs, test_batch_sz, truncated datasets with only a single batch, image_sz that is cut in half, and num_workers = 0. This is useful for testing that code runs correctly on CPU without multithreading before running full job on GPU.",
         "default": false,
         "type": "boolean"
      },
      "overfit_mode": {
         "title": "Overfit Mode",
         "description": "If True, uses half image size, and instead of doing epoch-based training, optimizes the model using a single batch repeatedly for overfit_num_steps number of steps.",
         "default": false,
         "type": "boolean"
      },
      "eval_train": {
         "title": "Eval Train",
         "description": "If True, runs final evaluation on training set (in addition to test set). Useful for debugging.",
         "default": false,
         "type": "boolean"
      },
      "save_model_bundle": {
         "title": "Save Model Bundle",
         "description": "If True, saves a model bundle at the end of training which is zip file with model and this LearnerConfig which can be used to make predictions on new images at a later time.",
         "default": true,
         "type": "boolean"
      },
      "log_tensorboard": {
         "title": "Log Tensorboard",
         "description": "Save Tensorboard log files at the end of each epoch.",
         "default": true,
         "type": "boolean"
      },
      "run_tensorboard": {
         "title": "Run Tensorboard",
         "description": "run Tensorboard server during training",
         "default": false,
         "type": "boolean"
      },
      "output_uri": {
         "title": "Output Uri",
         "description": "URI of where to save output",
         "type": "string"
      },
      "save_all_checkpoints": {
         "title": "Save All Checkpoints",
         "description": "If True, all checkpoints would be saved. The latest checkpoint would be saved as `last-model.pth`. The checkpoints prior to last epoch are stored as `model-ckpt-epoch-{N}.pth` where `N` is the epoch number.",
         "default": false,
         "type": "boolean"
      },
      "type_hint": {
         "title": "Type Hint",
         "default": "classification_learner",
         "enum": [
            "classification_learner"
         ],
         "type": "string"
      }
   },
   "required": [
      "solver",
      "data"
   ],
   "additionalProperties": false,
   "definitions": {
      "Backbone": {
         "title": "Backbone",
         "description": "An enumeration.",
         "enum": [
            "alexnet",
            "densenet121",
            "densenet169",
            "densenet201",
            "densenet161",
            "googlenet",
            "inception_v3",
            "mnasnet0_5",
            "mnasnet0_75",
            "mnasnet1_0",
            "mnasnet1_3",
            "mobilenet_v2",
            "resnet18",
            "resnet34",
            "resnet50",
            "resnet101",
            "resnet152",
            "resnext50_32x4d",
            "resnext101_32x8d",
            "wide_resnet50_2",
            "wide_resnet101_2",
            "shufflenet_v2_x0_5",
            "shufflenet_v2_x1_0",
            "shufflenet_v2_x1_5",
            "shufflenet_v2_x2_0",
            "squeezenet1_0",
            "squeezenet1_1",
            "vgg11",
            "vgg11_bn",
            "vgg13",
            "vgg13_bn",
            "vgg16",
            "vgg16_bn",
            "vgg19_bn",
            "vgg19"
         ]
      },
      "ExternalModuleConfig": {
         "title": "ExternalModuleConfig",
         "description": "Config describing an object to be loaded via Torch Hub.",
         "type": "object",
         "properties": {
            "uri": {
               "title": "Uri",
               "description": "Local uri of a zip file, or local uri of a directory,or remote uri of zip file.",
               "minLength": 1,
               "type": "string"
            },
            "github_repo": {
               "title": "Github Repo",
               "description": "<repo-owner>/<repo-name>[:tag]",
               "pattern": ".+/.+",
               "type": "string"
            },
            "name": {
               "title": "Name",
               "description": "Name of the folder in which to extract/copy the definition files.",
               "minLength": 1,
               "type": "string"
            },
            "entrypoint": {
               "title": "Entrypoint",
               "description": "Name of a callable present in hubconf.py. See docs for torch.hub for details.",
               "minLength": 1,
               "type": "string"
            },
            "entrypoint_args": {
               "title": "Entrypoint Args",
               "description": "Args to pass to the entrypoint. Must be serializable.",
               "default": [],
               "type": "array",
               "items": {}
            },
            "entrypoint_kwargs": {
               "title": "Entrypoint Kwargs",
               "description": "Keyword args to pass to the entrypoint. Must be serializable.",
               "default": {},
               "type": "object"
            },
            "force_reload": {
               "title": "Force Reload",
               "description": "Force reload of module definition.",
               "default": false,
               "type": "boolean"
            },
            "type_hint": {
               "title": "Type Hint",
               "default": "external-module",
               "enum": [
                  "external-module"
               ],
               "type": "string"
            }
         },
         "required": [
            "entrypoint"
         ],
         "additionalProperties": false
      },
      "ClassificationModelConfig": {
         "title": "ClassificationModelConfig",
         "description": "Configure a classification model.",
         "type": "object",
         "properties": {
            "backbone": {
               "description": "The torchvision.models backbone to use.",
               "default": "resnet18",
               "allOf": [
                  {
                     "$ref": "#/definitions/Backbone"
                  }
               ]
            },
            "pretrained": {
               "title": "Pretrained",
               "description": "If True, use ImageNet weights. If False, use random initialization.",
               "default": true,
               "type": "boolean"
            },
            "init_weights": {
               "title": "Init Weights",
               "description": "URI of PyTorch model weights used to initialize model. If set, this supercedes the pretrained option.",
               "type": "string"
            },
            "load_strict": {
               "title": "Load Strict",
               "description": "If True, the keys in the state dict referenced by init_weights must match exactly. Setting this to False can be useful if you just want to load the backbone of a model.",
               "default": true,
               "type": "boolean"
            },
            "external_def": {
               "title": "External Def",
               "description": "If specified, the model will be built from the definition from this external source, using Torch Hub.",
               "allOf": [
                  {
                     "$ref": "#/definitions/ExternalModuleConfig"
                  }
               ]
            },
            "extra_args": {
               "title": "Extra Args",
               "description": "Other implementation-specific args that might be useful for constructing the default model. This is ignored if using an external model.",
               "default": {},
               "type": "object"
            },
            "type_hint": {
               "title": "Type Hint",
               "default": "classification_model",
               "enum": [
                  "classification_model"
               ],
               "type": "string"
            }
         },
         "additionalProperties": false
      },
      "SolverConfig": {
         "title": "SolverConfig",
         "description": "Config related to solver aka optimizer.",
         "type": "object",
         "properties": {
            "lr": {
               "title": "Lr",
               "description": "Learning rate.",
               "default": 0.0001,
               "exclusiveMinimum": 0,
               "type": "number"
            },
            "num_epochs": {
               "title": "Num Epochs",
               "description": "Number of epochs (ie. sweeps through the whole training set).",
               "default": 10,
               "exclusiveMinimum": 0,
               "type": "integer"
            },
            "test_num_epochs": {
               "title": "Test Num Epochs",
               "description": "Number of epochs to use in test mode.",
               "default": 2,
               "exclusiveMinimum": 0,
               "type": "integer"
            },
            "test_batch_sz": {
               "title": "Test Batch Sz",
               "description": "Batch size to use in test mode.",
               "default": 4,
               "exclusiveMinimum": 0,
               "type": "integer"
            },
            "overfit_num_steps": {
               "title": "Overfit Num Steps",
               "description": "Number of optimizer steps to use in overfit mode.",
               "default": 1,
               "exclusiveMinimum": 0,
               "type": "integer"
            },
            "sync_interval": {
               "title": "Sync Interval",
               "description": "The interval in epochs for each sync to the cloud.",
               "default": 1,
               "exclusiveMinimum": 0,
               "type": "integer"
            },
            "batch_sz": {
               "title": "Batch Sz",
               "description": "Batch size.",
               "default": 32,
               "exclusiveMinimum": 0,
               "type": "integer"
            },
            "one_cycle": {
               "title": "One Cycle",
               "description": "If True, use triangular LR scheduler with a single cycle across all epochs with start and end LR being lr/10 and the peak being lr.",
               "default": true,
               "type": "boolean"
            },
            "multi_stage": {
               "title": "Multi Stage",
               "description": "List of epoch indices at which to divide LR by 10.",
               "default": [],
               "type": "array",
               "items": {}
            },
            "class_loss_weights": {
               "title": "Class Loss Weights",
               "description": "Class weights for weighted loss.",
               "type": "array",
               "items": {
                  "type": "number"
               }
            },
            "ignore_class_index": {
               "title": "Ignore Class Index",
               "description": "If specified, this index is ignored when computing the loss. See pytorch documentation for nn.CrossEntropyLoss for more details. This can also be negative, in which case it is treated as a negative slice index i.e. -1 = last index, -2 = second-last index, and so on.",
               "type": "integer"
            },
            "external_loss_def": {
               "title": "External Loss Def",
               "description": "If specified, the loss will be built from the definition from this external source, using Torch Hub.",
               "allOf": [
                  {
                     "$ref": "#/definitions/ExternalModuleConfig"
                  }
               ]
            },
            "type_hint": {
               "title": "Type Hint",
               "default": "solver",
               "enum": [
                  "solver"
               ],
               "type": "string"
            }
         },
         "additionalProperties": false
      },
      "PlotOptions": {
         "title": "PlotOptions",
         "description": "Config related to plotting.",
         "type": "object",
         "properties": {
            "transform": {
               "title": "Transform",
               "description": "An Albumentations transform serialized as a dict that will be applied to each image before it is plotted. Mainly useful for undoing any data transformation that you do not want included in the plot, such as normalization. The default value will shift and scale the image so the values range from 0.0 to 1.0 which is the expected range for the plotting function. This default is useful for cases where the values after normalization are close to zero which makes the plot difficult to see.",
               "default": {
                  "__version__": "1.3.0",
                  "transform": {
                     "__class_fullname__": "rastervision.pytorch_learner.utils.utils.MinMaxNormalize",
                     "always_apply": false,
                     "p": 1.0,
                     "min_val": 0.0,
                     "max_val": 1.0,
                     "dtype": 5
                  }
               },
               "type": "object"
            },
            "channel_display_groups": {
               "title": "Channel Display Groups",
               "description": "Groups of image channels to display together as a subplot when plotting the data and predictions. Can be a list or tuple of groups (e.g. [(0, 1, 2), (3,)]) or a dict containing title-to-group mappings (e.g. {\"RGB\": [0, 1, 2], \"IR\": [3]}), where each group is a list or tuple of channel indices and title is a string that will be used as the title of the subplot for that group.",
               "anyOf": [
                  {
                     "type": "object",
                     "additionalProperties": {
                        "type": "array",
                        "items": {
                           "type": "integer",
                           "minimum": 0
                        }
                     }
                  },
                  {
                     "type": "array",
                     "items": {
                        "type": "array",
                        "items": {
                           "type": "integer",
                           "minimum": 0
                        }
                     }
                  }
               ]
            },
            "type_hint": {
               "title": "Type Hint",
               "default": "plot_options",
               "enum": [
                  "plot_options"
               ],
               "type": "string"
            }
         },
         "additionalProperties": false
      },
      "ClassificationDataFormat": {
         "title": "ClassificationDataFormat",
         "description": "An enumeration.",
         "enum": [
            "image_folder"
         ]
      },
      "ClassificationImageDataConfig": {
         "title": "ClassificationImageDataConfig",
         "description": "Configure :class:`ClassificationImageDatasets <.ClassificationImageDataset>`.",
         "type": "object",
         "properties": {
            "class_names": {
               "title": "Class Names",
               "description": "Names of classes.",
               "default": [],
               "type": "array",
               "items": {
                  "type": "string"
               }
            },
            "class_colors": {
               "title": "Class Colors",
               "description": "Colors used to display classes. Can be color 3-tuples in list form.",
               "type": "array",
               "items": {
                  "anyOf": [
                     {
                        "type": "string"
                     },
                     {
                        "type": "array",
                        "minItems": 3,
                        "maxItems": 3,
                        "items": [
                           {
                              "type": "integer"
                           },
                           {
                              "type": "integer"
                           },
                           {
                              "type": "integer"
                           }
                        ]
                     }
                  ]
               }
            },
            "img_channels": {
               "title": "Img Channels",
               "description": "The number of channels of the training images.",
               "exclusiveMinimum": 0,
               "type": "integer"
            },
            "img_sz": {
               "title": "Img Sz",
               "description": "Length of a side of each image in pixels. This is the size to transform it to during training, not the size in the raw dataset.",
               "default": 256,
               "exclusiveMinimum": 0,
               "type": "integer"
            },
            "train_sz": {
               "title": "Train Sz",
               "description": "If set, the number of training images to use. If fewer images exist, then an exception will be raised.",
               "type": "integer"
            },
            "train_sz_rel": {
               "title": "Train Sz Rel",
               "description": "If set, the proportion of training images to use.",
               "type": "number"
            },
            "num_workers": {
               "title": "Num Workers",
               "description": "Number of workers to use when DataLoader makes batches.",
               "default": 4,
               "type": "integer"
            },
            "augmentors": {
               "title": "Augmentors",
               "description": "Names of albumentations augmentors to use for training batches. Choices include: ['Blur', 'RandomRotate90', 'HorizontalFlip', 'VerticalFlip', 'GaussianBlur', 'GaussNoise', 'RGBShift', 'ToGray']. Alternatively, a custom transform can be provided via the aug_transform option.",
               "default": [
                  "RandomRotate90",
                  "HorizontalFlip",
                  "VerticalFlip"
               ],
               "type": "array",
               "items": {
                  "type": "string"
               }
            },
            "base_transform": {
               "title": "Base Transform",
               "description": "An Albumentations transform serialized as a dict that will be applied to all datasets: training, validation, and test. This transformation is in addition to the resizing due to img_sz. This is useful for, for example, applying the same normalization to all datasets.",
               "type": "object"
            },
            "aug_transform": {
               "title": "Aug Transform",
               "description": "An Albumentations transform serialized as a dict that will be applied as data augmentation to the training dataset. This transform is applied before base_transform. If provided, the augmentors option is ignored.",
               "type": "object"
            },
            "plot_options": {
               "title": "Plot Options",
               "description": "Options to control plotting.",
               "default": {
                  "transform": {
                     "__version__": "1.3.0",
                     "transform": {
                        "__class_fullname__": "rastervision.pytorch_learner.utils.utils.MinMaxNormalize",
                        "always_apply": false,
                        "p": 1.0,
                        "min_val": 0.0,
                        "max_val": 1.0,
                        "dtype": 5
                     }
                  },
                  "channel_display_groups": null,
                  "type_hint": "plot_options"
               },
               "allOf": [
                  {
                     "$ref": "#/definitions/PlotOptions"
                  }
               ]
            },
            "preview_batch_limit": {
               "title": "Preview Batch Limit",
               "description": "Optional limit on the number of items in the preview plots produced during training.",
               "type": "integer"
            },
            "type_hint": {
               "title": "Type Hint",
               "default": "classification_image_data",
               "enum": [
                  "classification_image_data"
               ],
               "type": "string"
            },
            "data_format": {
               "default": "image_folder",
               "allOf": [
                  {
                     "$ref": "#/definitions/ClassificationDataFormat"
                  }
               ]
            },
            "uri": {
               "title": "Uri",
               "description": "One of the following:\n(1) a URI of a directory containing \"train\", \"valid\", and (optinally) \"test\" subdirectories;\n(2) a URI of a zip file containing (1);\n(3) a list of (2);\n(4) a URI of a directory containing zip files containing (1).",
               "anyOf": [
                  {
                     "type": "string"
                  },
                  {
                     "type": "array",
                     "items": {
                        "type": "string"
                     }
                  }
               ]
            },
            "group_uris": {
               "title": "Group Uris",
               "description": "This can be set instead of uri in order to specify groups of chips. Each element in the list is expected to be an object of the same form accepted by the uri field. The purpose of separating chips into groups is to be able to use the group_train_sz field.",
               "type": "array",
               "items": {
                  "anyOf": [
                     {
                        "type": "string"
                     },
                     {
                        "type": "array",
                        "items": {
                           "type": "string"
                        }
                     }
                  ]
               }
            },
            "group_train_sz": {
               "title": "Group Train Sz",
               "description": "If group_uris is set, this can be used to specify the number of chips to use per group. Only applies to training chips. This can either be a single value that will be used for all groups or a list of values (one for each group).",
               "anyOf": [
                  {
                     "type": "integer"
                  },
                  {
                     "type": "array",
                     "items": {
                        "type": "integer"
                     }
                  }
               ]
            },
            "group_train_sz_rel": {
               "title": "Group Train Sz Rel",
               "description": "Relative version of group_train_sz. Must be a float in [0, 1]. If group_uris is set, this can be used to specify the proportion of the total chips in each group to use per group. Only applies to training chips. This can either be a single value that will be used for all groups or a list of values (one for each group).",
               "anyOf": [
                  {
                     "type": "number",
                     "minimum": 0,
                     "maximum": 1
                  },
                  {
                     "type": "array",
                     "items": {
                        "type": "number",
                        "minimum": 0,
                        "maximum": 1
                     }
                  }
               ]
            }
         },
         "additionalProperties": false
      },
      "ClassConfig": {
         "title": "ClassConfig",
         "description": "Configure class information for a machine learning task.",
         "type": "object",
         "properties": {
            "names": {
               "title": "Names",
               "description": "Names of classes. The i-th class in this list will have class ID = i.",
               "type": "array",
               "items": {
                  "type": "string"
               }
            },
            "colors": {
               "title": "Colors",
               "description": "Colors used to visualize classes. Can be color strings accepted by matplotlib or RGB tuples. If None, a random color will be auto-generated for each class.",
               "type": "array",
               "items": {
                  "anyOf": [
                     {
                        "type": "string"
                     },
                     {
                        "type": "array",
                        "items": {}
                     }
                  ]
               }
            },
            "null_class": {
               "title": "Null Class",
               "description": "Optional name of class in `names` to use as the null class. This is used in semantic segmentation to represent the label for imagery pixels that are NODATA or that are missing a label. If None and the class names include \"null\", it will automatically be used as the null class. If None, and this Config is part of a SemanticSegmentationConfig, a null class will be added automatically.",
               "type": "string"
            },
            "type_hint": {
               "title": "Type Hint",
               "default": "class_config",
               "enum": [
                  "class_config"
               ],
               "type": "string"
            }
         },
         "required": [
            "names"
         ],
         "additionalProperties": false
      },
      "RasterTransformerConfig": {
         "title": "RasterTransformerConfig",
         "description": "Configure a :class:`.RasterTransformer`.",
         "type": "object",
         "properties": {
            "type_hint": {
               "title": "Type Hint",
               "default": "raster_transformer",
               "enum": [
                  "raster_transformer"
               ],
               "type": "string"
            }
         },
         "additionalProperties": false
      },
      "RasterSourceConfig": {
         "title": "RasterSourceConfig",
         "description": "Configure a :class:`.RasterSource`.",
         "type": "object",
         "properties": {
            "channel_order": {
               "title": "Channel Order",
               "description": "The sequence of channel indices to use when reading imagery.",
               "type": "array",
               "items": {
                  "type": "integer"
               }
            },
            "transformers": {
               "title": "Transformers",
               "default": [],
               "type": "array",
               "items": {
                  "$ref": "#/definitions/RasterTransformerConfig"
               }
            },
            "bbox": {
               "title": "Bbox",
               "description": "User-specified bbox in pixel coords in the form (ymin, xmin, ymax, xmax). Useful for cropping the raster source so that only part of the raster is read from.",
               "type": "array",
               "minItems": 4,
               "maxItems": 4,
               "items": [
                  {
                     "type": "integer"
                  },
                  {
                     "type": "integer"
                  },
                  {
                     "type": "integer"
                  },
                  {
                     "type": "integer"
                  }
               ]
            },
            "type_hint": {
               "title": "Type Hint",
               "default": "raster_source",
               "enum": [
                  "raster_source"
               ],
               "type": "string"
            }
         },
         "additionalProperties": false
      },
      "LabelSourceConfig": {
         "title": "LabelSourceConfig",
         "description": "Configure a :class:`.LabelSource`.",
         "type": "object",
         "properties": {
            "type_hint": {
               "title": "Type Hint",
               "default": "label_source",
               "enum": [
                  "label_source"
               ],
               "type": "string"
            }
         },
         "additionalProperties": false
      },
      "LabelStoreConfig": {
         "title": "LabelStoreConfig",
         "description": "Configure a :class:`.LabelStore`.",
         "type": "object",
         "properties": {
            "type_hint": {
               "title": "Type Hint",
               "default": "label_store",
               "enum": [
                  "label_store"
               ],
               "type": "string"
            }
         },
         "additionalProperties": false
      },
      "SceneConfig": {
         "title": "SceneConfig",
         "description": "Configure a :class:`.Scene` comprising raster data & labels for an AOI.\n    ",
         "type": "object",
         "properties": {
            "id": {
               "title": "Id",
               "type": "string"
            },
            "raster_source": {
               "$ref": "#/definitions/RasterSourceConfig"
            },
            "label_source": {
               "$ref": "#/definitions/LabelSourceConfig"
            },
            "label_store": {
               "$ref": "#/definitions/LabelStoreConfig"
            },
            "aoi_uris": {
               "title": "Aoi Uris",
               "description": "List of URIs of GeoJSON files that define the AOIs for the scene. Each polygon defines an AOI which is a piece of the scene that is assumed to be fully labeled and usable for training or validation. The AOIs are assumed to be in EPSG:4326 coordinates.",
               "type": "array",
               "items": {
                  "type": "string"
               }
            },
            "type_hint": {
               "title": "Type Hint",
               "default": "scene",
               "enum": [
                  "scene"
               ],
               "type": "string"
            }
         },
         "required": [
            "id",
            "raster_source"
         ],
         "additionalProperties": false
      },
      "DatasetConfig": {
         "title": "DatasetConfig",
         "description": "Configure train, validation, and test splits for a dataset.",
         "type": "object",
         "properties": {
            "class_config": {
               "$ref": "#/definitions/ClassConfig"
            },
            "train_scenes": {
               "title": "Train Scenes",
               "type": "array",
               "items": {
                  "$ref": "#/definitions/SceneConfig"
               }
            },
            "validation_scenes": {
               "title": "Validation Scenes",
               "type": "array",
               "items": {
                  "$ref": "#/definitions/SceneConfig"
               }
            },
            "test_scenes": {
               "title": "Test Scenes",
               "default": [],
               "type": "array",
               "items": {
                  "$ref": "#/definitions/SceneConfig"
               }
            },
            "scene_groups": {
               "title": "Scene Groups",
               "description": "Groupings of scenes. Should be a dict of the form: {<group-name>: Set(scene_id_1, scene_id_2, ...)}. Three groups are added by default: \"train_scenes\", \"validation_scenes\", and \"test_scenes\"",
               "default": {},
               "type": "object",
               "additionalProperties": {
                  "type": "array",
                  "items": {
                     "type": "string"
                  },
                  "uniqueItems": true
               }
            },
            "type_hint": {
               "title": "Type Hint",
               "default": "dataset",
               "enum": [
                  "dataset"
               ],
               "type": "string"
            }
         },
         "required": [
            "class_config",
            "train_scenes",
            "validation_scenes"
         ],
         "additionalProperties": false
      },
      "GeoDataWindowMethod": {
         "title": "GeoDataWindowMethod",
         "description": "An enumeration.",
         "enum": [
            "sliding",
            "random"
         ]
      },
      "GeoDataWindowConfig": {
         "title": "GeoDataWindowConfig",
         "description": "Configure a :class:`.GeoDataset`.\n\nSee :mod:`rastervision.pytorch_learner.dataset.dataset`.",
         "type": "object",
         "properties": {
            "method": {
               "default": "sliding",
               "allOf": [
                  {
                     "$ref": "#/definitions/GeoDataWindowMethod"
                  }
               ]
            },
            "size": {
               "title": "Size",
               "description": "If method = sliding, this is the size of sliding window. If method = random, this is the size that all the windows are resized to before they are returned. If method = random and neither size_lims nor h_lims and w_lims have been specified, then size_lims is set to (size, size + 1).",
               "anyOf": [
                  {
                     "type": "integer",
                     "exclusiveMinimum": 0
                  },
                  {
                     "type": "array",
                     "minItems": 2,
                     "maxItems": 2,
                     "items": [
                        {
                           "type": "integer",
                           "exclusiveMinimum": 0
                        },
                        {
                           "type": "integer",
                           "exclusiveMinimum": 0
                        }
                     ]
                  }
               ]
            },
            "stride": {
               "title": "Stride",
               "description": "Stride of sliding window. Only used if method = sliding.",
               "anyOf": [
                  {
                     "type": "integer",
                     "exclusiveMinimum": 0
                  },
                  {
                     "type": "array",
                     "minItems": 2,
                     "maxItems": 2,
                     "items": [
                        {
                           "type": "integer",
                           "exclusiveMinimum": 0
                        },
                        {
                           "type": "integer",
                           "exclusiveMinimum": 0
                        }
                     ]
                  }
               ]
            },
            "padding": {
               "title": "Padding",
               "description": "How many pixels are windows allowed to overflow the edges of the raster source.",
               "anyOf": [
                  {
                     "type": "integer",
                     "minimum": 0
                  },
                  {
                     "type": "array",
                     "minItems": 2,
                     "maxItems": 2,
                     "items": [
                        {
                           "type": "integer",
                           "minimum": 0
                        },
                        {
                           "type": "integer",
                           "minimum": 0
                        }
                     ]
                  }
               ]
            },
            "pad_direction": {
               "title": "Pad Direction",
               "description": "If \"end\", only pad ymax and xmax (bottom and right). If \"start\", only pad ymin and xmin (top and left). If \"both\", pad all sides. Has no effect if paddiong is zero. Defaults to \"end\".",
               "default": "end",
               "enum": [
                  "both",
                  "start",
                  "end"
               ],
               "type": "string"
            },
            "size_lims": {
               "title": "Size Lims",
               "description": "[min, max) interval from which window sizes will be uniformly randomly sampled. The upper limit is exclusive. To fix the size to a constant value, use size_lims = (sz, sz + 1). Only used if method = random. Specify either size_lims or h_lims and w_lims, but not both. If neither size_lims nor h_lims and w_lims have been specified, then this will be set to (size, size + 1).",
               "type": "array",
               "minItems": 2,
               "maxItems": 2,
               "items": [
                  {
                     "type": "integer",
                     "exclusiveMinimum": 0
                  },
                  {
                     "type": "integer",
                     "exclusiveMinimum": 0
                  }
               ]
            },
            "h_lims": {
               "title": "H Lims",
               "description": "[min, max] interval from which window heights will be uniformly randomly sampled. Only used if method = random.",
               "type": "array",
               "minItems": 2,
               "maxItems": 2,
               "items": [
                  {
                     "type": "integer",
                     "exclusiveMinimum": 0
                  },
                  {
                     "type": "integer",
                     "exclusiveMinimum": 0
                  }
               ]
            },
            "w_lims": {
               "title": "W Lims",
               "description": "[min, max] interval from which window widths will be uniformly randomly sampled. Only used if method = random.",
               "type": "array",
               "minItems": 2,
               "maxItems": 2,
               "items": [
                  {
                     "type": "integer",
                     "exclusiveMinimum": 0
                  },
                  {
                     "type": "integer",
                     "exclusiveMinimum": 0
                  }
               ]
            },
            "max_windows": {
               "title": "Max Windows",
               "description": "Max allowed reads from a GeoDataset. Only used if method = random.",
               "default": 10000,
               "minimum": 0,
               "type": "integer"
            },
            "max_sample_attempts": {
               "title": "Max Sample Attempts",
               "description": "Max attempts when trying to find a window within the AOI of a scene. Only used if method = random and the scene has aoi_polygons specified.",
               "default": 100,
               "exclusiveMinimum": 0,
               "type": "integer"
            },
            "efficient_aoi_sampling": {
               "title": "Efficient Aoi Sampling",
               "description": "If the scene has AOIs, sampling windows at random anywhere in the extent and then checking if they fall within any of the AOIs can be very inefficient. This flag enables the use of an alternate algorithm that only samples window locations inside the AOIs. Only used if method = random and the scene has aoi_polygons specified. Defaults to True",
               "default": true,
               "type": "boolean"
            },
            "type_hint": {
               "title": "Type Hint",
               "default": "geo_data_window",
               "enum": [
                  "geo_data_window"
               ],
               "type": "string"
            }
         },
         "required": [
            "size"
         ],
         "additionalProperties": false
      },
      "ClassificationGeoDataConfig": {
         "title": "ClassificationGeoDataConfig",
         "description": "Configure classification :class:`GeoDatasets <.GeoDataset>`.\n\nSee :mod:`rastervision.pytorch_learner.dataset.classification_dataset`.",
         "type": "object",
         "properties": {
            "class_names": {
               "title": "Class Names",
               "description": "Names of classes.",
               "default": [],
               "type": "array",
               "items": {
                  "type": "string"
               }
            },
            "class_colors": {
               "title": "Class Colors",
               "description": "Colors used to display classes. Can be color 3-tuples in list form.",
               "type": "array",
               "items": {
                  "anyOf": [
                     {
                        "type": "string"
                     },
                     {
                        "type": "array",
                        "minItems": 3,
                        "maxItems": 3,
                        "items": [
                           {
                              "type": "integer"
                           },
                           {
                              "type": "integer"
                           },
                           {
                              "type": "integer"
                           }
                        ]
                     }
                  ]
               }
            },
            "img_channels": {
               "title": "Img Channels",
               "description": "The number of channels of the training images.",
               "exclusiveMinimum": 0,
               "type": "integer"
            },
            "img_sz": {
               "title": "Img Sz",
               "description": "Length of a side of each image in pixels. This is the size to transform it to during training, not the size in the raw dataset.",
               "default": 256,
               "exclusiveMinimum": 0,
               "type": "integer"
            },
            "train_sz": {
               "title": "Train Sz",
               "description": "If set, the number of training images to use. If fewer images exist, then an exception will be raised.",
               "type": "integer"
            },
            "train_sz_rel": {
               "title": "Train Sz Rel",
               "description": "If set, the proportion of training images to use.",
               "type": "number"
            },
            "num_workers": {
               "title": "Num Workers",
               "description": "Number of workers to use when DataLoader makes batches.",
               "default": 4,
               "type": "integer"
            },
            "augmentors": {
               "title": "Augmentors",
               "description": "Names of albumentations augmentors to use for training batches. Choices include: ['Blur', 'RandomRotate90', 'HorizontalFlip', 'VerticalFlip', 'GaussianBlur', 'GaussNoise', 'RGBShift', 'ToGray']. Alternatively, a custom transform can be provided via the aug_transform option.",
               "default": [
                  "RandomRotate90",
                  "HorizontalFlip",
                  "VerticalFlip"
               ],
               "type": "array",
               "items": {
                  "type": "string"
               }
            },
            "base_transform": {
               "title": "Base Transform",
               "description": "An Albumentations transform serialized as a dict that will be applied to all datasets: training, validation, and test. This transformation is in addition to the resizing due to img_sz. This is useful for, for example, applying the same normalization to all datasets.",
               "type": "object"
            },
            "aug_transform": {
               "title": "Aug Transform",
               "description": "An Albumentations transform serialized as a dict that will be applied as data augmentation to the training dataset. This transform is applied before base_transform. If provided, the augmentors option is ignored.",
               "type": "object"
            },
            "plot_options": {
               "title": "Plot Options",
               "description": "Options to control plotting.",
               "default": {
                  "transform": {
                     "__version__": "1.3.0",
                     "transform": {
                        "__class_fullname__": "rastervision.pytorch_learner.utils.utils.MinMaxNormalize",
                        "always_apply": false,
                        "p": 1.0,
                        "min_val": 0.0,
                        "max_val": 1.0,
                        "dtype": 5
                     }
                  },
                  "channel_display_groups": null,
                  "type_hint": "plot_options"
               },
               "allOf": [
                  {
                     "$ref": "#/definitions/PlotOptions"
                  }
               ]
            },
            "preview_batch_limit": {
               "title": "Preview Batch Limit",
               "description": "Optional limit on the number of items in the preview plots produced during training.",
               "type": "integer"
            },
            "type_hint": {
               "title": "Type Hint",
               "default": "classification_geo_data",
               "enum": [
                  "classification_geo_data"
               ],
               "type": "string"
            },
            "scene_dataset": {
               "$ref": "#/definitions/DatasetConfig"
            },
            "window_opts": {
               "title": "Window Opts",
               "default": {},
               "anyOf": [
                  {
                     "$ref": "#/definitions/GeoDataWindowConfig"
                  },
                  {
                     "type": "object",
                     "additionalProperties": {
                        "$ref": "#/definitions/GeoDataWindowConfig"
                     }
                  }
               ]
            }
         },
         "additionalProperties": false
      }
   }
}

Config
  • extra: str = forbid

  • validate_assignment: bool = True

Fields
Validators
  • update_for_mode » all fields

  • validate_class_loss_weights » all fields

  • validate_run_tensorboard » run_tensorboard

field data: Union[ClassificationImageDataConfig, ClassificationGeoDataConfig] [Required]#
Validated by
  • update_for_mode

  • validate_class_loss_weights

field eval_train: bool = False#

If True, runs final evaluation on training set (in addition to test set). Useful for debugging.

Validated by
  • update_for_mode

  • validate_class_loss_weights

field log_tensorboard: bool = True#

Save Tensorboard log files at the end of each epoch.

Validated by
  • update_for_mode

  • validate_class_loss_weights

field model: Optional[ClassificationModelConfig] = None#
Validated by
  • update_for_mode

  • validate_class_loss_weights

field output_uri: Optional[str] = None#

URI of where to save output

Validated by
  • update_for_mode

  • validate_class_loss_weights

field overfit_mode: bool = False#

If True, uses half image size, and instead of doing epoch-based training, optimizes the model using a single batch repeatedly for overfit_num_steps number of steps.

Validated by
  • update_for_mode

  • validate_class_loss_weights

field predict_mode: bool = False#

If True, skips training, loads model, and does final eval.

Validated by
  • update_for_mode

  • validate_class_loss_weights

field run_tensorboard: bool = False#

run Tensorboard server during training

Validated by
  • update_for_mode

  • validate_class_loss_weights

  • validate_run_tensorboard

field save_all_checkpoints: bool = False#

If True, all checkpoints would be saved. The latest checkpoint would be saved as last-model.pth. The checkpoints prior to last epoch are stored as model-ckpt-epoch-{N}.pth where N is the epoch number.

Validated by
  • update_for_mode

  • validate_class_loss_weights

field save_model_bundle: bool = True#

If True, saves a model bundle at the end of training which is zip file with model and this LearnerConfig which can be used to make predictions on new images at a later time.

Validated by
  • update_for_mode

  • validate_class_loss_weights

field solver: SolverConfig [Required]#
Validated by
  • update_for_mode

  • validate_class_loss_weights

field test_mode: bool = False#

If True, uses test_num_epochs, test_batch_sz, truncated datasets with only a single batch, image_sz that is cut in half, and num_workers = 0. This is useful for testing that code runs correctly on CPU without multithreading before running full job on GPU.

Validated by
  • update_for_mode

  • validate_class_loss_weights

field type_hint: Literal['classification_learner'] = 'classification_learner'#
Validated by
  • update_for_mode

  • validate_class_loss_weights

build(tmp_dir=None, model_weights_path=None, model_def_path=None, loss_def_path=None, training=True)[source]#

Returns a Learner instantiated using this Config.

Parameters
  • tmp_dir (str) – Root of temp dirs.

  • model_weights_path (str, optional) – A local path to model weights. Defaults to None.

  • model_def_path (str, optional) – A local path to a directory with a hubconf.py. If provided, the model definition is imported from here. Defaults to None.

  • loss_def_path (str, optional) – A local path to a directory with a hubconf.py. If provided, the loss function definition is imported from here. Defaults to None.

  • training (bool, optional) – Whether the model is to be used for training or prediction. If False, the model is put in eval mode and the loss function, optimizer, etc. are not initialized. Defaults to True.

get_model_bundle_uri() str#

Returns the URI of where the model bundle is stored.

Return type

str

recursive_validate_config()#

Recursively validate hierarchies of Configs.

This uses reflection to call validate_config on a hierarchy of Configs using a depth-first pre-order traversal.

revalidate()#

Re-validate an instantiated Config.

Runs all Pydantic validators plus self.validate_config().

Adapted from: https://github.com/samuelcolvin/pydantic/issues/1864#issuecomment-679044432

update(*args, **kwargs)#

Update any fields before validation.

Subclasses should override this to provide complex default behavior, for example, setting default values as a function of the values of other fields. The arguments to this method will vary depending on the type of Config.

validator update_for_mode  »  all fields#
Parameters

values (dict) –

Return type

dict

validator validate_class_loss_weights  »  all fields#
Parameters

values (dict) –

Return type

dict

validate_config()#

Validate fields that should be checked after update is called.

This is to complement the builtin validation that Pydantic performs at the time of object construction.

validate_list(field: str, valid_options: List[str])#

Validate a list field.

Parameters
  • field (str) – name of field to validate

  • valid_options (List[str]) – values that field is allowed to take

Raises

ConfigError – if field is invalid

validator validate_run_tensorboard  »  run_tensorboard#
Parameters
Return type

bool