Source code for rastervision.pytorch_backend.pytorch_chip_classification_config

from rastervision.pipeline.config import register_config
from rastervision.pytorch_backend.pytorch_learner_backend_config import (
    PyTorchLearnerBackendConfig)
from rastervision.pytorch_learner.learner_config import default_augmentors
from rastervision.pytorch_learner.classification_learner_config import (
    ClassificationModelConfig, ClassificationLearnerConfig,
    ClassificationImageDataConfig)
from rastervision.pytorch_backend.pytorch_chip_classification import (
    PyTorchChipClassification)


def clf_learner_backend_config_upgrader(cfg_dict, version):  # pragma: no cover
    if version == 0:
        fields = {
            'augmentors': default_augmentors,
            'group_uris': None,
            'group_train_sz': None,
            'group_train_sz_rel': None,
            'num_workers': 4,
            'img_sz': None,
            'base_transform': None,
            'aug_transform': None,
            'plot_options': None,
            'preview_batch_limit': None
        }
        data_cfg_dict = {
            key: cfg_dict.pop(key, default_val)
            for key, default_val in fields.items() if key in cfg_dict
        }
        if data_cfg_dict['img_sz'] is None:
            data_cfg_dict['img_sz'] = 256

        data_cfg = ClassificationImageDataConfig(**data_cfg_dict)
        data_cfg.update()
        data_cfg.validate_config()
        cfg_dict['data'] = data_cfg.dict()
    return cfg_dict


[docs]@register_config( 'pytorch_chip_classification_backend', upgrader=clf_learner_backend_config_upgrader) class PyTorchChipClassificationConfig(PyTorchLearnerBackendConfig): """Configure a :class:`.PyTorchChipClassification` backend.""" model: ClassificationModelConfig
[docs] def get_learner_config(self, pipeline): learner = ClassificationLearnerConfig( data=self.data, model=self.model, solver=self.solver, output_uri=pipeline.train_uri, log_tensorboard=self.log_tensorboard, run_tensorboard=self.run_tensorboard, save_all_checkpoints=self.save_all_checkpoints) learner.update() learner.validate_config() return learner
[docs] def build(self, pipeline, tmp_dir): learner = self.get_learner_config(pipeline) return PyTorchChipClassification(pipeline, learner, tmp_dir)