Source code for rastervision.pytorch_learner.dataset.visualizer.classification_visualizer

from typing import TYPE_CHECKING, Optional, Sequence
from textwrap import wrap

import torch

from rastervision.pytorch_learner.dataset.visualizer import Visualizer  # NOQA
from rastervision.pytorch_learner.utils import (plot_channel_groups,
                                                channel_groups_to_imgs)

if TYPE_CHECKING:
    from matplotlib.pyplot import Axes


[docs]class ClassificationVisualizer(Visualizer): """Plots samples from image classification Datasets."""
[docs] def plot_xyz(self, axs: Sequence['Axes'], x: torch.Tensor, y: Optional[int] = None, z: Optional[int] = None, plot_title: bool = True) -> None: channel_groups = self.get_channel_display_groups(x.shape[1]) img_axes = axs[:-1] label_ax = axs[-1] # plot image imgs = channel_groups_to_imgs(x, channel_groups) plot_channel_groups( img_axes, imgs, channel_groups, plot_title=plot_title) # plot label class_names = self.class_names class_names = ['-\n-'.join(wrap(c, width=16)) for c in class_names] if y is not None and z is None: self.plot_gt(label_ax, class_names, y) elif z is not None: self.plot_pred(label_ax, class_names, z, y=y) if plot_title: label_ax.set_title('Prediction')
[docs] def plot_gt(self, ax: 'Axes', class_names: Sequence[str], y: torch.Tensor): """Display ground truth class names as text.""" class_name = class_names[y] ax.text( x=.5, y=.5, s=class_name, ha='center', va='center', fontdict={ 'size': 20, 'family': 'sans-serif' }) ax.set_xlim((0, 1)) ax.set_ylim((0, 1)) ax.axis('off')
[docs] def plot_pred(self, ax: 'Axes', class_names: Sequence[str], z: torch.Tensor, y: Optional[torch.Tensor] = None): """Plot predictions. Plots predicted class probabilities as a horizontal bar plot. If ground truth, y, is provided, the bar colors represent: green = ground truth, dark-red = wrong prediction, light-gray = other. In case predicted class matches ground truth, only one bar will be green and the others will be light-gray. """ class_probabilities = z.softmax(dim=-1) class_index_pred = z.argmax(dim=-1) bar_colors = ['lightgray'] * len(z) if y is not None: class_index_gt = y if class_index_pred == class_index_gt: bar_colors[class_index_pred] = 'green' else: bar_colors[class_index_pred] = 'darkred' bar_colors[class_index_gt] = 'green' ax.barh( y=class_names, width=class_probabilities, color=bar_colors, edgecolor='black') ax.set_xlim((0, 1)) ax.xaxis.grid(linestyle='--', alpha=1) ax.set_xlabel('Probability')
[docs] def get_plot_ncols(self, **kwargs) -> int: x = kwargs['x'] nb_img_channels = x.shape[1] ncols = len(self.get_channel_display_groups(nb_img_channels)) + 1 return ncols