Visualizer#

class Visualizer[source]#

Bases: ABC

Base class for plotting samples from computer vision PyTorch Datasets.

Attributes

__init__(class_names: list[str], class_colors: list[str | tuple[int, int, int]] | None = None, transform: dict | None = None, channel_display_groups: Optional[Union[dict[str, Sequence[int]], Sequence[Sequence[int]]]] = None)[source]#

Constructor.

Parameters:
  • class_names (list[str]) – names of classes

  • class_colors (list[str | tuple[int, int, int]] | None) – Colors used to display classes. Can be color 3-tuples in list form.

  • transform (dict | None) – An Albumentations transform serialized as a dict that will be applied to each image before it is plotted. Mainly useful for undoing any data transformation that you do not want included in the plot, such as normalization. The default value will shift and scale the image so the values range from 0.0 to 1.0 which is the expected range for the plotting function. This default is useful for cases where the values after normalization are close to zero which makes the plot difficult to see.

  • channel_display_groups (Optional[Union[dict[str, Sequence[int]], Sequence[Sequence[int]]]]) – Groups of image channels to display together as a subplot when plotting the data and predictions. Can be a list or tuple of groups (e.g. [(0, 1, 2), (3,)]) or a dict containing title-to-group mappings (e.g. {“RGB”: [0, 1, 2], “IR”: [3]}), where each group is a list or tuple of channel indices and title is a string that will be used as the title of the subplot for that group.

Methods

__init__(class_names[, class_colors, ...])

Constructor.

get_batch(dataset[, batch_sz])

Generate a batch from a dataset.

get_channel_display_groups(nb_img_channels)

get_collate_fn()

Returns a custom collate_fn to use in DataLoader.

get_plot_ncols(**kwargs)

get_plot_nrows(**kwargs)

get_plot_params(**kwargs)

plot_batch(x[, y, output_path, z, ...])

Plot a whole batch in a grid using plot_xyz.

plot_xyz(axs, x[, y, z, plot_title])

Plot image, ground truth labels, and predicted labels.

__init__(class_names: list[str], class_colors: list[str | tuple[int, int, int]] | None = None, transform: dict | None = None, channel_display_groups: Optional[Union[dict[str, Sequence[int]], Sequence[Sequence[int]]]] = None)[source]#

Constructor.

Parameters:
  • class_names (list[str]) – names of classes

  • class_colors (list[str | tuple[int, int, int]] | None) – Colors used to display classes. Can be color 3-tuples in list form.

  • transform (dict | None) – An Albumentations transform serialized as a dict that will be applied to each image before it is plotted. Mainly useful for undoing any data transformation that you do not want included in the plot, such as normalization. The default value will shift and scale the image so the values range from 0.0 to 1.0 which is the expected range for the plotting function. This default is useful for cases where the values after normalization are close to zero which makes the plot difficult to see.

  • channel_display_groups (Optional[Union[dict[str, Sequence[int]], Sequence[Sequence[int]]]]) – Groups of image channels to display together as a subplot when plotting the data and predictions. Can be a list or tuple of groups (e.g. [(0, 1, 2), (3,)]) or a dict containing title-to-group mappings (e.g. {“RGB”: [0, 1, 2], “IR”: [3]}), where each group is a list or tuple of channel indices and title is a string that will be used as the title of the subplot for that group.

get_batch(dataset: Dataset, batch_sz: int = 4, **kwargs) tuple[torch.Tensor, Any][source]#

Generate a batch from a dataset.

This is a convenience method for generating a batch of data to plot.

Parameters:
  • dataset (Dataset) – A Pytorch Dataset.

  • batch_sz (int) – Batch size. Defaults to 4.

  • **kwargs – Extra args for DataLoader.

Returns:

(x, y) tuple where x is images and y is labels.

Return type:

tuple[Tensor, Any]

get_channel_display_groups(nb_img_channels: int) Union[dict[str, Sequence[int]], Sequence[Sequence[int]]][source]#
Parameters:

nb_img_channels (int) –

Return type:

Union[dict[str, Sequence[int]], Sequence[Sequence[int]]]

get_collate_fn() collections.abc.Callable | None[source]#

Returns a custom collate_fn to use in DataLoader.

None is returned if default collate_fn should be used.

See https://pytorch.org/docs/stable/data.html#working-with-collate-fn

Return type:

collections.abc.Callable | None

get_plot_ncols(**kwargs) int[source]#
Return type:

int

get_plot_nrows(**kwargs) int[source]#
Return type:

int

get_plot_params(**kwargs) dict[source]#
Return type:

dict

plot_batch(x: Tensor, y: Optional[Sequence] = None, output_path: str | None = None, z: Optional[Sequence] = None, batch_limit: int | None = None, show: bool = False)[source]#

Plot a whole batch in a grid using plot_xyz.

Parameters:
  • x (Tensor) – batch of images

  • y (Optional[Sequence]) – ground truth labels

  • output_path (str | None) – local path where to save plot image

  • z (Optional[Sequence]) – optional predicted labels

  • batch_limit (int | None) – optional limit on (rendered) batch size

  • show (bool) –

abstract plot_xyz(axs, x: Tensor, y: Optional[Sequence] = None, z: Optional[Sequence] = None, plot_title: bool = True)[source]#

Plot image, ground truth labels, and predicted labels.

Parameters:
property class_colors#
property class_names#
scale: float = 3.0#