from typing import (TYPE_CHECKING, Sequence, Optional, List, Dict, Union,
Tuple, Any)
from abc import ABC, abstractmethod
import numpy as np
import torch
from torch import Tensor
import albumentations as A
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from rastervision.pipeline.file_system import make_dir
from rastervision.pytorch_learner.utils import (
deserialize_albumentation_transform, validate_albumentation_transform,
MinMaxNormalize)
from rastervision.pytorch_learner.learner_config import (
RGBTuple,
ChannelInds,
ensure_class_colors,
validate_channel_display_groups,
get_default_channel_display_groups,
)
if TYPE_CHECKING:
from torch.utils.data import Dataset
from matplotlib.figure import Figure
[docs]class Visualizer(ABC):
"""Base class for plotting samples from computer vision PyTorch Datasets."""
scale: float = 3.
[docs] def __init__(self,
class_names: List[str],
class_colors: Optional[List[Union[str, RGBTuple]]] = None,
transform: Optional[Dict] = A.to_dict(MinMaxNormalize()),
channel_display_groups: Optional[Union[Dict[
str, ChannelInds], Sequence[ChannelInds]]] = None):
"""Constructor.
Args:
class_names: names of classes
class_colors: Colors used to display classes. Can be color 3-tuples
in list form.
transform: An Albumentations transform serialized as a dict that
will be applied to each image before it is plotted. Mainly useful
for undoing any data transformation that you do not want included in
the plot, such as normalization. The default value will shift and scale
the image so the values range from 0.0 to 1.0 which is the expected range
for the plotting function. This default is useful for cases where the
values after normalization are close to zero which makes the plot
difficult to see.
channel_display_groups: Groups of image channels to display together as a
subplot when plotting the data and predictions.
Can be a list or tuple of groups (e.g. [(0, 1, 2), (3,)]) or a
dict containing title-to-group mappings
(e.g. {"RGB": [0, 1, 2], "IR": [3]}),
where each group is a list or tuple of channel indices and
title is a string that will be used as the title of the subplot
for that group.
"""
self.class_names = class_names
self.class_colors = ensure_class_colors(self.class_names, class_colors)
self.transform = validate_albumentation_transform(transform)
self._channel_display_groups = validate_channel_display_groups(
channel_display_groups)
[docs] @abstractmethod
def plot_xyz(self,
axs,
x: Tensor,
y: Sequence,
z: Optional[Sequence] = None,
plot_title: bool = True):
"""Plot image, ground truth labels, and predicted labels.
Args:
axs: matplotlib axes on which to plot
x: image
y: ground truth labels
z: optional predicted labels
"""
pass
[docs] def plot_batch(self,
x: Tensor,
y: Sequence,
output_path: Optional[str] = None,
z: Optional[Sequence] = None,
batch_limit: Optional[int] = None,
show: bool = False):
"""Plot a whole batch in a grid using plot_xyz.
Args:
x: batch of images
y: ground truth labels
output_path: local path where to save plot image
z: optional predicted labels
batch_limit: optional limit on (rendered) batch size
"""
params = self.get_plot_params(
x=x, y=y, z=z, output_path=output_path, batch_limit=batch_limit)
if params['subplot_args']['nrows'] == 0:
return
if x.ndim == 4:
fig, axs = plt.subplots(**params['fig_args'],
**params['subplot_args'])
plot_xyz_args = params['plot_xyz_args']
self._plot_batch(fig, axs, plot_xyz_args, x, y=y, z=z)
elif x.ndim == 5:
# If a temporal dimension is present, we divide the figure into
# multiple subfigures--one for each batch. Then, in each subfigure,
# we plot all timesteps as if they were a single batch. To
# delineate the boundary b/w batch items, we adopt the convention
# of only displaying subplot titles once per batch (above the first
# row in each batch).
batch_sz, T, *_ = x.shape
params['fig_args']['figsize'][1] *= T
fig = plt.figure(**params['fig_args'])
subfigs = fig.subfigures(
nrows=batch_sz, ncols=1, hspace=0.0, squeeze=False)
subfig_axs = [
subfig.subplots(
nrows=T, ncols=params['subplot_args']['ncols'])
for subfig in subfigs.flat
]
for i, axs in enumerate(subfig_axs):
plot_xyz_args = [
dict(params['plot_xyz_args'][i]) for _ in range(T)
]
plot_xyz_args[0]['plot_title'] = True
for args in plot_xyz_args[1:]:
args['plot_title'] = False
_x = x[i]
_y = [y[i]] * T
_z = None if z is None else [z[i]] * T
self._plot_batch(fig, axs, plot_xyz_args, _x, y=_y, z=_z)
else:
raise ValueError('Expected x to have 4 or 5 dims, but found '
f'x.shape: {x.shape}')
if show:
plt.show()
if output_path is not None:
make_dir(output_path, use_dirname=True)
plt.savefig(output_path, bbox_inches='tight', pad_inches=0.2)
plt.close(fig)
def _plot_batch(
self,
fig: 'Figure',
axs: Sequence,
plot_xyz_args: List[dict],
x: Tensor,
y: Optional[Sequence] = None,
z: Optional[Sequence] = None,
):
# (N, c, h, w) --> (N, h, w, c)
x = x.permute(0, 2, 3, 1)
# apply transform, if given
if self.transform is not None:
tf = deserialize_albumentation_transform(self.transform)
imgs = [tf(image=img)['image'] for img in x.numpy()]
x = torch.from_numpy(np.stack(imgs))
for i, row_axs in enumerate(axs):
_z = None if z is None else z[i]
self.plot_xyz(row_axs, x[i], y[i], z=_z, **plot_xyz_args[i])
[docs] def get_channel_display_groups(
self, nb_img_channels: int
) -> Union[Dict[str, ChannelInds], Sequence[ChannelInds]]:
# The default channel_display_groups object depends on the number of
# channels in the image. This number is not known when the Visualizer
# is constructed which is why it needs to be created later.
if self._channel_display_groups is not None:
return self._channel_display_groups
return get_default_channel_display_groups(nb_img_channels)
[docs] def get_collate_fn(self) -> Optional[callable]:
"""Returns a custom collate_fn to use in DataLoader.
None is returned if default collate_fn should be used.
See https://pytorch.org/docs/stable/data.html#working-with-collate-fn
"""
return None
[docs] def get_batch(self, dataset: 'Dataset', batch_sz: int = 4,
**kwargs) -> Tuple[Tensor, Any]:
"""Generate a batch from a dataset.
This is a convenience method for generating a batch of data to plot.
Args:
dataset (Dataset): A Pytorch Datset.
batch_sz (int): Batch size. Defaults to 4.
**kwargs: Extra args for :class:`~torch.utils.data.DataLoader`.
Returns:
Tuple[Tensor, Any]: (x, y) tuple where x is images and y is labels.
"""
collate_fn = self.get_collate_fn()
dl = DataLoader(dataset, batch_sz, collate_fn=collate_fn, **kwargs)
try:
x, y = next(iter(dl))
except StopIteration:
raise ValueError('dataset did not return a batch')
return x, y
[docs] def get_plot_nrows(self, **kwargs) -> int:
x = kwargs['x']
batch_limit = kwargs.get('batch_limit')
batch_sz = x.shape[0]
nrows = min(batch_sz,
batch_limit) if batch_limit is not None else batch_sz
return nrows
[docs] def get_plot_ncols(self, **kwargs) -> int:
x = kwargs['x']
nb_img_channels = x.shape[1]
ncols = len(self.get_channel_display_groups(nb_img_channels))
return ncols
[docs] def get_plot_params(self, **kwargs) -> dict:
nrows = self.get_plot_nrows(**kwargs)
ncols = self.get_plot_ncols(**kwargs)
params = {
'fig_args': {
'constrained_layout': True,
'figsize': np.array((self.scale * ncols, self.scale * nrows)),
},
'subplot_args': {
'nrows': nrows,
'ncols': ncols,
'squeeze': False
},
'plot_xyz_args': [{} for _ in range(nrows)]
}
return params