Source code for rastervision.core.cli

from typing import List, Optional
import click

from rastervision.pipeline.file_system import get_tmp_dir
from rastervision.core.predictor import Predictor

[docs]class OptionEatAll(click.Option):
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._previous_parser_process = None self._eat_all_parser = None
[docs] def add_to_parser(self, parser, ctx): def is_next_option(arg: str) -> bool: for prefix in self._eat_all_parser.prefixes: if arg.startswith(prefix): return True return False def parser_process(value, state): # method to hook to the parser.process values = [value] # grab everything up to the next option while state.rargs and not is_next_option(state.rargs[0]): values.append(state.rargs.pop(0)) # call the actual process self._previous_parser_process(values, state) retval = super().add_to_parser(parser, ctx) for name in self.opts: our_parser = (parser._long_opt.get(name) or parser._short_opt.get(name)) if our_parser: self._eat_all_parser = our_parser self._previous_parser_process = our_parser.process our_parser.process = parser_process break return retval
@click.command( 'predict', short_help='Use a model bundle to predict on new images.') @click.argument('model_bundle') @click.argument('image_uri') @click.argument('label_uri') @click.option( '--update-stats', '-a', is_flag=True, help=('Run an analysis on this individual image, as ' 'opposed to using any analysis like statistics ' 'that exist in the prediction package')) @click.option( '--channel-order', cls=OptionEatAll, # type=list, help='List of indices comprising channel_order. Example: 2 1 0') @click.option( '--scene-group', help='Name of the scene group whose stats will be used by the ' 'StatsTransformer. Requires the stats for this scene group to be present ' 'inside the bundle.') def predict(model_bundle: str, image_uri: str, label_uri: str, update_stats: bool = False, channel_order: Optional[List[str]] = None, scene_group: Optional[str] = None): """Make predictions on the images at IMAGE_URI using MODEL_BUNDLE and store the prediction output at LABEL_URI. """ if channel_order is not None: channel_order: List[int] = [int(i) for i in channel_order] with get_tmp_dir() as tmp_dir: predictor = Predictor(model_bundle, tmp_dir, update_stats, channel_order, scene_group) predictor.predict([image_uri], label_uri)