Source code for rastervision.pytorch_learner.learner_pipeline
from rastervision.pipeline.pipeline import Pipeline
from rastervision.pytorch_learner import LearnerConfig
[docs]class LearnerPipeline(Pipeline):
"""Simple Pipeline that is a wrapper around Learner.main()
This supports the ability to use the pytorch_learner package to train models using
the RV pipeline package and its runner functionality without the rest of RV.
"""
commands = ['train']
gpu_commands = ['train']
[docs] def train(self):
learner_cfg: LearnerConfig = self.config.learner
learner = learner_cfg.build(learner_cfg, self.tmp_dir)
learner.main()