Source code for rastervision.pytorch_backend.pytorch_chip_classification

from typing import TYPE_CHECKING, Iterator, Optional
from os.path import join
import uuid

from rastervision.pipeline.file_system import (make_dir)
from rastervision.pytorch_backend.pytorch_learner_backend import (
    PyTorchLearnerSampleWriter, PyTorchLearnerBackend)
from rastervision.pytorch_learner.dataset import (
    ClassificationSlidingWindowGeoDataset)
from rastervision.core.data import ChipClassificationLabels

if TYPE_CHECKING:
    import numpy as np
    from rastervision.core.data import Scene
    from rastervision.core.data_sample import DataSample


[docs]class PyTorchChipClassificationSampleWriter(PyTorchLearnerSampleWriter):
[docs] def write_sample(self, sample: 'DataSample'): """ This writes a training or validation sample to (train|valid)/{class_name}/{scene_id}-{ind}.png """ class_id = sample.labels.get_cell_class_id(sample.window) # If a chip is not associated with a class, don't # use it in training data. if class_id is None: return split_name = 'train' if sample.is_train else 'valid' img_path = self.get_image_path(split_name, sample, class_id) self.write_chip(sample.chip, img_path) self.sample_ind += 1
[docs] def get_image_path(self, split_name: str, sample: 'DataSample', class_id: int) -> str: class_name = self.class_config.names[class_id] img_dir = join(self.sample_dir, split_name, class_name) make_dir(img_dir) sample_name = f'{sample.scene_id}-{self.sample_ind}' ext = self.get_image_ext(sample.chip) img_path = join(img_dir, f'{sample_name}.{ext}') return img_path
[docs]class PyTorchChipClassification(PyTorchLearnerBackend):
[docs] def get_sample_writer(self): output_uri = join(self.pipeline_cfg.chip_uri, f'{uuid.uuid4()}.zip') return PyTorchChipClassificationSampleWriter( output_uri, self.pipeline_cfg.dataset.class_config, self.tmp_dir)
[docs] def predict_scene(self, scene: 'Scene', chip_sz: int, stride: Optional[int] = None ) -> 'ChipClassificationLabels': if stride is None: stride = chip_sz if self.learner is None: self.load_model() # Important to use self.learner.cfg.data instead of # self.learner_cfg.data because of the updates # Learner.from_model_bundle() makes to the custom transforms. base_tf, _ = self.learner.cfg.data.get_data_transforms() ds = ClassificationSlidingWindowGeoDataset( scene, size=chip_sz, stride=stride, transform=base_tf) predictions: Iterator['np.array'] = self.learner.predict_dataset( ds, raw_out=True, numpy_out=True, progress_bar=True, progress_bar_kw=dict(desc=f'Making predictions on {scene.id}')) labels = ChipClassificationLabels.from_predictions( ds.windows, predictions) return labels