Source code for rastervision.pytorch_learner.utils.torch_hub

from typing import Any
from pathlib import Path
from os.path import join, isdir, realpath
import shutil
from glob import glob

import torch.hub

from rastervision.pipeline.file_system import (download_if_needed, unzip,

def _remove_dir(path):
    """ Remove a directory if it exists. """
    if isdir(path):

def _repo_name_to_dir_name(repo: str) -> str:
    """Convert repo name to dir name per torch.hub naming conventions.

    Adapted from torch.hub._get_cache_or_reload()

        repo: <repo-owner>/<erpo-name>[:tag]

        Directory name
    from torch.hub import _parse_repo_info
    repo_owner, repo_name, branch = _parse_repo_info(repo)
    normalized_br = branch.replace('/', '_')
    dir_name = '_'.join([repo_owner, repo_name, normalized_br])
    return dir_name

def _uri_to_dir_name(uri: str) -> str:
    """ Determine directory name from a URI. """
    return Path(uri).stem

[docs]def get_hubconf_dir_from_cfg(cfg, parent: str | None = '') -> str: """Determine destination directory name from an ExternalModuleConfig. If a parent path is provided, the dir name is appended to it. Args: cfg (ExternalModuleConfig): an ExternalModuleConfig parent: Parent path. Defaults to ''. Returns: Directory name or path """ if is not None: dir_name = elif cfg.uri is not None: dir_name = _uri_to_dir_name(cfg.uri) else: dir_name = _repo_name_to_dir_name(cfg.github_repo) path = join(parent, dir_name) return path
[docs]def torch_hub_load_github(repo: str, entrypoint: str, *args, dst_dir: str | None = None, **kwargs) -> Any: """Load an entrypoint from a github repo using :func:`torch.hub.load`. Args: repo: <repo-owner>/<erpo-name>[:tag] entrypoint: Name of a Callable present in ````. *args: Args to be passed to the entrypoint. dst_dir: If provided, the contents of the repo are copied there. Defaults to None. **kwargs: Keyword args to be passed to the entrypoint. Returns: Any: The output from calling the entrypoint. """ out = torch.hub.load( repo, entrypoint, *args, source='github', skip_validation=True, **kwargs) if dst_dir is not None: orig_dir = join(torch.hub.get_dir(), _repo_name_to_dir_name(repo)) _remove_dir(dst_dir) shutil.move(orig_dir, dst_dir) return out
[docs]def torch_hub_load_uri(uri: str, entrypoint: str, *args, dst_dir: str | None = None, **kwargs) -> Any: """Load an entrypoint from a uri. Load an entrypoint from: - a local uri of a zip file, or - a local uri of a directory, or - a remote uri of zip file. The zip file should either have at the top level or contain a single sub-directory that contains at its top level. In the latter case, the sub-directory will be copied to dst_dir. Args: uri: A URI. entrypoint: Name of a Callable present in ````. *args: Args to be passed to the entrypoint. dst_dir: If provided, the contents from the uri are copied there. Defaults to None. **kwargs: Keyword args to be passed to the entrypoint. Returns: Any: The output from calling the entrypoint. """ uri_path = Path(uri) is_zip = uri_path.suffix.lower() == '.zip' if is_zip: zip_path = download_if_needed(uri) with get_tmp_dir() as tmp_dir: unzip_dir = join(tmp_dir, uri_path.stem) _remove_dir(unzip_dir) unzip(zip_path, target_dir=unzip_dir) unzipped_contents = list(glob(f'{unzip_dir}/*', recursive=False)) # if the top level only contains a directory if (len(unzipped_contents) == 1) and isdir(unzipped_contents[0]): sub_dir = unzipped_contents[0] scr_dir = sub_dir else: scr_dir = unzip_dir out = torch_hub_load_local(scr_dir, entrypoint, *args, **kwargs) if dst_dir is not None: _remove_dir(dst_dir) shutil.move(scr_dir, dst_dir) else: # assume uri is local out = torch_hub_load_local(uri, entrypoint, *args, **kwargs) if dst_dir is not None and realpath(uri) != realpath(dst_dir): _remove_dir(dst_dir) shutil.copytree(uri, dst_dir) return out
[docs]def torch_hub_load_local(hubconf_dir: str, entrypoint: str, *args, **kwargs) -> Any: """Wrapper around :func:`torch.hub.load` with ``source='local'``. Historical note: the code that was previously here was moved to :func:`torch.hub.load` (for its implementation of ``source='local'``), so now it just calls that function. """ return torch.hub.load( hubconf_dir, entrypoint, *args, source='local', skip_validation=True, **kwargs)