from typing import TYPE_CHECKING, Optional
import warnings
from os.path import join
import logging
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import torch
import torch.nn.functional as F
from rastervision.pytorch_learner.learner import Learner
from rastervision.pytorch_learner.dataset.visualizer import (
RegressionVisualizer)
if TYPE_CHECKING:
import torch.nn as nn
warnings.filterwarnings('ignore')
log = logging.getLogger(__name__)
[docs]class RegressionLearner(Learner):
[docs] def get_visualizer_class(self):
return RegressionVisualizer
[docs] def build_model(self, model_def_path: Optional[str] = None) -> 'nn.Module':
"""Override to pass class_names, pos_class_names, and prob_class_names.
"""
cfg = self.cfg
class_names = cfg.data.class_names
pos_class_names = cfg.data.pos_class_names
prob_class_names = cfg.data.prob_class_names
model = cfg.model.build(
num_classes=cfg.data.num_classes,
in_channels=cfg.data.img_channels,
save_dir=self.modules_dir,
hubconf_dir=model_def_path,
class_names=class_names,
pos_class_names=pos_class_names,
prob_class_names=prob_class_names)
return model
[docs] def on_overfit_start(self):
self.on_train_start()
[docs] def on_train_start(self):
ys = []
for _, y in self.train_dl:
ys.append(y)
y = torch.cat(ys, dim=0)
self.target_medians = y.median(dim=0).values.to(self.device)
[docs] def build_metric_names(self):
metric_names = [
'epoch', 'train_time', 'valid_time', 'train_loss', 'val_loss'
]
for label in self.cfg.data.class_names:
metric_names.extend([
'{}_abs_error'.format(label),
'{}_scaled_abs_error'.format(label)
])
return metric_names
[docs] def train_step(self, batch, batch_ind):
x, y = batch
out = self.post_forward(self.model(x))
return {'train_loss': F.l1_loss(out, y, reduction='sum')}
[docs] def validate_step(self, batch, batch_nb):
x, y = batch
out = self.post_forward(self.model(x))
val_loss = F.l1_loss(out, y, reduction='sum')
abs_error = torch.abs(out - y).sum(dim=0)
scaled_abs_error = (
torch.abs(out - y) / self.target_medians).sum(dim=0)
metrics = {'val_loss': val_loss}
for ind, label in enumerate(self.cfg.data.class_names):
metrics['{}_abs_error'.format(label)] = abs_error[ind]
metrics['{}_scaled_abs_error'.format(label)] = scaled_abs_error[
ind]
return metrics
[docs] def prob_to_pred(self, x):
return x
[docs] def eval_model(self, split):
super().eval_model(split)
y, out = self.predict_dataloader(
self.get_dataloader(split), return_format='yz', raw_out=False)
max_scatter_points = self.cfg.data.plot_options.max_scatter_points
if y.shape[0] > max_scatter_points:
scatter_inds = torch.randperm(
y.shape[0], dtype=torch.long)[0:max_scatter_points]
else:
scatter_inds = torch.arange(0, y.shape[0], dtype=torch.long)
# make scatter plot
num_labels = len(self.cfg.data.class_names)
ncols = num_labels
nrows = 1
fig = plt.figure(
constrained_layout=True, figsize=(5 * ncols, 5 * nrows))
grid = gridspec.GridSpec(ncols=ncols, nrows=nrows, figure=fig)
for label_ind, label in enumerate(self.cfg.data.class_names):
ax = fig.add_subplot(grid[label_ind])
ax.scatter(
y[scatter_inds, label_ind],
out[scatter_inds, label_ind],
c='blue',
alpha=0.1)
ax.set_title('{} on {} set'.format(label, split))
ax.set_xlabel('ground truth')
ax.set_ylabel('predictions')
scatter_path = join(self.output_dir, '{}_scatter.png'.format(split))
plt.savefig(scatter_path)
# make histogram of errors
fig = plt.figure(
constrained_layout=True, figsize=(5 * ncols, 5 * nrows))
grid = gridspec.GridSpec(ncols=ncols, nrows=nrows, figure=fig)
hist_bins = self.cfg.data.plot_options.hist_bins
for label_ind, label in enumerate(self.cfg.data.class_names):
ax = fig.add_subplot(grid[label_ind])
errs = torch.abs(y[:, label_ind] - out[:, label_ind]).tolist()
ax.hist(errs, bins=hist_bins)
ax.set_title('{} on {} set'.format(label, split))
ax.set_xlabel('prediction error')
hist_path = join(self.output_dir, '{}_err_hist.png'.format(split))
plt.savefig(hist_path)