Source code for rastervision.aws_s3.s3_file_system

from typing import Tuple
import io
import os
import subprocess
from datetime import datetime
from urllib.parse import urlparse

import boto3
from everett.manager import ConfigurationMissingError
from tqdm.auto import tqdm

from rastervision.pipeline.file_system import (FileSystem, NotReadableError,
                                               NotWritableError)

AWS_S3 = 'aws_s3'


# Code from https://alexwlchan.net/2017/07/listing-s3-keys/
[docs]def get_matching_s3_objects(bucket, prefix='', suffix='', request_payer='None'): """ Generate objects in an S3 bucket. :param bucket: Name of the S3 bucket. :param prefix: Only fetch objects whose key starts with this prefix (optional). :param suffix: Only fetch objects whose keys end with this suffix (optional). """ s3 = S3FileSystem.get_client() kwargs = {'Bucket': bucket, 'RequestPayer': request_payer} # If the prefix is a single string (not a tuple of strings), we can # do the filtering directly in the S3 API. if isinstance(prefix, str): kwargs['Prefix'] = prefix while True: # The S3 API response is a large blob of metadata. # 'Contents' contains information about the listed objects. resp = s3.list_objects_v2(**kwargs) try: contents = resp['Contents'] except KeyError: return for obj in contents: key = obj['Key'] if key.startswith(prefix) and key.endswith(suffix): yield obj # The S3 API is paginated, returning up to 1000 keys at a time. # Pass the continuation token into the next response, until we # reach the final page (when this field is missing). try: kwargs['ContinuationToken'] = resp['NextContinuationToken'] except KeyError: break
[docs]def get_matching_s3_keys(bucket, prefix='', suffix='', request_payer='None'): """ Generate the keys in an S3 bucket. :param bucket: Name of the S3 bucket. :param prefix: Only fetch keys that start with this prefix (optional). :param suffix: Only fetch keys that end with this suffix (optional). """ for obj in get_matching_s3_objects(bucket, prefix, suffix, request_payer): yield obj['Key']
[docs]def progressbar(total_size: int, desc: str): return tqdm( total=total_size, desc=desc, unit='B', unit_scale=True, unit_divisor=1024, mininterval=0.5, delay=5)
[docs]class S3FileSystem(FileSystem): """A FileSystem for interacting with files stored on AWS S3. Uses Everett configuration of form: ``` [AWS_S3] requester_pays=True ``` """
[docs] @staticmethod def get_request_payer(): # Import here to avoid circular reference. from rastervision.pipeline import rv_config_ as rv_config try: s3_config = rv_config.get_namespace_config(AWS_S3) # 'None' needs the quotes because boto3 cannot handle None. return ('requester' if s3_config( 'requester_pays', parser=bool, default='False') else 'None') except ConfigurationMissingError: return 'None'
[docs] @staticmethod def get_session(): return boto3.Session()
[docs] @staticmethod def get_client(): if os.getenv('AWS_NO_SIGN_REQUEST', '').lower() == 'yes': from botocore import UNSIGNED from botocore.config import Config s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED)) return s3 return S3FileSystem.get_session().client('s3')
[docs] @staticmethod def matches_uri(uri: str, mode: str) -> bool: parsed_uri = urlparse(uri) return parsed_uri.scheme == 's3'
[docs] @staticmethod def parse_uri(uri: str) -> Tuple[str, str]: """Parse bucket name and key from an S3 URI.""" parsed_uri = urlparse(uri) bucket, key = parsed_uri.netloc, parsed_uri.path[1:] return bucket, key
[docs] @staticmethod def file_exists(uri: str, include_dir: bool = True) -> bool: # Lazily load boto import botocore parsed_uri = urlparse(uri) bucket = parsed_uri.netloc key = parsed_uri.path[1:] request_payer = S3FileSystem.get_request_payer() if include_dir: s3 = S3FileSystem.get_client() try: # Ensure key ends in slash so that this won't pick up files that # contain the key as a prefix, but aren't actually directories. # Example: if key is 'model' then we don't want to consider # model-123 a match. dir_key = key if key[-1] == '/' else key + '/' response = s3.list_objects_v2( Bucket=bucket, Prefix=dir_key, MaxKeys=1, RequestPayer=request_payer) if response['KeyCount'] == 0: return S3FileSystem.file_exists(uri, include_dir=False) return True except botocore.exceptions.ClientError: return False else: s3r = S3FileSystem.get_session().resource('s3') try: s3r.Object(bucket, key).load(RequestPayer=request_payer) return True except botocore.exceptions.ClientError: return False
[docs] @staticmethod def read_str(uri: str) -> str: return S3FileSystem.read_bytes(uri).decode('utf-8')
[docs] @staticmethod def read_bytes(uri: str) -> bytes: import botocore s3 = S3FileSystem.get_client() request_payer = S3FileSystem.get_request_payer() bucket, key = S3FileSystem.parse_uri(uri) with io.BytesIO() as file_buffer: try: file_size = s3.head_object( Bucket=bucket, Key=key)['ContentLength'] with progressbar(file_size, desc='Downloading') as bar: s3.download_fileobj( Bucket=bucket, Key=key, Fileobj=file_buffer, Callback=lambda bytes: bar.update(bytes), ExtraArgs={'RequestPayer': request_payer}) return file_buffer.getvalue() except botocore.exceptions.ClientError as e: raise NotReadableError('Could not read {}'.format(uri)) from e
[docs] @staticmethod def write_str(uri: str, data: str) -> None: data = bytes(data, encoding='utf-8') S3FileSystem.write_bytes(uri, data)
[docs] @staticmethod def write_bytes(uri: str, data: bytes) -> None: s3 = S3FileSystem.get_client() bucket, key = S3FileSystem.parse_uri(uri) file_size = len(data) with io.BytesIO(data) as str_buffer: try: with progressbar(file_size, desc=f'Uploading') as bar: s3.upload_fileobj( Fileobj=str_buffer, Bucket=bucket, Key=key, Callback=lambda bytes: bar.update(bytes)) except Exception as e: raise NotWritableError(f'Could not write {uri}') from e
[docs] @staticmethod def sync_from_dir(src_dir_uri: str, dst_dir: str, delete: bool = False) -> None: # pragma: no cover command = ['aws', 's3', 'sync', src_dir_uri, dst_dir] if delete: command.append('--delete') request_payer = S3FileSystem.get_request_payer() if request_payer: command.append('--request-payer') subprocess.run(command)
[docs] @staticmethod def sync_to_dir(src_dir: str, dst_dir_uri: str, delete: bool = False) -> None: # pragma: no cover S3FileSystem.sync_from_dir(src_dir, dst_dir_uri, delete=delete)
[docs] @staticmethod def copy_to(src_path: str, dst_uri: str) -> None: s3 = S3FileSystem.get_client() bucket, key = S3FileSystem.parse_uri(dst_uri) if os.path.isfile(src_path): file_size = os.path.getsize(src_path) try: with progressbar(file_size, desc='Uploading') as bar: s3.upload_file( Filename=src_path, Bucket=bucket, Key=key, Callback=lambda bytes: bar.update(bytes)) except Exception as e: raise NotWritableError(f'Could not write {dst_uri}') from e else: S3FileSystem.sync_to_dir(src_path, dst_uri, delete=True)
[docs] @staticmethod def copy_from(src_uri: str, dst_path: str) -> None: import botocore s3 = S3FileSystem.get_client() request_payer = S3FileSystem.get_request_payer() bucket, key = S3FileSystem.parse_uri(src_uri) try: file_size = s3.head_object(Bucket=bucket, Key=key)['ContentLength'] with progressbar(file_size, desc=f'Downloading') as bar: s3.download_file( Bucket=bucket, Key=key, Filename=dst_path, Callback=lambda bytes: bar.update(bytes), ExtraArgs={'RequestPayer': request_payer}) except botocore.exceptions.ClientError: raise NotReadableError(f'Could not read {src_uri}')
[docs] @staticmethod def local_path(uri: str, download_dir: str) -> None: parsed_uri = urlparse(uri) path = os.path.join(download_dir, 's3', parsed_uri.netloc, parsed_uri.path[1:]) return path
[docs] @staticmethod def last_modified(uri: str) -> datetime: bucket, key = S3FileSystem.parse_uri(uri) s3 = S3FileSystem.get_client() request_payer = S3FileSystem.get_request_payer() head_data = s3.head_object( Bucket=bucket, Key=key, RequestPayer=request_payer) return head_data['LastModified']
[docs] @staticmethod def list_paths(uri, ext=''): request_payer = S3FileSystem.get_request_payer() parsed_uri = urlparse(uri) bucket = parsed_uri.netloc prefix = os.path.join(parsed_uri.path[1:]) keys = get_matching_s3_keys( bucket, prefix, suffix=ext, request_payer=request_payer) return [os.path.join('s3://', bucket, key) for key in keys]