Source code for rastervision.pytorch_backend.pytorch_object_detection_config

from rastervision.pipeline.config import register_config
from rastervision.pytorch_backend.pytorch_learner_backend_config import (
    PyTorchLearnerBackendConfig)
from rastervision.pytorch_learner.learner_config import default_augmentors
from rastervision.pytorch_learner.object_detection_learner_config import (
    ObjectDetectionModelConfig, ObjectDetectionLearnerConfig,
    ObjectDetectionImageDataConfig)
from rastervision.pytorch_backend.pytorch_object_detection import (
    PyTorchObjectDetection)


def objdet_learner_backend_config_upgrader(cfg_dict, version):
    if version == 0:
        fields = {
            'augmentors': default_augmentors,
            'group_uris': None,
            'group_train_sz': None,
            'group_train_sz_rel': None,
            'num_workers': 4,
            'img_sz': None,
            'base_transform': None,
            'aug_transform': None,
            'plot_options': None,
            'preview_batch_limit': None
        }
        data_cfg_dict = {
            key: cfg_dict.pop(key, default_val)
            for key, default_val in fields.items() if key in cfg_dict
        }
        if data_cfg_dict['img_sz'] is None:
            data_cfg_dict['img_sz'] = 256

        data_cfg = ObjectDetectionImageDataConfig(**data_cfg_dict)
        data_cfg.update()
        data_cfg.validate_config()
        cfg_dict['data'] = data_cfg.dict()
    return cfg_dict


[docs]@register_config( 'pytorch_object_detection_backend', upgrader=objdet_learner_backend_config_upgrader) class PyTorchObjectDetectionConfig(PyTorchLearnerBackendConfig): """Configure a :class:`.PyTorchObjectDetection` backend.""" model: ObjectDetectionModelConfig
[docs] def get_learner_config(self, pipeline): learner = ObjectDetectionLearnerConfig( data=self.data, model=self.model, solver=self.solver, test_mode=self.test_mode, output_uri=pipeline.train_uri, log_tensorboard=self.log_tensorboard, run_tensorboard=self.run_tensorboard) learner.update() return learner
[docs] def build(self, pipeline, tmp_dir): learner = self.get_learner_config(pipeline) return PyTorchObjectDetection(pipeline, learner, tmp_dir)