SemanticSegmentationLearner#

class SemanticSegmentationLearner[source]#

Bases: Learner

Attributes

__init__(cfg: LearnerConfig, output_dir: str | None = None, train_ds: Dataset | None = None, valid_ds: Dataset | None = None, test_ds: Dataset | None = None, model: torch.nn.modules.module.Module | None = None, loss: collections.abc.Callable[[...], torch.Tensor] | None = None, optimizer: Optimizer | None = None, epoch_scheduler: _LRScheduler | None = None, step_scheduler: _LRScheduler | None = None, tmp_dir: str | None = None, model_weights_path: str | None = None, model_def_path: str | None = None, loss_def_path: str | None = None, training: bool = True)#

Constructor.

Parameters:
  • cfg (LearnerConfig) – LearnerConfig.

  • train_ds (Dataset | None) – The dataset to use for training. If None, will be generated from cfg.data. Defaults to None.

  • valid_ds (Dataset | None) – The dataset to use for validation. If None, will be generated from cfg.data. Defaults to None.

  • test_ds (Dataset | None) – The dataset to use for testing. If None, will be generated from cfg.data. Defaults to None.

  • model (torch.nn.modules.module.Module | None) – The model. If None, will be generated from cfg.model. Defaults to None.

  • loss (collections.abc.Callable[[...], torch.Tensor] | None) – The loss function. If None, will be generated from cfg.solver. Defaults to None.

  • optimizer (Optimizer | None) – The optimizer. If None, will be generated from cfg.solver. Defaults to None.

  • epoch_scheduler (_LRScheduler | None) – The scheduler that updates after each epoch. If None, will be generated from cfg.solver. Defaults to None.

  • step_scheduler (_LRScheduler | None) – The scheduler that updates after each optimizer-step. If None, will be generated from cfg.solver. Defaults to None.

  • tmp_dir (str | None) – A temporary directory to use for downloads etc. If None, will be auto-generated. Defaults to None.

  • model_weights_path (str | None) – URI of model weights to initialize the model with. Defaults to None.

  • model_def_path (str | None) – A local path to a directory with a hubconf.py file. If provided, the model definition is imported from here. This is used when loading an external model from a model-bundle. Defaults to None.

  • loss_def_path (str | None) – A local path to a directory with a hubconf.py file. If provided, the loss function definition is imported from here. This is used when loading an external loss function from a model-bundle. Defaults to None.

  • training (bool) – If False, the training apparatus (loss, optimizer, scheduler, logging, etc.) will not be set up and the model will be put into eval mode. If True, the training apparatus will be set up and the model will be put into training mode. Defaults to True.

  • output_dir (str | None) –

Methods

__init__(cfg[, output_dir, train_ds, ...])

Constructor.

build_dataloader(split[, distributed])

Build DataLoader for split.

build_dataloaders([distributed])

Build DataLoaders for train, validation, and test splits.

build_dataset(split)

Build Dataset for split.

build_datasets()

Build Datasets for train, validation, and test splits.

build_epoch_scheduler([start_epoch])

Returns an LR scheduler that changes the LR each epoch.

build_loss([loss_def_path])

Build a loss Callable.

build_model([model_def_path])

Build a PyTorch model.

build_optimizer()

Returns optimizer.

build_sampler(ds, split[, distributed])

Build an optional sampler for the split's dataloader.

build_step_scheduler([start_epoch])

Returns an LR scheduler that changes the LR each step.

ddp([rank, world_size])

Return a DDPContextManager.

export_to_onnx(path[, model, sample_input])

Export model to ONNX format via torch.onnx.export().

from_model_bundle(model_bundle_uri[, ...])

Create a Learner from a model bundle.

get_collate_fn()

Returns a custom collate_fn to use in DataLoader.

get_dataloader(split)

Get the DataLoader for a split.

get_dataset(split)

Get the Dataset for a split.

get_start_and_end_epochs([epochs])

Get start and end epochs given epochs.

get_start_epoch()

Get start epoch.

get_visualizer_class()

Returns a Visualizer class object for plotting data samples.

load_checkpoint()

Load last weights from previous run if available.

load_init_weights([model_weights_path])

Load the weights to initialize model.

load_onnx_model(model_path)

load_weights(uri, **kwargs)

Load model weights from a file.

log_data_stats()

Log stats about each DataSet.

main()

Main training sequence.

normalize_input(x)

Normalize x to [0, 1].

on_epoch_end(curr_epoch, metrics)

Hook that is called at end of epoch.

on_train_start()

Hook that is called at start of train routine.

output_to_numpy(out)

Convert output of model to numpy format.

plot_dataloader(dl, output_path[, ...])

Plot images and ground truth labels for a DataLoader.

plot_dataloaders([batch_limit, show])

Plot images and ground truth labels for all DataLoaders.

plot_predictions(split[, batch_limit, show])

Plot predictions for a split.

post_forward(x)

Post process output of call to model().

postprocess_model_output(out, raw_out, out_shape)

predict(x[, raw_out, out_shape])

Make prediction for an image or batch of images.

predict_dataloader(dl[, batched_output, ...])

Returns an iterator over predictions on the given dataloader.

predict_dataset(dataset[, return_format, ...])

Returns an iterator over predictions on the given dataset.

predict_onnx(x[, raw_out, out_shape])

Alternative to predict() for ONNX inference.

prob_to_pred(x)

Convert a Tensor with prediction probabilities to class ids.

reduce_distributed_metrics(metrics)

Average numeric metrics across processes.

run_tensorboard()

Run TB server serving logged stats.

save_model_bundle([export_onnx])

Save a model bundle.

save_weights(path)

Save model weights to a local file.

setup_data([distributed])

Set datasets and dataLoaders for train, validation, and test sets.

setup_ddp_params()

Set up and validate params related to PyTorch DDP.

setup_loss([loss_def_path])

Setup self.loss.

setup_model([model_weights_path, model_def_path])

Setup self.model.

setup_tensorboard()

Setup for logging stats to TB.

setup_training([loss_def_path])

Set up model, data, loss, optimizers and various paths.

stop_tensorboard()

Stop TB logging and server if it's running.

sync_from_cloud()

Sync any previous output in the cloud to output_dir.

sync_to_cloud()

Sync any output to the cloud at output_uri.

to_batch(x)

Ensure that image array has batch dimension.

to_device(x, device)

Load Tensors onto a device.

train([epochs])

Run training loop, resuming training if appropriate

train_end(outputs)

Aggregate the output of train_step at the end of the epoch.

train_epoch(optimizer[, dataloader, ...])

Train for a single epoch.

train_step(batch, batch_ind)

Compute loss for a single training batch.

validate([split])

Evaluate model on a particular data split.

validate_end(outputs)

Aggregate the output of validate_step at the end of the epoch.

validate_epoch(dl)

Validate for a single epoch.

validate_step(batch, batch_ind)

Compute metrics on validation batch.

__init__(cfg: LearnerConfig, output_dir: str | None = None, train_ds: Dataset | None = None, valid_ds: Dataset | None = None, test_ds: Dataset | None = None, model: torch.nn.modules.module.Module | None = None, loss: collections.abc.Callable[[...], torch.Tensor] | None = None, optimizer: Optimizer | None = None, epoch_scheduler: _LRScheduler | None = None, step_scheduler: _LRScheduler | None = None, tmp_dir: str | None = None, model_weights_path: str | None = None, model_def_path: str | None = None, loss_def_path: str | None = None, training: bool = True)#

Constructor.

Parameters:
  • cfg (LearnerConfig) – LearnerConfig.

  • train_ds (Dataset | None) – The dataset to use for training. If None, will be generated from cfg.data. Defaults to None.

  • valid_ds (Dataset | None) – The dataset to use for validation. If None, will be generated from cfg.data. Defaults to None.

  • test_ds (Dataset | None) – The dataset to use for testing. If None, will be generated from cfg.data. Defaults to None.

  • model (torch.nn.modules.module.Module | None) – The model. If None, will be generated from cfg.model. Defaults to None.

  • loss (collections.abc.Callable[[...], torch.Tensor] | None) – The loss function. If None, will be generated from cfg.solver. Defaults to None.

  • optimizer (Optimizer | None) – The optimizer. If None, will be generated from cfg.solver. Defaults to None.

  • epoch_scheduler (_LRScheduler | None) – The scheduler that updates after each epoch. If None, will be generated from cfg.solver. Defaults to None.

  • step_scheduler (_LRScheduler | None) – The scheduler that updates after each optimizer-step. If None, will be generated from cfg.solver. Defaults to None.

  • tmp_dir (str | None) – A temporary directory to use for downloads etc. If None, will be auto-generated. Defaults to None.

  • model_weights_path (str | None) – URI of model weights to initialize the model with. Defaults to None.

  • model_def_path (str | None) – A local path to a directory with a hubconf.py file. If provided, the model definition is imported from here. This is used when loading an external model from a model-bundle. Defaults to None.

  • loss_def_path (str | None) – A local path to a directory with a hubconf.py file. If provided, the loss function definition is imported from here. This is used when loading an external loss function from a model-bundle. Defaults to None.

  • training (bool) – If False, the training apparatus (loss, optimizer, scheduler, logging, etc.) will not be set up and the model will be put into eval mode. If True, the training apparatus will be set up and the model will be put into training mode. Defaults to True.

  • output_dir (str | None) –

build_dataloader(split: Literal['train', 'valid', 'test'], distributed: bool | None = None, **kwargs) DataLoader#

Build DataLoader for split.

Parameters:
  • split (Literal['train', 'valid', 'test']) –

  • distributed (bool | None) –

Return type:

DataLoader

build_dataloaders(distributed: bool | None = None) tuple[torch.utils.data.dataloader.DataLoader, torch.utils.data.dataloader.DataLoader, torch.utils.data.dataloader.DataLoader | None]#

Build DataLoaders for train, validation, and test splits.

Parameters:

distributed (bool | None) –

Return type:

tuple[torch.utils.data.dataloader.DataLoader, torch.utils.data.dataloader.DataLoader, torch.utils.data.dataloader.DataLoader | None]

build_dataset(split: Literal['train', 'valid', 'test']) Dataset#

Build Dataset for split.

Parameters:

split (Literal['train', 'valid', 'test']) –

Return type:

Dataset

build_datasets() tuple['Dataset', 'Dataset', 'Dataset']#

Build Datasets for train, validation, and test splits.

Return type:

tuple[‘Dataset’, ‘Dataset’, ‘Dataset’]

build_epoch_scheduler(start_epoch: int = 0) _LRScheduler#

Returns an LR scheduler that changes the LR each epoch.

Parameters:

start_epoch (int) –

Return type:

_LRScheduler

build_loss(loss_def_path: str | None = None) Callable[[...], Tensor]#

Build a loss Callable.

Parameters:

loss_def_path (str | None) –

Return type:

Callable[[…], Tensor]

build_model(model_def_path: str | None = None) Module#

Build a PyTorch model.

Parameters:

model_def_path (str | None) –

Return type:

Module

build_optimizer() Optimizer#

Returns optimizer.

Return type:

Optimizer

build_sampler(ds: Dataset, split: Literal['train', 'valid', 'test'], distributed: bool = False) Sampler | None#

Build an optional sampler for the split’s dataloader.

Parameters:
  • ds (Dataset) –

  • split (Literal['train', 'valid', 'test']) –

  • distributed (bool) –

Return type:

Sampler | None

build_step_scheduler(start_epoch: int = 0) _LRScheduler#

Returns an LR scheduler that changes the LR each step.

Parameters:

start_epoch (int) –

Return type:

_LRScheduler

ddp(rank: int | None = None, world_size: int | None = None) DDPContextManager#

Return a DDPContextManager.

This should be used to wrap code that needs to be executed in parallel. It is safe call this recursively; recusive calls will have no affect.

Note that DDPContextManager does not start processes itself, but merely initializes and destroyes DDP process groups.

Usage:

with learner.ddp([rank], [world_size]):
    ...
Parameters:
  • rank (int | None) –

  • world_size (int | None) –

Return type:

DDPContextManager

export_to_onnx(path: str, model: nn.Module | None = None, sample_input: torch.Tensor | None = None, **kwargs) None[source]#

Export model to ONNX format via torch.onnx.export().

Parameters:
  • path (str) – File path to save to.

  • model (nn.Module | None) – The model to export. If None, self.model will be used. Defaults to None.

  • sample_input (torch.Tensor | None) – Sample input to the model. If None, a single batch from any available DataLoader in this Learner will be used. Defaults to None.

  • validate_export – If True, use onnx.checker.check_model() to validate exported model. An exception is raised if the check fails. Defaults to True.

  • **kwargs – Keyword args to pass to torch.onnx.export(). These override the default values used in the function definition.

Raises:

ValueError – If sample_input is None and the Learner has no valid DataLoaders.

Return type:

None

classmethod from_model_bundle(model_bundle_uri: str, tmp_dir: str | None = None, cfg: LearnerConfig | None = None, training: bool = False, use_onnx_model: bool | None = None, **kwargs) Self#

Create a Learner from a model bundle.

Note

This is the bundle saved in train/model-bundle.zip and not bundle/model-bundle.zip.

Parameters:
  • model_bundle_uri (str) – URI of the model bundle.

  • tmp_dir (str | None) – Optional temporary directory. Will be used for unzipping bundle and also passed to the default constructor. If None, will be auto-generated. Defaults to None.

  • cfg (LearnerConfig | None) – If None, will be read from the bundle. Defaults to None.

  • training (bool) – If False, the training apparatus (loss, optimizer, scheduler, logging, etc.) will not be set up and the model will be put into eval mode. If True, the training apparatus will be set up and the model will be put into training mode. Defaults to True.

  • use_onnx_model (bool | None) – If True and training=False and a model.onnx file is available in the bundle, use that for inference rather than the PyTorch weights. Defaults to the boolean environment variable RASTERVISION_USE_ONNX if set, False otherwise.

  • **kwargs – Extra args for __init__().

Raises:

FileNotFoundError – If using custom Albumentations transforms and definition file is not found in bundle.

Returns:

Object of the Learner subclass on which this was called.

Return type:

Learner

get_collate_fn() collections.abc.Callable | None#

Returns a custom collate_fn to use in DataLoader.

None is returned if default collate_fn should be used.

See https://pytorch.org/docs/stable/data.html#working-with-collate-fn

Return type:

collections.abc.Callable | None

get_dataloader(split: Literal['train', 'valid', 'test']) DataLoader#

Get the DataLoader for a split.

Parameters:

split (Literal['train', 'valid', 'test']) – a split name which can be train, valid, or test

Return type:

DataLoader

get_dataset(split: Literal['train', 'valid', 'test']) torch.utils.data.dataloader.DataLoader | None#

Get the Dataset for a split.

Parameters:

split (Literal['train', 'valid', 'test']) – a split name which can be train, valid, or test

Return type:

torch.utils.data.dataloader.DataLoader | None

get_start_and_end_epochs(epochs: int | None = None) tuple[int, int]#

Get start and end epochs given epochs.

Parameters:

epochs (int | None) –

Return type:

tuple[int, int]

get_start_epoch() int#

Get start epoch.

If training was interrupted, this returns the last complete epoch + 1.

Return type:

int

get_visualizer_class()[source]#

Returns a Visualizer class object for plotting data samples.

load_checkpoint()#

Load last weights from previous run if available.

load_init_weights(model_weights_path: str | None = None) None#

Load the weights to initialize model.

Parameters:

model_weights_path (str | None) –

Return type:

None

load_onnx_model(model_path: str) ONNXRuntimeAdapter#
Parameters:

model_path (str) –

Return type:

ONNXRuntimeAdapter

load_weights(uri: str, **kwargs) None#

Load model weights from a file.

Parameters:
  • uri (str) – URI.

  • **kwargs – Extra args for nn.Module.load_state_dict().

Return type:

None

log_data_stats()#

Log stats about each DataSet.

main()#

Main training sequence.

This plots the dataset, runs a training and validation loop (which will resume if interrupted), logs stats, plots predictions, and syncs results to the cloud.

normalize_input(x: ndarray) ndarray#

Normalize x to [0, 1].

If x.dtype is a subtype of np.unsignedinteger, normalize it to [0, 1] using the max possible value of that dtype. Otherwise, assume it is in [0, 1] already and do nothing.

Parameters:

x (np.ndarray) – an image or batch of images

Returns:

the same array scaled to [0, 1].

Return type:

ndarray

on_epoch_end(curr_epoch: int, metrics: dict[str, float]) None#

Hook that is called at end of epoch.

Writes metrics to CSV and TensorBoard, and saves model.

Parameters:
Return type:

None

on_train_start()#

Hook that is called at start of train routine.

output_to_numpy(out: Tensor) ndarray#

Convert output of model to numpy format.

Parameters:

out (Tensor) – the output of the model in PyTorch format

Return type:

ndarray

Returns: the output of the model in numpy format

plot_dataloader(dl: DataLoader, output_path: str, batch_limit: int | None = None, show: bool = False)#

Plot images and ground truth labels for a DataLoader.

Parameters:
plot_dataloaders(batch_limit: int | None = None, show: bool = False)#

Plot images and ground truth labels for all DataLoaders.

Parameters:
  • batch_limit (int | None) –

  • show (bool) –

plot_predictions(split: Literal['train', 'valid', 'test'], batch_limit: int | None = None, show: bool = False)#

Plot predictions for a split.

Uses the first batch for the corresponding DataLoader.

Parameters:
  • split (Literal['train', 'valid', 'test']) – dataset split. Can be train, valid, or test.

  • batch_limit (int | None) – optional limit on (rendered) batch size

  • show (bool) –

post_forward(x)[source]#

Post process output of call to model().

Useful for when predictions are inside a structure returned by model().

postprocess_model_output(out: Tensor, raw_out: bool, out_shape: tuple[int, int])[source]#
Parameters:
predict(x: Tensor, raw_out: bool = False, out_shape: tuple[int, int] | None = None) Tensor[source]#

Make prediction for an image or batch of images.

Parameters:
  • x (Tensor) – Image or batch of images as a float Tensor with pixel values normalized to [0, 1].

  • raw_out (bool) – if True, return prediction probabilities

  • out_shape (tuple[int, int] | None) –

Returns:

The predictions, in probability form if raw_out is True, in class_id form otherwise.

Return type:

Tensor

predict_dataloader(dl: DataLoader, batched_output: bool = True, return_format: Literal['xyz', 'yz', 'z'] = 'z', raw_out: bool = True, predict_kw: dict = {}) Union[Iterator[Any], Iterator[tuple[Any, ...]]]#

Returns an iterator over predictions on the given dataloader.

Parameters:
  • dl (DataLoader) – The dataloader to make predictions on.

  • batched_output (bool) – If True, return batches of x, y, z as defined by the dataloader. If False, unroll the batches into individual items. Defaults to True.

  • return_format (Literal['xyz', 'yz', 'z']) – Format of the return elements of the returned iterator. Must be one of: ‘xyz’, ‘yz’, and ‘z’. If ‘xyz’, elements are 3-tuples of x, y, and z. If ‘yz’, elements are 2-tuples of y and z. If ‘z’, elements are (non-tuple) values of z. Where x = input image, y = ground truth, and z = prediction. Defaults to ‘z’.

  • raw_out (bool) – If true, return raw predicted scores. Defaults to True.

  • predict_kw (dict) – Dict with keywords passed to Learner.predict(). Useful if a Learner subclass implements a custom predict() method.

Raises:

ValueError – If return_format is not one of the allowed values.

Returns:

If return_format is 'z', the returned value is an iterator of whatever type the predictions are. Otherwise, the returned value is an iterator of tuples.

Return type:

Union[Iterator[Any], Iterator[tuple[Any, …]]]

predict_dataset(dataset: Dataset, return_format: Literal['xyz', 'yz', 'z'] = 'z', raw_out: bool = True, numpy_out: bool = False, predict_kw: dict = {}, dataloader_kw: dict = {}, progress_bar: bool = True, progress_bar_kw: dict = {}) Union[Iterator[Any], Iterator[tuple[Any, ...]]]#

Returns an iterator over predictions on the given dataset.

Parameters:
  • dataset (Dataset) – The dataset to make predictions on.

  • return_format (Literal['xyz', 'yz', 'z']) – Format of the return elements of the returned iterator. Must be one of: ‘xyz’, ‘yz’, and ‘z’. If ‘xyz’, elements are 3-tuples of x, y, and z. If ‘yz’, elements are 2-tuples of y and z. If ‘z’, elements are (non-tuple) values of z. Where x = input image, y = ground truth, and z = prediction. Defaults to ‘z’.

  • raw_out (bool) – If true, return raw predicted scores. Defaults to True.

  • numpy_out (bool) – If True, convert predictions to numpy arrays before returning. Defaults to False.

  • predict_kw (dict) – Dict with keywords passed to Learner.predict(). Useful if a Learner subclass implements a custom predict() method.

  • dataloader_kw (dict) – Dict with keywords passed to the DataLoader constructor.

  • progress_bar (bool) – If True, display a progress bar. Since this function returns an iterator, the progress bar won’t be visible until the iterator is consumed. Defaults to True.

  • progress_bar_kw (dict) – Dict with keywords passed to tqdm.

Raises:

ValueError – If return_format is not one of the allowed values.

Returns:

If return_format is ‘z’, the returned value is an iterator of whatever type the predictions are. Otherwise, the returned value is an iterator of tuples.

Return type:

Union[Iterator[Any], Iterator[tuple[Any, …]]]

predict_onnx(x: Tensor, raw_out: bool = False, out_shape: tuple[int, int] | None = None) Tensor[source]#

Alternative to predict() for ONNX inference.

Parameters:
Return type:

Tensor

prob_to_pred(x)[source]#

Convert a Tensor with prediction probabilities to class ids.

The class ids should be the classes with the maximum probability.

reduce_distributed_metrics(metrics: dict)#

Average numeric metrics across processes.

Parameters:

metrics (dict) –

run_tensorboard()#

Run TB server serving logged stats.

save_model_bundle(export_onnx: bool = True)#

Save a model bundle.

This is a zip file with the model weights in .pth format and a serialized copy of the LearningConfig, which allows for making predictions in the future.

Parameters:

export_onnx (bool) –

save_weights(path: str)#

Save model weights to a local file.

Parameters:

path (str) –

setup_data(distributed: bool | None = None)#

Set datasets and dataLoaders for train, validation, and test sets.

Parameters:

distributed (bool | None) –

setup_ddp_params()#

Set up and validate params related to PyTorch DDP.

setup_loss(loss_def_path: str | None = None) None#

Setup self.loss.

Parameters:
  • loss_def_path (str) – Loss definition path. Will be

  • None. (available when loading from a bundle. Defaults to) –

Return type:

None

setup_model(model_weights_path: str | None = None, model_def_path: str | None = None) None#

Setup self.model.

Parameters:
  • model_weights_path (str | None) – Path to model weights. Will be available when loading from a bundle. Defaults to None.

  • model_def_path (str | None) – Path to model definition. Will be available when loading from a bundle. Defaults to None.

Return type:

None

setup_tensorboard()#

Setup for logging stats to TB.

setup_training(loss_def_path: str | None = None) None#

Set up model, data, loss, optimizers and various paths.

The exact behavior differs based on whether this method is called in a distributed scenario.

Parameters:

loss_def_path (str | None) – A local path to a directory with a hubconf.py. If provided, the loss function definition is imported from here. This is used when loading an external loss function from a model-bundle. Defaults to None.

Return type:

None

stop_tensorboard()#

Stop TB logging and server if it’s running.

sync_from_cloud()#

Sync any previous output in the cloud to output_dir.

sync_to_cloud()#

Sync any output to the cloud at output_uri.

to_batch(x: Tensor) Tensor#

Ensure that image array has batch dimension.

Parameters:

x (Tensor) – assumed to be either image or batch of images

Returns:

x with extra batch dimension of length 1 if needed

Return type:

Tensor

to_device(x: Any, device: str | torch.device) Any#

Load Tensors onto a device.

Parameters:
  • x (Any) – some object with Tensors in it

  • device (str | torch.device) – ‘cpu’ or ‘cuda’

Returns:

x but with any Tensors in it on the device

Return type:

Any

train(epochs: int | None = None)#

Run training loop, resuming training if appropriate

Parameters:

epochs (int | None) –

train_end(outputs: list[dict[str, float | torch.Tensor]]) dict[str, float]#

Aggregate the output of train_step at the end of the epoch.

Parameters:

outputs (list[dict[str, float | torch.Tensor]]) – a list of outputs of train_step

Return type:

dict[str, float]

train_epoch(optimizer: Optimizer, dataloader: torch.utils.data.dataloader.DataLoader | None = None, step_scheduler: _LRScheduler | None = None) dict[str, float]#

Train for a single epoch.

Parameters:
Return type:

dict[str, float]

train_step(batch, batch_ind)[source]#

Compute loss for a single training batch.

Parameters:
  • batch – batch data needed to compute loss

  • batch_ind – index of batch within epoch

Returns:

dict with ‘train_loss’ as key and possibly other losses

validate(split: Literal['train', 'valid', 'test'] = 'valid')#

Evaluate model on a particular data split.

Parameters:

split (Literal['train', 'valid', 'test']) –

validate_end(outputs)[source]#

Aggregate the output of validate_step at the end of the epoch.

Parameters:

outputs – a list of outputs of validate_step

validate_epoch(dl: DataLoader) dict[str, float]#

Validate for a single epoch.

Parameters:

dl (DataLoader) –

Return type:

dict[str, float]

validate_step(batch, batch_ind)[source]#

Compute metrics on validation batch.

Parameters:
  • batch – batch data needed to compute validation metrics

  • batch_ind – index of batch within epoch

Returns:

dict with metric names mapped to metric values

property onnx_mode: bool#