from typing import TYPE_CHECKING, Optional, Sequence
from textwrap import wrap

import torch
import numpy as np
import matplotlib.pyplot as plt

from rastervision.pytorch_learner.dataset.visualizer import Visualizer  # NOQA
from rastervision.pytorch_learner.utils import (plot_channel_groups,

    from matplotlib.pyplot import Axes

[docs]class RegressionVisualizer(Visualizer): """Plots samples from image regression Datasets."""
[docs] def plot_xyz(self, axs: Sequence, x: torch.Tensor, y: int, z: Optional[int] = None, plot_title: bool = True) -> None: channel_groups = self.get_channel_display_groups(x.shape[1]) img_axes = axs[:-1] label_ax: 'Axes' = axs[-1] # plot image imgs = channel_groups_to_imgs(x, channel_groups) plot_channel_groups( img_axes, imgs, channel_groups, plot_title=plot_title) # plot label class_names = self.class_names class_names = ['-\n-'.join(wrap(c, width=8)) for c in class_names] if y is not None and z is None: self.plot_gt(label_ax, class_names, y) if plot_title: label_ax.set_title('Ground truth') elif z is not None: self.plot_pred(label_ax, class_names, z, y=y)
[docs] def plot_gt(self, ax: 'Axes', class_names: Sequence[str], y: torch.Tensor): """Plot targets as a horizontal bar plot with values at the tips.""" bars_gt = ax.barh( y=class_names, width=y, color='lightgray', edgecolor='black') # show values on the end of bars ax.bar_label(bars_gt, fmt='%.3f', padding=3) ax.xaxis.grid(linestyle='--', alpha=1) ax.set_xlabel('Value') ax.spines['right'].set_visible(False) ax.get_yaxis().tick_left()
[docs] def plot_pred(self, ax: 'Axes', class_names: Sequence[str], z: torch.Tensor, y: Optional[torch.Tensor] = None): """Plot targets and predictions as a grouped horizontal bar plot.""" # display targets and predictions as a grouped horizontal bar plot bar_thickness = 0.35 if y is not None else 0.70 y_tick_locs = np.arange(len(class_names)) if y is not None: bars_gt = ax.barh( y=y_tick_locs + bar_thickness / 2, width=y, height=bar_thickness, color='lightgray', edgecolor='black', label='true') # show values on the end of bars ax.bar_label(bars_gt, fmt='%.3f', padding=3) bars_pred = ax.barh( y=y_tick_locs - bar_thickness / 2, width=z, height=bar_thickness, color=plt.get_cmap('tab10')(0), edgecolor='black', label='pred') # show values on the end of bars ax.bar_label(bars_pred, fmt='%.3f', padding=3) ax.set_yticks(ticks=y_tick_locs, labels=class_names) ax.legend(ncol=2, loc='lower center', bbox_to_anchor=(0.5, 1.0)) ax.xaxis.grid(linestyle='--', alpha=1) ax.set_xlabel('Value') ax.spines['right'].set_visible(False) ax.get_yaxis().tick_left()
[docs] def get_plot_ncols(self, **kwargs) -> int: x = kwargs['x'] nb_img_channels = x.shape[1] ncols = len(self.get_channel_display_groups(nb_img_channels)) + 1 return ncols