ClassificationVisualizer#
- class ClassificationVisualizer[source]#
Bases:
Visualizer
Plots samples from image classification Datasets.
Attributes
- __init__(class_names: list[str], class_colors: list[str | tuple[int, int, int]] | None = None, transform: dict | None = None, channel_display_groups: Optional[Union[dict[str, Sequence[int]], Sequence[Sequence[int]]]] = None)#
Constructor.
- Parameters:
class_colors (list[str | tuple[int, int, int]] | None) – Colors used to display classes. Can be color 3-tuples in list form.
transform (dict | None) – An Albumentations transform serialized as a dict that will be applied to each image before it is plotted. Mainly useful for undoing any data transformation that you do not want included in the plot, such as normalization. The default value will shift and scale the image so the values range from 0.0 to 1.0 which is the expected range for the plotting function. This default is useful for cases where the values after normalization are close to zero which makes the plot difficult to see.
channel_display_groups (Optional[Union[dict[str, Sequence[int]], Sequence[Sequence[int]]]]) – Groups of image channels to display together as a subplot when plotting the data and predictions. Can be a list or tuple of groups (e.g. [(0, 1, 2), (3,)]) or a dict containing title-to-group mappings (e.g. {“RGB”: [0, 1, 2], “IR”: [3]}), where each group is a list or tuple of channel indices and title is a string that will be used as the title of the subplot for that group.
Methods
__init__
(class_names[, class_colors, ...])Constructor.
get_batch
(dataset[, batch_sz])Generate a batch from a dataset.
get_channel_display_groups
(nb_img_channels)Returns a custom collate_fn to use in DataLoader.
get_plot_ncols
(**kwargs)get_plot_nrows
(**kwargs)get_plot_params
(**kwargs)plot_batch
(x[, y, output_path, z, ...])Plot a whole batch in a grid using plot_xyz.
plot_gt
(ax, class_names, y)Display ground truth class names as text.
plot_pred
(ax, class_names, z[, y])Plot predictions.
plot_xyz
(axs, x[, y, z, plot_title])Plot image, ground truth labels, and predicted labels.
- __init__(class_names: list[str], class_colors: list[str | tuple[int, int, int]] | None = None, transform: dict | None = None, channel_display_groups: Optional[Union[dict[str, Sequence[int]], Sequence[Sequence[int]]]] = None)#
Constructor.
- Parameters:
class_colors (list[str | tuple[int, int, int]] | None) – Colors used to display classes. Can be color 3-tuples in list form.
transform (dict | None) – An Albumentations transform serialized as a dict that will be applied to each image before it is plotted. Mainly useful for undoing any data transformation that you do not want included in the plot, such as normalization. The default value will shift and scale the image so the values range from 0.0 to 1.0 which is the expected range for the plotting function. This default is useful for cases where the values after normalization are close to zero which makes the plot difficult to see.
channel_display_groups (Optional[Union[dict[str, Sequence[int]], Sequence[Sequence[int]]]]) – Groups of image channels to display together as a subplot when plotting the data and predictions. Can be a list or tuple of groups (e.g. [(0, 1, 2), (3,)]) or a dict containing title-to-group mappings (e.g. {“RGB”: [0, 1, 2], “IR”: [3]}), where each group is a list or tuple of channel indices and title is a string that will be used as the title of the subplot for that group.
- get_batch(dataset: Dataset, batch_sz: int = 4, **kwargs) tuple[torch.Tensor, Any] #
Generate a batch from a dataset.
This is a convenience method for generating a batch of data to plot.
- Parameters:
dataset (Dataset) – A Pytorch Dataset.
batch_sz (int) – Batch size. Defaults to 4.
**kwargs – Extra args for
DataLoader
.
- Returns:
(x, y) tuple where x is images and y is labels.
- Return type:
tuple[Tensor, Any]
- get_channel_display_groups(nb_img_channels: int) Union[dict[str, Sequence[int]], Sequence[Sequence[int]]] #
- get_collate_fn() collections.abc.Callable | None #
Returns a custom collate_fn to use in DataLoader.
None is returned if default collate_fn should be used.
See https://pytorch.org/docs/stable/data.html#working-with-collate-fn
- Return type:
collections.abc.Callable | None
- plot_batch(x: Tensor, y: Optional[Sequence] = None, output_path: str | None = None, z: Optional[Sequence] = None, batch_limit: int | None = None, show: bool = False)#
Plot a whole batch in a grid using plot_xyz.
- plot_gt(ax: Axes, class_names: Sequence[str], y: Tensor)[source]#
Display ground truth class names as text.
- plot_pred(ax: Axes, class_names: Sequence[str], z: Tensor, y: torch.Tensor | None = None)[source]#
Plot predictions.
Plots predicted class probabilities as a horizontal bar plot. If ground truth, y, is provided, the bar colors represent: green = ground truth, dark-red = wrong prediction, light-gray = other. In case predicted class matches ground truth, only one bar will be green and the others will be light-gray.
- Parameters:
ax (Axes) –
z (Tensor) –
y (torch.Tensor | None) –
- plot_xyz(axs: Sequence[Axes], x: Tensor, y: int | None = None, z: int | None = None, plot_title: bool = True) None [source]#
Plot image, ground truth labels, and predicted labels.
- property class_colors#
- property class_names#