Source code for rastervision.pytorch_backend.pytorch_learner_backend

from typing import TYPE_CHECKING
from os.path import join, splitext
import tempfile

import numpy as np
from tqdm.auto import tqdm

from rastervision.pipeline import rv_config_ as rv_config
from rastervision.pipeline.file_system import (make_dir, upload_or_copy,
                                               zipdir)
from rastervision.core.backend import Backend, SampleWriter
from rastervision.core.data.utils.misc import save_img
from rastervision.core.data_sample import DataSample
from rastervision.pytorch_learner.learner import Learner

if TYPE_CHECKING:
    from torch.utils.data import Dataset
    from rastervision.core.data import ClassConfig, DatasetConfig, Scene
    from rastervision.core.rv_pipeline import RVPipelineConfig, ChipOptions
    from rastervision.pytorch_learner import DataConfig, LearnerConfig

SPLITS = ['train', 'valid', 'test']


[docs]def write_chip(chip: np.ndarray, path: str) -> None: """Save chip as either a PNG image or a numpy array.""" ext = splitext(path)[-1] if ext == '.npy': np.save(path, chip) else: chip = chip.astype(np.uint8) save_img(chip, path)
[docs]def get_image_ext(chip: np.ndarray) -> str: """Decide which format to store the image in.""" if chip.ndim == 2 or chip.shape[-1] == 3: return 'png' else: return 'npy'
[docs]class PyTorchLearnerSampleWriter(SampleWriter):
[docs] def __init__(self, output_uri: str, class_config: 'ClassConfig', tmp_dir: str): """Constructor. Args: output_uri (str): URI of directory where zip file of chips should be placed. class_config (ClassConfig): used to convert class ids to names which may be needed for some training data formats. tmp_dir (str): local directory which is root of any temporary directories that are created. """ self.output_uri = output_uri self.class_config = class_config self.tmp_dir = tmp_dir
def __enter__(self): self.tmp_dir_obj = tempfile.TemporaryDirectory(dir=self.tmp_dir) self.sample_dir = join(self.tmp_dir_obj.name, 'samples') make_dir(self.sample_dir) self.sample_ind = 0 return self def __exit__(self, type, value, traceback): """ This writes a zip file for a group of scenes at {output_uri}/{uuid}.zip. This method is called once per instance of the chip command. A number of instances of the chip command can run simultaneously to process chips in parallel. The uuid in the zip path above is what allows separate instances to avoid overwriting each others' output. """ output_path = join(self.tmp_dir_obj.name, 'output.zip') zipdir(self.sample_dir, output_path) upload_or_copy(output_path, self.output_uri) self.tmp_dir_obj.cleanup()
[docs] def write_sample(self, sample: 'DataSample') -> None: """Write a single sample to disk.""" raise NotImplementedError()
[docs] def get_image_path(self, sample: 'DataSample') -> str: """Decide the save location of the image. Also, ensure that the target directory exists.""" split = '' if sample.split is None else sample.split img_dir = join(self.sample_dir, split, 'img') make_dir(img_dir) if sample.scene_id is not None: sample_name = f'{sample.scene_id}-{self.sample_ind}' else: sample_name = f'{self.sample_ind}' ext = self.get_image_ext(sample.chip) img_path = join(img_dir, f'{sample_name}.{ext}') return img_path
[docs] def get_image_ext(self, chip: np.ndarray) -> str: """Decide which format to store the image in.""" return get_image_ext(chip)
[docs] def write_chip(self, chip: np.ndarray, path: str) -> None: """Save chip as either a PNG image or a numpy array.""" write_chip(chip, path)
[docs]class PyTorchLearnerBackend(Backend): """Backend that uses the rastervision.pytorch_learner package to train models."""
[docs] def __init__(self, pipeline_cfg: 'RVPipelineConfig', learner_cfg: 'LearnerConfig', tmp_dir: str): self.pipeline_cfg = pipeline_cfg self.learner_cfg = learner_cfg self.tmp_dir = tmp_dir self.learner = None
[docs] def train(self, source_bundle_uri=None): if source_bundle_uri is not None: learner = self._build_learner_from_bundle( bundle_uri=source_bundle_uri, cfg=self.learner_cfg, training=True) else: learner = self.learner_cfg.build(self.tmp_dir, training=True) learner.main()
[docs] def load_model(self, uri: str | None = None): self.learner = self._build_learner_from_bundle( bundle_uri=uri, training=False)
def _build_learner_from_bundle(self, bundle_uri: str | None = None, cfg: 'LearnerConfig | None' = None, training: bool = False): if bundle_uri is None: bundle_uri = self.learner_cfg.get_model_bundle_uri() return Learner.from_model_bundle( bundle_uri, self.tmp_dir, cfg=cfg, training=training)
[docs] def get_sample_writer(self): raise NotImplementedError()
[docs] def chip_dataset(self, dataset: 'DatasetConfig', chip_options: 'ChipOptions', dataloader_kw: dict = {}) -> None: data_config = self._make_chip_data_config(dataset, chip_options) train_ds, valid_ds, test_ds = data_config.build(for_chipping=True) with self.get_sample_writer() as sample_writer: for split, ds in zip(SPLITS, [train_ds, valid_ds, test_ds]): if len(ds) == 0: continue self.chip_pytorch_dataset( ds, sample_writer=sample_writer, chip_options=chip_options, split=split, dataloader_kw=dataloader_kw)
[docs] def chip_pytorch_dataset( self, dataset: 'Dataset', sample_writer: 'PyTorchLearnerSampleWriter', chip_options: 'ChipOptions', split: str | None = None, dataloader_kw: dict = {}, ) -> None: from torch.utils.data import DataLoader num_workers = rv_config.get_namespace_option( 'rastervision', 'CHIP_NUM_WORKERS', default=self.learner_cfg.data.num_workers) batch_size = rv_config.get_namespace_option( 'rastervision', 'CHIP_BATCH_SIZE', default=self.learner_cfg.solver.batch_sz) dl_kw = dict( batch_size=int(batch_size), num_workers=int(num_workers), shuffle=False, pin_memory=True) dl_kw.update(dataloader_kw) dl = DataLoader(dataset, **dl_kw) if split is not None: desc = f'Chipping {split} scenes.' else: desc = f'Chipping dataset.' with tqdm(total=len(dataset), desc=desc) as bar: for (xs, ys), ws in dl: for x, y, w in zip(xs, ys, ws): if not chip_options.keep_chip(x, y): continue sample = DataSample(chip=x, label=y, window=w, split=split) sample_writer.write_sample(sample) bar.update(1)
[docs] def predict_scene(self, scene: 'Scene', chip_sz: int, stride: int | None = None): raise NotImplementedError()
def _make_chip_data_config(self, dataset: 'DatasetConfig', chip_options: 'ChipOptions') -> 'DataConfig': raise NotImplementedError()