RegressionVisualizer#

class RegressionVisualizer[source]#

Bases: Visualizer

Plots samples from image regression Datasets.

Attributes

__init__(class_names: List[str], class_colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None, transform: Optional[Dict] = None, channel_display_groups: Optional[Union[Dict[str, Sequence[ConstrainedIntValue]], Sequence[Sequence[ConstrainedIntValue]]]] = None)#

Constructor.

Parameters
  • class_names (List[str]) – names of classes

  • class_colors (Optional[List[Union[str, Tuple[int, int, int]]]]) – Colors used to display classes. Can be color 3-tuples in list form.

  • transform (Optional[Dict]) – An Albumentations transform serialized as a dict that will be applied to each image before it is plotted. Mainly useful for undoing any data transformation that you do not want included in the plot, such as normalization. The default value will shift and scale the image so the values range from 0.0 to 1.0 which is the expected range for the plotting function. This default is useful for cases where the values after normalization are close to zero which makes the plot difficult to see.

  • channel_display_groups (Optional[Union[Dict[str, Sequence[ConstrainedIntValue]], Sequence[Sequence[ConstrainedIntValue]]]]) – Groups of image channels to display together as a subplot when plotting the data and predictions. Can be a list or tuple of groups (e.g. [(0, 1, 2), (3,)]) or a dict containing title-to-group mappings (e.g. {“RGB”: [0, 1, 2], “IR”: [3]}), where each group is a list or tuple of channel indices and title is a string that will be used as the title of the subplot for that group.

Methods

__init__(class_names[, class_colors, ...])

Constructor.

get_batch(dataset[, batch_sz])

Generate a batch from a dataset.

get_channel_display_groups(nb_img_channels)

get_collate_fn()

Returns a custom collate_fn to use in DataLoader.

get_plot_ncols(**kwargs)

get_plot_nrows(**kwargs)

get_plot_params(**kwargs)

plot_batch(x[, y, output_path, z, ...])

Plot a whole batch in a grid using plot_xyz.

plot_gt(ax, class_names, y)

Plot targets as a horizontal bar plot with values at the tips.

plot_pred(ax, class_names, z[, y])

Plot targets and predictions as a grouped horizontal bar plot.

plot_xyz(axs, x, y[, z, plot_title])

Plot image, ground truth labels, and predicted labels.

__init__(class_names: List[str], class_colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None, transform: Optional[Dict] = None, channel_display_groups: Optional[Union[Dict[str, Sequence[ConstrainedIntValue]], Sequence[Sequence[ConstrainedIntValue]]]] = None)#

Constructor.

Parameters
  • class_names (List[str]) – names of classes

  • class_colors (Optional[List[Union[str, Tuple[int, int, int]]]]) – Colors used to display classes. Can be color 3-tuples in list form.

  • transform (Optional[Dict]) – An Albumentations transform serialized as a dict that will be applied to each image before it is plotted. Mainly useful for undoing any data transformation that you do not want included in the plot, such as normalization. The default value will shift and scale the image so the values range from 0.0 to 1.0 which is the expected range for the plotting function. This default is useful for cases where the values after normalization are close to zero which makes the plot difficult to see.

  • channel_display_groups (Optional[Union[Dict[str, Sequence[ConstrainedIntValue]], Sequence[Sequence[ConstrainedIntValue]]]]) – Groups of image channels to display together as a subplot when plotting the data and predictions. Can be a list or tuple of groups (e.g. [(0, 1, 2), (3,)]) or a dict containing title-to-group mappings (e.g. {“RGB”: [0, 1, 2], “IR”: [3]}), where each group is a list or tuple of channel indices and title is a string that will be used as the title of the subplot for that group.

get_batch(dataset: Dataset, batch_sz: int = 4, **kwargs) Tuple[torch.Tensor, Any]#

Generate a batch from a dataset.

This is a convenience method for generating a batch of data to plot.

Parameters
  • dataset (Dataset) – A Pytorch Dataset.

  • batch_sz (int) – Batch size. Defaults to 4.

  • **kwargs – Extra args for DataLoader.

Returns

(x, y) tuple where x is images and y is labels.

Return type

Tuple[Tensor, Any]

get_channel_display_groups(nb_img_channels: int) Union[Dict[str, Sequence[ConstrainedIntValue]], Sequence[Sequence[ConstrainedIntValue]]]#
Parameters

nb_img_channels (int) –

Return type

Union[Dict[str, Sequence[ConstrainedIntValue]], Sequence[Sequence[ConstrainedIntValue]]]

get_collate_fn() Optional[callable]#

Returns a custom collate_fn to use in DataLoader.

None is returned if default collate_fn should be used.

See https://pytorch.org/docs/stable/data.html#working-with-collate-fn

Return type

Optional[callable]

get_plot_ncols(**kwargs) int[source]#
Return type

int

get_plot_nrows(**kwargs) int#
Return type

int

get_plot_params(**kwargs) dict#
Return type

dict

plot_batch(x: torch.Tensor, y: Optional[Sequence] = None, output_path: Optional[str] = None, z: Optional[Sequence] = None, batch_limit: Optional[int] = None, show: bool = False)#

Plot a whole batch in a grid using plot_xyz.

Parameters
plot_gt(ax: Axes, class_names: Sequence[str], y: torch.Tensor)[source]#

Plot targets as a horizontal bar plot with values at the tips.

Parameters
plot_pred(ax: Axes, class_names: Sequence[str], z: torch.Tensor, y: Optional[torch.Tensor] = None)[source]#

Plot targets and predictions as a grouped horizontal bar plot.

Parameters
plot_xyz(axs: Sequence, x: torch.Tensor, y: int, z: Optional[int] = None, plot_title: bool = True) None[source]#

Plot image, ground truth labels, and predicted labels.

Parameters
Return type

None

scale: float = 3.0#