Source code for rastervision.pytorch_learner.utils.torch_hub

from typing import Any, Optional
from pathlib import Path
from os.path import join, isdir, realpath
import shutil
from glob import glob

import torch.hub

from rastervision.pipeline.file_system import (download_if_needed, unzip,
                                               get_tmp_dir)


def _remove_dir(path):
    """ Remove a directory if it exists. """
    if isdir(path):
        shutil.rmtree(path)


def _repo_name_to_dir_name(repo: str) -> str:
    """Convert repo name to dir name per torch.hub naming conventions.

    Adapted from torch.hub._get_cache_or_reload()

    Args:
        repo (str): <repo-owner>/<erpo-name>[:tag]

    Returns:
        str: directory name
    """
    from torch.hub import _parse_repo_info
    repo_owner, repo_name, branch = _parse_repo_info(repo)
    normalized_br = branch.replace('/', '_')
    dir_name = '_'.join([repo_owner, repo_name, normalized_br])
    return dir_name


def _uri_to_dir_name(uri: str) -> str:
    """ Determine directory name from a URI. """
    return Path(uri).stem


[docs]def get_hubconf_dir_from_cfg(cfg, parent: Optional[str] = '') -> str: """Determine destination directory name from an ExternalModuleConfig. If a parent path is provided, the dir name is appended to it. Args: cfg (ExternalModuleConfig): an ExternalModuleConfig parent (str, optional): Parent path. Defaults to ''. Returns: str: directory name or path """ if cfg.name is not None: dir_name = cfg.name elif cfg.uri is not None: dir_name = _uri_to_dir_name(cfg.uri) else: dir_name = _repo_name_to_dir_name(cfg.github_repo) path = join(parent, dir_name) return path
[docs]def torch_hub_load_github(repo: str, hubconf_dir: str, entrypoint: str, *args, **kwargs) -> Any: """Load an entrypoint from a github repo using torch.hub.load(). Args: repo (str): <repo-owner>/<erpo-name>[:tag] hubconf_dir (str): Where the contents from the uri will finally be saved to. entrypoint (str): Name of a callable present in hubconf.py. *args: Args to be passed to the entrypoint. **kwargs: Keyword args to be passed to the entrypoint. Returns: Any: The output from calling the entrypoint. """ out = torch.hub.load( repo, entrypoint, *args, source='github', skip_validation=True, **kwargs) orig_dir = join(torch.hub.get_dir(), _repo_name_to_dir_name(repo)) _remove_dir(hubconf_dir) shutil.move(orig_dir, hubconf_dir) return out
[docs]def torch_hub_load_uri(uri: str, hubconf_dir: str, entrypoint: str, *args, **kwargs) -> Any: """Load an entrypoint from a uri. Load an entrypoint from: - a local uri of a zip file, or - a local uri of a directory, or - a remote uri of zip file. The zip file should either have hubconf.py at the top level or contain a single sub-directory that contains hubconf.py at its top level. In the latter case, the sub-directory will be copied to hubconf_dir. Args: uri (str): A URI. hubconf_dir (str): The target directory where the contents from the uri will finally be saved to. entrypoint (str): Name of a callable present in hubconf.py. *args: Args to be passed to the entrypoint. **kwargs: Keyword args to be passed to the entrypoint. Returns: Any: The output from calling the entrypoint. """ uri_path = Path(uri) is_zip = uri_path.suffix.lower() == '.zip' if is_zip: # unzip zip_path = download_if_needed(uri) with get_tmp_dir() as tmp_dir: unzip_dir = join(tmp_dir, uri_path.stem) _remove_dir(unzip_dir) unzip(zip_path, target_dir=unzip_dir) unzipped_contents = list(glob(f'{unzip_dir}/*', recursive=False)) _remove_dir(hubconf_dir) # if the top level only contains a directory if (len(unzipped_contents) == 1) and isdir(unzipped_contents[0]): sub_dir = unzipped_contents[0] shutil.move(sub_dir, hubconf_dir) else: shutil.move(unzip_dir, hubconf_dir) # assume uri is local and attempt copying else: # only copy if needed if realpath(uri) != realpath(hubconf_dir): _remove_dir(hubconf_dir) shutil.copytree(uri, hubconf_dir) out = torch_hub_load_local(hubconf_dir, entrypoint, *args, **kwargs) return out
[docs]def torch_hub_load_local(hubconf_dir: str, entrypoint: str, *args, **kwargs) -> Any: return torch.hub.load( hubconf_dir, entrypoint, *args, source='local', skip_validation=True, **kwargs)