from typing import (Any, Dict, Sequence, Tuple, Optional, Union, List,
Iterable, Container)
from os.path import basename, join, isfile
import logging
import torch
from torch import nn
from torch.hub import _import_module
import numpy as np
from PIL import ImageColor
import albumentations as A
from albumentations.core.transforms_interface import ImageOnlyTransform
import cv2
import pandas as pd
from rastervision.pipeline.file_system import get_tmp_dir
from rastervision.pipeline.config import ConfigError
log = logging.getLogger(__name__)
[docs]def color_to_triple(color: Optional[str] = None) -> Tuple[int, int, int]:
"""Given a PIL ImageColor string, return a triple of integers
representing the red, green, and blue values.
If color is None, return a random color.
Args:
color: A PIL ImageColor string
Returns:
An triple of integers
"""
if color is None:
r = np.random.randint(0, 0x100)
g = np.random.randint(0, 0x100)
b = np.random.randint(0, 0x100)
return (r, g, b)
else:
return ImageColor.getrgb(color)
[docs]def compute_conf_mat(out, y, num_labels):
labels = torch.arange(0, num_labels).to(out.device)
return ((out == labels[:, None]) & (y == labels[:, None, None])).sum(
dim=2, dtype=torch.float32)
[docs]def compute_conf_mat_metrics(conf_mat, label_names, eps=1e-6):
# eps is to avoid dividing by zero.
eps = torch.tensor(eps)
conf_mat = conf_mat.cpu()
gt_count = conf_mat.sum(dim=1)
pred_count = conf_mat.sum(dim=0)
total = conf_mat.sum()
true_pos = torch.diag(conf_mat)
precision = true_pos / torch.max(pred_count, eps)
recall = true_pos / torch.max(gt_count, eps)
f1 = (2 * precision * recall) / torch.max(precision + recall, eps)
weights = gt_count / total
weighted_precision = (weights * precision).sum()
weighted_recall = (weights * recall).sum()
weighted_f1 = ((2 * weighted_precision * weighted_recall) / torch.max(
weighted_precision + weighted_recall, eps))
metrics = {
'avg_precision': weighted_precision.item(),
'avg_recall': weighted_recall.item(),
'avg_f1': weighted_f1.item()
}
for ind, label in enumerate(label_names):
metrics.update({
'{}_precision'.format(label): precision[ind].item(),
'{}_recall'.format(label): recall[ind].item(),
'{}_f1'.format(label): f1[ind].item(),
})
return metrics
[docs]class SplitTensor(nn.Module):
""" Wrapper around `torch.split` """
[docs] def __init__(self, size_or_sizes, dim):
super().__init__()
self.size_or_sizes = size_or_sizes
self.dim = dim
[docs] def forward(self, X):
return X.split(self.size_or_sizes, dim=self.dim)
[docs]class Parallel(nn.ModuleList):
""" Passes inputs through multiple `nn.Module`s in parallel.
Returns a tuple of outputs.
"""
[docs] def __init__(self, *args):
super().__init__(args)
[docs] def forward(self, xs):
if isinstance(xs, torch.Tensor):
return tuple(m(xs) for m in self)
assert len(xs) == len(self)
return tuple(m(x) for m, x in zip(self, xs))
[docs]class AddTensors(nn.Module):
""" Adds all its inputs together. """
[docs] def forward(self, xs):
return sum(xs)
[docs]class MinMaxNormalize(ImageOnlyTransform):
"""Albumentations transform that normalizes image to desired min and max values.
This will shift and scale the image appropriately to achieve the desired min and
max.
"""
[docs] def __init__(
self,
min_val=0.0,
max_val=1.0,
dtype=cv2.CV_32F,
always_apply=False,
p=1.0,
):
"""Constructor.
Args:
min_val: the minimum value that output should have
max_val: the maximum value that output should have
dtype: the dtype of output image
"""
super(MinMaxNormalize, self).__init__(always_apply, p)
self.min_val = min_val
self.max_val = max_val
self.dtype = dtype
def _apply_on_channel(self, image, **params):
out = cv2.normalize(
image,
None,
self.min_val,
self.max_val,
cv2.NORM_MINMAX,
dtype=self.dtype)
# We need to clip because sometimes values are slightly less or more than
# min_val and max_val due to rounding errors.
return np.clip(out, self.min_val, self.max_val)
[docs] def apply(self, image, **params):
if image.ndim <= 2:
return self._apply_on_channel(image, **params)
assert image.ndim == 3
chs = [
self._apply_on_channel(ch, **params)
for ch in image.transpose(2, 0, 1)
]
out = np.stack(chs, axis=2)
return out
[docs]def adjust_conv_channels(old_conv: nn.Conv2d,
in_channels: int,
pretrained: bool = True
) -> Union[nn.Conv2d, nn.Sequential]:
if in_channels == old_conv.in_channels:
return old_conv
# These parameters will be the same for the new conv layer.
# This list should be kept up to date with the Conv2d definition.
old_conv_args = {
'out_channels': old_conv.out_channels,
'kernel_size': old_conv.kernel_size,
'stride': old_conv.stride,
'padding': old_conv.padding,
'dilation': old_conv.dilation,
'groups': old_conv.groups,
'bias': old_conv.bias is not None,
'padding_mode': old_conv.padding_mode
}
if not pretrained:
# simply replace the first conv layer with one with the
# correct number of input channels
new_conv = nn.Conv2d(in_channels=in_channels, **old_conv_args)
return new_conv
if in_channels > old_conv.in_channels:
# insert a new conv layer parallel to the existing one
# and sum their outputs
extra_channels = in_channels - old_conv.in_channels
extra_conv = nn.Conv2d(in_channels=extra_channels, **old_conv_args)
new_conv = nn.Sequential(
# split input along channel dim
SplitTensor((old_conv.in_channels, extra_channels), dim=1),
# each split goes to its respective conv layer
Parallel(old_conv, extra_conv),
# sum the parallel outputs
AddTensors())
return new_conv
elif in_channels < old_conv.in_channels:
new_conv = nn.Conv2d(in_channels=in_channels, **old_conv_args)
pretrained_kernels = old_conv.weight.data[:, :in_channels]
new_conv.weight.data[:, :in_channels] = pretrained_kernels
return new_conv
else:
raise ConfigError(f'Something went wrong.')
[docs]def plot_channel_groups(axs: Iterable,
imgs: Iterable[Union[np.array, torch.Tensor]],
channel_groups: dict) -> None:
for title, ax, img in zip(channel_groups.keys(), axs, imgs):
ax.imshow(img)
ax.set_title(title)
ax.set_xticks([])
ax.set_yticks([])
[docs]def channel_groups_to_imgs(
x: torch.Tensor,
channel_groups: Dict[str, Sequence[int]]) -> List[torch.Tensor]:
imgs = []
for title, ch_inds in channel_groups.items():
img = x[..., ch_inds]
if len(ch_inds) == 1:
# repeat single channel 3 times
img = img.expand(-1, -1, 3)
elif len(ch_inds) == 2:
# add a 3rd channel with all pixels set to 0.5
h, w, _ = x.shape
third_channel = torch.full((h, w, 1), fill_value=.5)
img = torch.cat((img, third_channel), dim=-1)
elif len(ch_inds) > 3:
# only use the first 3 channels
log.warn(f'Only plotting first 3 channels of channel-group '
f'{title}: {ch_inds}.')
img = x[..., ch_inds[:3]]
imgs.append(img)
return imgs
[docs]def log_metrics_to_csv(csv_path: str, metrics: Dict[str, Any]):
"""Append epoch metrics to CSV file."""
# dict --> single-row DataFrame
metrics_df = pd.DataFrame.from_records([metrics])
# if file already exist, append row
log_file_exists = isfile(csv_path)
metrics_df.to_csv(
csv_path, mode='a', header=(not log_file_exists), index=False)
[docs]def aggregate_metrics(
outputs: List[Dict[str, Union[float, torch.Tensor]]],
exclude_keys: Container[str] = set('conf_mat')) -> Dict[str, float]:
"""Aggregate the ouput of validate_step at the end of the epoch.
Args:
outputs: A list of outputs of Learner.validate_step().
exclude_keys: Keys to ignore. These will not be aggregated and will not
be included in the output. Defaults to {'conf_mat'}.
Returns:
Dict[str, float]: Dict with aggregated values.
"""
metrics = {}
metric_names = outputs[0].keys()
for metric_name in metric_names:
if metric_name in exclude_keys:
continue
metric_vals = [out[metric_name] for out in outputs]
elem = metric_vals[0]
if isinstance(elem, torch.Tensor):
if elem.ndim == 0:
metric_vals = torch.stack(metric_vals)
else:
metric_vals = torch.cat(metric_vals)
metric_avg = metric_vals.float().mean().item()
else:
metric_avg = sum(metric_vals) / len(metric_vals)
metrics[metric_name] = metric_avg
return metrics