ObjectDetectionLearner#

class ObjectDetectionLearner[source]#

Bases: Learner

Attributes

__init__(cfg: LearnerConfig, output_dir: Optional[str] = None, train_ds: Optional[Dataset] = None, valid_ds: Optional[Dataset] = None, test_ds: Optional[Dataset] = None, model: Optional[torch.nn.Module] = None, loss: Optional[Callable] = None, optimizer: Optional[Optimizer] = None, epoch_scheduler: Optional[_LRScheduler] = None, step_scheduler: Optional[_LRScheduler] = None, tmp_dir: Optional[str] = None, model_weights_path: Optional[str] = None, model_def_path: Optional[str] = None, loss_def_path: Optional[str] = None, training: bool = True)#

Constructor.

Parameters
  • cfg (LearnerConfig) – LearnerConfig.

  • train_ds (Optional[Dataset], optional) – The dataset to use for training. If None, will be generated from cfg.data. Defaults to None.

  • valid_ds (Optional[Dataset], optional) – The dataset to use for validation. If None, will be generated from cfg.data. Defaults to None.

  • test_ds (Optional[Dataset], optional) – The dataset to use for testing. If None, will be generated from cfg.data. Defaults to None.

  • model (Optional[nn.Module], optional) – The model. If None, will be generated from cfg.model. Defaults to None.

  • loss (Optional[Callable], optional) – The loss function. If None, will be generated from cfg.solver. Defaults to None.

  • optimizer (Optional[Optimizer], optional) – The optimizer. If None, will be generated from cfg.solver. Defaults to None.

  • epoch_scheduler (Optional[_LRScheduler], optional) – The scheduler that updates after each epoch. If None, will be generated from cfg.solver. Defaults to None.

  • step_scheduler (Optional[_LRScheduler], optional) – The scheduler that updates after each optimizer-step. If None, will be generated from cfg.solver. Defaults to None.

  • tmp_dir (Optional[str], optional) – A temporary directory to use for downloads etc. If None, will be auto-generated. Defaults to None.

  • model_weights_path (Optional[str], optional) – URI of model weights to initialize the model with. Defaults to None.

  • model_def_path (Optional[str], optional) – A local path to a directory with a hubconf.py. If provided, the model definition is imported from here. This is used when loading an external model from a model-bundle. Defaults to None.

  • loss_def_path (Optional[str], optional) – A local path to a directory with a hubconf.py. If provided, the loss function definition is imported from here. This is used when loading an external loss function from a model-bundle. Defaults to None.

  • training (bool, optional) – If False, the training apparatus (loss, optimizer, scheduler, logging, etc.) will not be set up and the model will be put into eval mode. If True, the training apparatus will be set up and the model will be put into training mode. Defaults to True.

  • output_dir (Optional[str]) –

Methods

__init__(cfg[, output_dir, train_ds, ...])

Constructor.

build_dataloader(split[, distributed])

Build DataLoader for split.

build_dataloaders([distributed])

Build DataLoaders for train, validation, and test splits.

build_dataset(split)

Build Dataset for split.

build_datasets()

Build Datasets for train, validation, and test splits.

build_epoch_scheduler([start_epoch])

Returns an LR scheduler that changes the LR each epoch.

build_loss([loss_def_path])

Build a loss Callable.

build_model([model_def_path])

Override to pass img_sz.

build_optimizer()

Returns optimizer.

build_sampler(ds, split[, distributed])

Build an optional sampler for the split's dataloader.

build_step_scheduler([start_epoch])

Returns an LR scheduler that changes the LR each step.

ddp([rank, world_size])

Return a DDPContextManager.

export_to_onnx(path[, model, sample_input])

Export model to ONNX format via torch.onnx.export().

from_model_bundle(model_bundle_uri[, ...])

Create a Learner from a model bundle.

get_collate_fn()

Returns a custom collate_fn to use in DataLoader.

get_dataloader(split)

Get the DataLoader for a split.

get_dataset(split)

Get the Dataset for a split.

get_start_and_end_epochs([epochs])

Get start and end epochs given epochs.

get_start_epoch()

Get start epoch.

get_visualizer_class()

Returns a Visualizer class object for plotting data samples.

load_checkpoint()

Load last weights from previous run if available.

load_init_weights([model_weights_path])

Load the weights to initialize model.

load_onnx_model(model_path)

load_weights(uri, **kwargs)

Load model weights from a file.

log_data_stats()

Log stats about each DataSet.

main()

Main training sequence.

normalize_input(x)

Normalize x to [0, 1].

on_epoch_end(curr_epoch, metrics)

Hook that is called at end of epoch.

on_train_start()

Hook that is called at start of train routine.

output_to_numpy(out)

Convert output of model to numpy format.

plot_dataloader(dl, output_path[, ...])

Plot images and ground truth labels for a DataLoader.

plot_dataloaders([batch_limit, show])

Plot images and ground truth labels for all DataLoaders.

plot_predictions(split[, batch_limit, show])

Plot predictions for a split.

post_forward(x)

Post process output of call to model().

postprocess_model_output(x, out_batch, out_shape)

predict(x[, raw_out, out_shape])

Make prediction for an image or batch of images.

predict_dataloader(dl[, batched_output, ...])

Returns an iterator over predictions on the given dataloader.

predict_dataset(dataset[, return_format, ...])

Returns an iterator over predictions on the given dataset.

predict_onnx(x[, raw_out, out_shape])

Alternative to predict() for ONNX inference.

prob_to_pred(x)

Convert a Tensor with prediction probabilities to class ids.

reduce_distributed_metrics(metrics)

Average numeric metrics across processes.

run_tensorboard()

Run TB server serving logged stats.

save_model_bundle([export_onnx])

Save a model bundle.

save_weights(path)

Save model weights to a local file.

setup_data([distributed])

Set datasets and dataLoaders for train, validation, and test sets.

setup_ddp_params()

Set up and validate params related to PyTorch DDP.

setup_loss([loss_def_path])

Setup self.loss.

setup_model([model_weights_path, model_def_path])

Override to apply the TorchVisionODAdapter wrapper.

setup_tensorboard()

Setup for logging stats to TB.

setup_training([loss_def_path])

Set up model, data, loss, optimizers and various paths.

stop_tensorboard()

Stop TB logging and server if it's running.

sync_from_cloud()

Sync any previous output in the cloud to output_dir.

sync_to_cloud()

Sync any output to the cloud at output_uri.

to_batch(x)

Ensure that image array has batch dimension.

to_device(x, device)

Load Tensors onto a device.

train([epochs])

Run training loop, resuming training if appropriate

train_end(outputs)

Aggregate the output of train_step at the end of the epoch.

train_epoch(optimizer[, dataloader, ...])

Train for a single epoch.

train_step(batch, batch_ind)

Compute loss for a single training batch.

validate([split])

Evaluate model on a particular data split.

validate_end(outputs)

Aggregate the output of validate_step at the end of the epoch.

validate_epoch(dl)

Validate for a single epoch.

validate_step(batch, batch_ind)

Compute metrics on validation batch.

__init__(cfg: LearnerConfig, output_dir: Optional[str] = None, train_ds: Optional[Dataset] = None, valid_ds: Optional[Dataset] = None, test_ds: Optional[Dataset] = None, model: Optional[torch.nn.Module] = None, loss: Optional[Callable] = None, optimizer: Optional[Optimizer] = None, epoch_scheduler: Optional[_LRScheduler] = None, step_scheduler: Optional[_LRScheduler] = None, tmp_dir: Optional[str] = None, model_weights_path: Optional[str] = None, model_def_path: Optional[str] = None, loss_def_path: Optional[str] = None, training: bool = True)#

Constructor.

Parameters
  • cfg (LearnerConfig) – LearnerConfig.

  • train_ds (Optional[Dataset], optional) – The dataset to use for training. If None, will be generated from cfg.data. Defaults to None.

  • valid_ds (Optional[Dataset], optional) – The dataset to use for validation. If None, will be generated from cfg.data. Defaults to None.

  • test_ds (Optional[Dataset], optional) – The dataset to use for testing. If None, will be generated from cfg.data. Defaults to None.

  • model (Optional[nn.Module], optional) – The model. If None, will be generated from cfg.model. Defaults to None.

  • loss (Optional[Callable], optional) – The loss function. If None, will be generated from cfg.solver. Defaults to None.

  • optimizer (Optional[Optimizer], optional) – The optimizer. If None, will be generated from cfg.solver. Defaults to None.

  • epoch_scheduler (Optional[_LRScheduler], optional) – The scheduler that updates after each epoch. If None, will be generated from cfg.solver. Defaults to None.

  • step_scheduler (Optional[_LRScheduler], optional) – The scheduler that updates after each optimizer-step. If None, will be generated from cfg.solver. Defaults to None.

  • tmp_dir (Optional[str], optional) – A temporary directory to use for downloads etc. If None, will be auto-generated. Defaults to None.

  • model_weights_path (Optional[str], optional) – URI of model weights to initialize the model with. Defaults to None.

  • model_def_path (Optional[str], optional) – A local path to a directory with a hubconf.py. If provided, the model definition is imported from here. This is used when loading an external model from a model-bundle. Defaults to None.

  • loss_def_path (Optional[str], optional) – A local path to a directory with a hubconf.py. If provided, the loss function definition is imported from here. This is used when loading an external loss function from a model-bundle. Defaults to None.

  • training (bool, optional) – If False, the training apparatus (loss, optimizer, scheduler, logging, etc.) will not be set up and the model will be put into eval mode. If True, the training apparatus will be set up and the model will be put into training mode. Defaults to True.

  • output_dir (Optional[str]) –

build_dataloader(split: Literal['train', 'valid', 'test'], distributed: Optional[bool] = None, **kwargs) torch.utils.data.DataLoader#

Build DataLoader for split.

Parameters
Return type

torch.utils.data.DataLoader

build_dataloaders(distributed: Optional[bool] = None) Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, torch.utils.data.DataLoader]#

Build DataLoaders for train, validation, and test splits.

Parameters

distributed (Optional[bool]) –

Return type

Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, torch.utils.data.DataLoader]

build_dataset(split: Literal['train', 'valid', 'test']) Dataset#

Build Dataset for split.

Parameters

split (Literal['train', 'valid', 'test']) –

Return type

Dataset

build_datasets() Tuple[Dataset, Dataset, Dataset]#

Build Datasets for train, validation, and test splits.

Return type

Tuple[Dataset, Dataset, Dataset]

build_epoch_scheduler(start_epoch: int = 0) _LRScheduler#

Returns an LR scheduler that changes the LR each epoch.

Parameters

start_epoch (int) –

Return type

_LRScheduler

build_loss(loss_def_path: Optional[str] = None) Callable#

Build a loss Callable.

Parameters

loss_def_path (Optional[str]) –

Return type

Callable

build_model(model_def_path: Optional[str] = None) nn.Module[source]#

Override to pass img_sz.

Parameters

model_def_path (Optional[str]) –

Return type

nn.Module

build_optimizer() Optimizer#

Returns optimizer.

Return type

Optimizer

build_sampler(ds: Dataset, split: Literal['train', 'valid', 'test'], distributed: bool = False) Optional[Sampler]#

Build an optional sampler for the split’s dataloader.

Parameters
  • ds (Dataset) –

  • split (Literal['train', 'valid', 'test']) –

  • distributed (bool) –

Return type

Optional[Sampler]

build_step_scheduler(start_epoch: int = 0) _LRScheduler#

Returns an LR scheduler that changes the LR each step.

Parameters

start_epoch (int) –

Return type

_LRScheduler

ddp(rank: Optional[int] = None, world_size: Optional[int] = None) DDPContextManager#

Return a DDPContextManager.

This should be used to wrap code that needs to be executed in parallel. It is safe call this recursively; recusive calls will have no affect.

Note that DDPContextManager does not start processes itself, but merely initializes and destroyes DDP process groups.

Usage:

with learner.ddp([rank], [world_size]):
    ...
Parameters
Return type

DDPContextManager

export_to_onnx(path: str, model: Optional[nn.Module] = None, sample_input: Optional[torch.Tensor] = None, **kwargs) None[source]#

Export model to ONNX format via torch.onnx.export().

Parameters
  • path (str) – File path to save to.

  • model (Optional[nn.Module]) – The model to export. If None, self.model will be used. Defaults to None.

  • sample_input (Optional[Tensor]) – Sample input to the model. If None, a single batch from any available DataLoader in this Learner will be used. Defaults to None.

  • validate_export (bool) – If True, use onnx.checker.check_model() to validate exported model. An exception is raised if the check fails. Defaults to True.

  • **kwargs (dict) – Keyword args to pass to torch.onnx.export(). These override the default values used in the function definition.

Raises

ValueError – If sample_input is None and the Learner has no valid DataLoaders.

Return type

None

classmethod from_model_bundle(model_bundle_uri: str, tmp_dir: Optional[str] = None, cfg: Optional[LearnerConfig] = None, training: bool = False, use_onnx_model: Optional[bool] = None, **kwargs) Learner#

Create a Learner from a model bundle.

Note

This is the bundle saved in train/model-bundle.zip and not bundle/model-bundle.zip.

Parameters
  • model_bundle_uri (str) – URI of the model bundle.

  • tmp_dir (Optional[str], optional) – Optional temporary directory. Will be used for unzipping bundle and also passed to the default constructor. If None, will be auto-generated. Defaults to None.

  • cfg (Optional[LearnerConfig], optional) – If None, will be read from the bundle. Defaults to None.

  • training (bool, optional) – If False, the training apparatus (loss, optimizer, scheduler, logging, etc.) will not be set up and the model will be put into eval mode. If True, the training apparatus will be set up and the model will be put into training mode. Defaults to True.

  • use_onnx_model (Optional[bool]) – If True and training=False and a model.onnx file is available in the bundle, use that for inference rather than the PyTorch weights. Defaults to the boolean environment variable RASTERVISION_USE_ONNX if set, False otherwise.

  • **kwargs – Extra args for __init__().

Raises

FileNotFoundError – If using custom Albumentations transforms and definition file is not found in bundle.

Returns

Object of the Learner subclass on which this was called.

Return type

Learner

get_collate_fn()[source]#

Returns a custom collate_fn to use in DataLoader.

None is returned if default collate_fn should be used.

See https://pytorch.org/docs/stable/data.html#working-with-collate-fn

get_dataloader(split: Literal['train', 'valid', 'test']) torch.utils.data.DataLoader#

Get the DataLoader for a split.

Parameters

split (Literal['train', 'valid', 'test']) – a split name which can be train, valid, or test

Return type

torch.utils.data.DataLoader

get_dataset(split: Literal['train', 'valid', 'test']) Optional[torch.utils.data.DataLoader]#

Get the Dataset for a split.

Parameters

split (Literal['train', 'valid', 'test']) – a split name which can be train, valid, or test

Return type

Optional[torch.utils.data.DataLoader]

get_start_and_end_epochs(epochs: Optional[int] = None) Tuple[int, int]#

Get start and end epochs given epochs.

Parameters

epochs (Optional[int]) –

Return type

Tuple[int, int]

get_start_epoch() int#

Get start epoch.

If training was interrupted, this returns the last complete epoch + 1.

Return type

int

get_visualizer_class()[source]#

Returns a Visualizer class object for plotting data samples.

load_checkpoint()#

Load last weights from previous run if available.

load_init_weights(model_weights_path: Optional[str] = None) None#

Load the weights to initialize model.

Parameters

model_weights_path (Optional[str]) –

Return type

None

load_onnx_model(model_path: str) ONNXRuntimeAdapterForFasterRCNN[source]#
Parameters

model_path (str) –

Return type

ONNXRuntimeAdapterForFasterRCNN

load_weights(uri: str, **kwargs) None#

Load model weights from a file.

Parameters
  • uri (str) – URI.

  • **kwargs – Extra args for nn.Module.load_state_dict().

Return type

None

log_data_stats()#

Log stats about each DataSet.

main()#

Main training sequence.

This plots the dataset, runs a training and validation loop (which will resume if interrupted), logs stats, plots predictions, and syncs results to the cloud.

normalize_input(x: ndarray) ndarray#

Normalize x to [0, 1].

If x.dtype is a subtype of np.unsignedinteger, normalize it to [0, 1] using the max possible value of that dtype. Otherwise, assume it is in [0, 1] already and do nothing.

Parameters

x (np.ndarray) – an image or batch of images

Returns

the same array scaled to [0, 1].

Return type

ndarray

on_epoch_end(curr_epoch: int, metrics: Dict[str, float]) None#

Hook that is called at end of epoch.

Writes metrics to CSV and TensorBoard, and saves model.

Parameters
Return type

None

on_train_start()#

Hook that is called at start of train routine.

output_to_numpy(out: Iterable[BoxList]) Union[Dict[str, ndarray], List[Dict[str, ndarray]]][source]#

Convert output of model to numpy format.

Parameters

out (Iterable[BoxList]) – the output of the model in PyTorch format

Return type

Union[Dict[str, ndarray], List[Dict[str, ndarray]]]

Returns: the output of the model in numpy format

plot_dataloader(dl: torch.utils.data.DataLoader, output_path: str, batch_limit: Optional[int] = None, show: bool = False)#

Plot images and ground truth labels for a DataLoader.

Parameters
plot_dataloaders(batch_limit: Optional[int] = None, show: bool = False)#

Plot images and ground truth labels for all DataLoaders.

Parameters
plot_predictions(split: Literal['train', 'valid', 'test'], batch_limit: Optional[int] = None, show: bool = False)#

Plot predictions for a split.

Uses the first batch for the corresponding DataLoader.

Parameters
  • split (Literal['train', 'valid', 'test']) – dataset split. Can be train, valid, or test.

  • batch_limit (Optional[int]) – optional limit on (rendered) batch size

  • show (bool) –

post_forward(x: Any) Any#

Post process output of call to model().

Useful for when predictions are inside a structure returned by model().

Parameters

x (Any) –

Return type

Any

postprocess_model_output(x: Tensor, out_batch: torch.Tensor, out_shape: Tuple[int, int])[source]#
Parameters
predict(x: Tensor, raw_out: bool = False, out_shape: Optional[Tuple[int, int]] = None) BoxList[source]#

Make prediction for an image or batch of images.

Parameters
  • x (Tensor) – Image or batch of images as a float Tensor with pixel values normalized to [0, 1].

  • raw_out (bool, optional) – If True, return prediction probabilities. Defaults to False.

  • out_shape (Optional[Tuple[int, int]], optional) – If provided, boxes are resized such that they reference pixel coordinates in an image of this shape. Defaults to None.

Returns

Predicted boxes.

Return type

BoxList

predict_dataloader(dl: torch.utils.data.DataLoader, batched_output: bool = True, return_format: Literal['xyz', 'yz', 'z'] = 'z', raw_out: bool = True, predict_kw: dict = {}) Union[Iterator[Any], Iterator[Tuple[Any, ...]]]#

Returns an iterator over predictions on the given dataloader.

Parameters
  • dl (DataLoader) – The dataloader to make predictions on.

  • batched_output (bool, optional) – If True, return batches of x, y, z as defined by the dataloader. If False, unroll the batches into individual items. Defaults to True.

  • return_format (Literal['xyz', 'yz', 'z'], optional) – Format of the return elements of the returned iterator. Must be one of: ‘xyz’, ‘yz’, and ‘z’. If ‘xyz’, elements are 3-tuples of x, y, and z. If ‘yz’, elements are 2-tuples of y and z. If ‘z’, elements are (non-tuple) values of z. Where x = input image, y = ground truth, and z = prediction. Defaults to ‘z’.

  • raw_out (bool, optional) – If true, return raw predicted scores. Defaults to True.

  • predict_kw (dict) – Dict with keywords passed to Learner.predict(). Useful if a Learner subclass implements a custom predict() method.

Raises

ValueError – If return_format is not one of the allowed values.

Returns

If return_format

is ‘z’, the returned value is an iterator of whatever type the predictions are. Otherwise, the returned value is an iterator of tuples.

Return type

Union[Iterator[Any], Iterator[Tuple[Any, …]]]

predict_dataset(dataset: Dataset, return_format: Literal['xyz', 'yz', 'z'] = 'z', raw_out: bool = True, numpy_out: bool = False, predict_kw: dict = {}, dataloader_kw: dict = {}, progress_bar: bool = True, progress_bar_kw: dict = {}) Union[Iterator[Any], Iterator[Tuple[Any, ...]]]#

Returns an iterator over predictions on the given dataset.

Parameters
  • dataset (Dataset) – The dataset to make predictions on.

  • return_format (Literal['xyz', 'yz', 'z'], optional) – Format of the return elements of the returned iterator. Must be one of: ‘xyz’, ‘yz’, and ‘z’. If ‘xyz’, elements are 3-tuples of x, y, and z. If ‘yz’, elements are 2-tuples of y and z. If ‘z’, elements are (non-tuple) values of z. Where x = input image, y = ground truth, and z = prediction. Defaults to ‘z’.

  • raw_out (bool, optional) – If true, return raw predicted scores. Defaults to True.

  • numpy_out (bool, optional) – If True, convert predictions to numpy arrays before returning. Defaults to False.

  • predict_kw (dict) – Dict with keywords passed to Learner.predict(). Useful if a Learner subclass implements a custom predict() method.

  • dataloader_kw (dict) – Dict with keywords passed to the DataLoader constructor.

  • progress_bar (bool, optional) – If True, display a progress bar. Since this function returns an iterator, the progress bar won’t be visible until the iterator is consumed. Defaults to True.

  • progress_bar_kw (dict) – Dict with keywords passed to tqdm.

Raises

ValueError – If return_format is not one of the allowed values.

Returns

If return_format is ‘z’, the returned value is an iterator of whatever type the predictions are. Otherwise, the returned value is an iterator of tuples.

Return type

Union[Iterator[Any], Iterator[Tuple[Any, …]]]

predict_onnx(x: Tensor, raw_out: bool = False, out_shape: Optional[Tuple[int, int]] = None) BoxList[source]#

Alternative to predict() for ONNX inference.

Parameters
Return type

BoxList

prob_to_pred(x)[source]#

Convert a Tensor with prediction probabilities to class ids.

The class ids should be the classes with the maximum probability.

reduce_distributed_metrics(metrics: dict)#

Average numeric metrics across processes.

Parameters

metrics (dict) –

run_tensorboard()#

Run TB server serving logged stats.

save_model_bundle(export_onnx: bool = True)#

Save a model bundle.

This is a zip file with the model weights in .pth format and a serialized copy of the LearningConfig, which allows for making predictions in the future.

Parameters

export_onnx (bool) –

save_weights(path: str)#

Save model weights to a local file.

Parameters

path (str) –

setup_data(distributed: Optional[bool] = None)#

Set datasets and dataLoaders for train, validation, and test sets.

Parameters

distributed (Optional[bool]) –

setup_ddp_params()#

Set up and validate params related to PyTorch DDP.

setup_loss(loss_def_path: Optional[str] = None) None#

Setup self.loss.

Parameters
  • loss_def_path (str, optional) – Loss definition path. Will be

  • None. (available when loading from a bundle. Defaults to) –

Return type

None

setup_model(model_weights_path: Optional[str] = None, model_def_path: Optional[str] = None) None[source]#

Override to apply the TorchVisionODAdapter wrapper.

Parameters
Return type

None

setup_tensorboard()#

Setup for logging stats to TB.

setup_training(loss_def_path: Optional[str] = None) None#

Set up model, data, loss, optimizers and various paths.

The exact behavior differs based on whether this method is called in a distributed scenario.

Parameters

loss_def_path (Optional[str]) – A local path to a directory with a hubconf.py. If provided, the loss function definition is imported from here. This is used when loading an external loss function from a model-bundle. Defaults to None.

Return type

None

stop_tensorboard()#

Stop TB logging and server if it’s running.

sync_from_cloud()#

Sync any previous output in the cloud to output_dir.

sync_to_cloud()#

Sync any output to the cloud at output_uri.

to_batch(x: torch.Tensor) torch.Tensor#

Ensure that image array has batch dimension.

Parameters

x (torch.Tensor) – assumed to be either image or batch of images

Returns

x with extra batch dimension of length 1 if needed

Return type

torch.Tensor

to_device(x: Any, device: Union[str, torch.device]) Any#

Load Tensors onto a device.

Parameters
Returns

x but with any Tensors in it on the device

Return type

Any

train(epochs: Optional[int] = None)#

Run training loop, resuming training if appropriate

Parameters

epochs (Optional[int]) –

train_end(outputs: List[Dict[str, Union[float, torch.Tensor]]]) Dict[str, float]#

Aggregate the output of train_step at the end of the epoch.

Parameters

outputs (List[Dict[str, Union[float, torch.Tensor]]]) – a list of outputs of train_step

Return type

Dict[str, float]

train_epoch(optimizer: Optimizer, dataloader: Optional[torch.utils.data.DataLoader] = None, step_scheduler: Optional[_LRScheduler] = None) Dict[str, float]#

Train for a single epoch.

Parameters
Return type

Dict[str, float]

train_step(batch, batch_ind)[source]#

Compute loss for a single training batch.

Parameters
  • batch – batch data needed to compute loss

  • batch_ind – index of batch within epoch

Returns

dict with ‘train_loss’ as key and possibly other losses

validate(split: Literal['train', 'valid', 'test'] = 'valid')#

Evaluate model on a particular data split.

Parameters

split (Literal['train', 'valid', 'test']) –

validate_end(outputs)[source]#

Aggregate the output of validate_step at the end of the epoch.

Parameters

outputs – a list of outputs of validate_step

validate_epoch(dl: torch.utils.data.DataLoader) Dict[str, float]#

Validate for a single epoch.

Parameters

dl (torch.utils.data.DataLoader) –

Return type

Dict[str, float]

validate_step(batch, batch_ind)[source]#

Compute metrics on validation batch.

Parameters
  • batch – batch data needed to compute validation metrics

  • batch_ind – index of batch within epoch

Returns

dict with metric names mapped to metric values

property onnx_mode: bool#