ObjectDetectionVisualizer#
- class ObjectDetectionVisualizer[source]#
Bases:
Visualizer
Plots samples from object detection Datasets.
Attributes
- __init__(class_names: list[str], class_colors: list[str | tuple[int, int, int]] | None = None, transform: dict | None = None, channel_display_groups: Optional[Union[dict[str, Sequence[int]], Sequence[Sequence[int]]]] = None)#
Constructor.
- Parameters:
class_colors (list[str | tuple[int, int, int]] | None) – Colors used to display classes. Can be color 3-tuples in list form.
transform (dict | None) – An Albumentations transform serialized as a dict that will be applied to each image before it is plotted. Mainly useful for undoing any data transformation that you do not want included in the plot, such as normalization. The default value will shift and scale the image so the values range from 0.0 to 1.0 which is the expected range for the plotting function. This default is useful for cases where the values after normalization are close to zero which makes the plot difficult to see.
channel_display_groups (Optional[Union[dict[str, Sequence[int]], Sequence[Sequence[int]]]]) – Groups of image channels to display together as a subplot when plotting the data and predictions. Can be a list or tuple of groups (e.g. [(0, 1, 2), (3,)]) or a dict containing title-to-group mappings (e.g. {“RGB”: [0, 1, 2], “IR”: [3]}), where each group is a list or tuple of channel indices and title is a string that will be used as the title of the subplot for that group.
Methods
__init__
(class_names[, class_colors, ...])Constructor.
get_batch
(dataset[, batch_sz])Generate a batch from a dataset.
get_channel_display_groups
(nb_img_channels)Returns a custom collate_fn to use in DataLoader.
get_plot_ncols
(**kwargs)get_plot_nrows
(**kwargs)get_plot_params
(**kwargs)plot_batch
(x[, y, output_path, z, ...])Plot a whole batch in a grid using plot_xyz.
plot_xyz
(axs, x[, y, z, plot_title])Plot image, ground truth labels, and predicted labels.
- __init__(class_names: list[str], class_colors: list[str | tuple[int, int, int]] | None = None, transform: dict | None = None, channel_display_groups: Optional[Union[dict[str, Sequence[int]], Sequence[Sequence[int]]]] = None)#
Constructor.
- Parameters:
class_colors (list[str | tuple[int, int, int]] | None) – Colors used to display classes. Can be color 3-tuples in list form.
transform (dict | None) – An Albumentations transform serialized as a dict that will be applied to each image before it is plotted. Mainly useful for undoing any data transformation that you do not want included in the plot, such as normalization. The default value will shift and scale the image so the values range from 0.0 to 1.0 which is the expected range for the plotting function. This default is useful for cases where the values after normalization are close to zero which makes the plot difficult to see.
channel_display_groups (Optional[Union[dict[str, Sequence[int]], Sequence[Sequence[int]]]]) – Groups of image channels to display together as a subplot when plotting the data and predictions. Can be a list or tuple of groups (e.g. [(0, 1, 2), (3,)]) or a dict containing title-to-group mappings (e.g. {“RGB”: [0, 1, 2], “IR”: [3]}), where each group is a list or tuple of channel indices and title is a string that will be used as the title of the subplot for that group.
- get_batch(dataset: Dataset, batch_sz: int = 4, **kwargs) tuple[torch.Tensor, Any] #
Generate a batch from a dataset.
This is a convenience method for generating a batch of data to plot.
- Parameters:
dataset (Dataset) – A Pytorch Dataset.
batch_sz (int) – Batch size. Defaults to 4.
**kwargs – Extra args for
DataLoader
.
- Returns:
(x, y) tuple where x is images and y is labels.
- Return type:
tuple[Tensor, Any]
- get_channel_display_groups(nb_img_channels: int) Union[dict[str, Sequence[int]], Sequence[Sequence[int]]] #
- get_collate_fn()[source]#
Returns a custom collate_fn to use in DataLoader.
None is returned if default collate_fn should be used.
See https://pytorch.org/docs/stable/data.html#working-with-collate-fn
- plot_batch(x: Tensor, y: Optional[Sequence] = None, output_path: str | None = None, z: Optional[Sequence] = None, batch_limit: int | None = None, show: bool = False)#
Plot a whole batch in a grid using plot_xyz.
- plot_xyz(axs: Sequence, x: Tensor, y: rastervision.pytorch_learner.object_detection_utils.BoxList | None = None, z: rastervision.pytorch_learner.object_detection_utils.BoxList | None = None, plot_title: bool = True) None [source]#
Plot image, ground truth labels, and predicted labels.
- Parameters:
axs (Sequence) – matplotlib axes on which to plot
x (Tensor) – image
y (rastervision.pytorch_learner.object_detection_utils.BoxList | None) – ground truth labels
z (rastervision.pytorch_learner.object_detection_utils.BoxList | None) – optional predicted labels
plot_title (bool) –
- Return type:
None
- property class_colors#
- property class_names#