Source code for rastervision.pytorch_learner.dataset.visualizer.semantic_segmentation_visualizer

from typing import TYPE_CHECKING, Optional, Sequence, Union

import torch
import numpy as np
import matplotlib.colors as mcolors
import matplotlib.patches as mpatches

from rastervision.pytorch_learner.dataset.visualizer import Visualizer  # NOQA
from rastervision.pytorch_learner.utils import (
    color_to_triple, plot_channel_groups, channel_groups_to_imgs)

if TYPE_CHECKING:
    from matplotlib.pyplot import Axes
    from matplotlib.colors import Colormap


[docs]class SemanticSegmentationVisualizer(Visualizer): """Plots samples from semantic segmentation Datasets."""
[docs] def plot_xyz(self, axs: Sequence, x: torch.Tensor, y: Optional[Union[torch.Tensor, np.ndarray]] = None, z: Optional[torch.Tensor] = None, plot_title: bool = True) -> None: channel_groups = self.get_channel_display_groups(x.shape[1]) img_axes = axs[:len(channel_groups)] # plot image imgs = channel_groups_to_imgs(x, channel_groups) plot_channel_groups( img_axes, imgs, channel_groups, plot_title=plot_title) if y is None and z is None: return # plot labels class_colors = self.class_colors colors = [ color_to_triple(c) if isinstance(c, str) else c for c in class_colors ] colors = np.array(colors) / 255. cmap = mcolors.ListedColormap(colors) if y is not None: label_ax: 'Axes' = axs[len(channel_groups)] self.plot_gt(label_ax, y, num_classes=len(colors), cmap=cmap) if plot_title: label_ax.set_title('Ground truth') if z is not None: pred_ax = axs[-1] self.plot_pred(pred_ax, z, num_classes=len(colors), cmap=cmap) if plot_title: pred_ax.set_title('Predicted labels') # add a legend to the rightmost subplot class_names = self.class_names if class_names: legend_items = [ mpatches.Patch(facecolor=col, edgecolor='black', label=name) for col, name in zip(colors, class_names) ] axs[-1].legend( handles=legend_items, loc='center left', bbox_to_anchor=(1., 0.5))
[docs] def plot_gt(self, ax: 'Axes', y: Union[torch.Tensor, np.ndarray], num_classes: int, cmap: 'Colormap', **kwargs): ax.imshow( y, vmin=0, vmax=num_classes, cmap=cmap, interpolation='none', **kwargs) ax.set_xticks([]) ax.set_yticks([])
[docs] def plot_pred(self, ax: 'Axes', z: Union[torch.Tensor, np.ndarray], num_classes: int, cmap: 'Colormap', **kwargs): if z.ndim == 3: z = z.argmax(dim=0) self.plot_gt(ax, y=z, num_classes=num_classes, cmap=cmap, **kwargs)
[docs] def get_plot_ncols(self, **kwargs) -> int: x = kwargs['x'] nb_img_channels = x.shape[1] ncols = len(self.get_channel_display_groups(nb_img_channels)) if kwargs.get('y') is not None: ncols += 1 if kwargs.get('z') is not None: ncols += 1 return ncols