Source code for rastervision.aws_s3.s3_file_system

from typing import Tuple
import io
import os
import subprocess
from datetime import datetime
from urllib.parse import urlparse

from everett.manager import ConfigurationMissingError
from tqdm.auto import tqdm

from rastervision.pipeline.file_system import (FileSystem, NotReadableError,
                                               NotWritableError)

AWS_S3 = 'aws_s3'


# Code from https://alexwlchan.net/2017/07/listing-s3-keys/
[docs]def get_matching_s3_objects(bucket, prefix='', suffix='', request_payer='None'): """ Generate objects in an S3 bucket. :param bucket: Name of the S3 bucket. :param prefix: Only fetch objects whose key starts with this prefix (optional). :param suffix: Only fetch objects whose keys end with this suffix (optional). """ import boto3 s3 = boto3.client('s3') kwargs = {'Bucket': bucket, 'RequestPayer': request_payer} # If the prefix is a single string (not a tuple of strings), we can # do the filtering directly in the S3 API. if isinstance(prefix, str): kwargs['Prefix'] = prefix while True: # The S3 API response is a large blob of metadata. # 'Contents' contains information about the listed objects. resp = s3.list_objects_v2(**kwargs) try: contents = resp['Contents'] except KeyError: return for obj in contents: key = obj['Key'] if key.startswith(prefix) and key.endswith(suffix): yield obj # The S3 API is paginated, returning up to 1000 keys at a time. # Pass the continuation token into the next response, until we # reach the final page (when this field is missing). try: kwargs['ContinuationToken'] = resp['NextContinuationToken'] except KeyError: break
[docs]def get_matching_s3_keys(bucket, prefix='', suffix='', request_payer='None'): """ Generate the keys in an S3 bucket. :param bucket: Name of the S3 bucket. :param prefix: Only fetch keys that start with this prefix (optional). :param suffix: Only fetch keys that end with this suffix (optional). """ for obj in get_matching_s3_objects(bucket, prefix, suffix, request_payer): yield obj['Key']
[docs]def progressbar(total_size: int, desc: str): return tqdm( total=total_size, desc=desc, unit='B', unit_scale=True, unit_divisor=1024, mininterval=0.5, delay=5)
[docs]class S3FileSystem(FileSystem): """A FileSystem for interacting with files stored on AWS S3. Uses Everett configuration of form: ``` [AWS_S3] requester_pays=True ``` """
[docs] @staticmethod def get_request_payer(): # Import here to avoid circular reference. from rastervision.pipeline import rv_config_ as rv_config try: s3_config = rv_config.get_namespace_config(AWS_S3) # 'None' needs the quotes because boto3 cannot handle None. return ('requester' if s3_config( 'requester_pays', parser=bool, default='False') else 'None') except ConfigurationMissingError: return 'None'
[docs] @staticmethod def get_session(): # Lazily load boto import boto3 return boto3.Session()
[docs] @staticmethod def matches_uri(uri: str, mode: str) -> bool: parsed_uri = urlparse(uri) return parsed_uri.scheme == 's3'
[docs] @staticmethod def parse_uri(uri: str) -> Tuple[str, str]: """Parse bucket name and key from an S3 URI.""" parsed_uri = urlparse(uri) bucket, key = parsed_uri.netloc, parsed_uri.path[1:] return bucket, key
[docs] @staticmethod def file_exists(uri: str, include_dir: bool = True) -> bool: # Lazily load boto import botocore parsed_uri = urlparse(uri) bucket = parsed_uri.netloc key = parsed_uri.path[1:] request_payer = S3FileSystem.get_request_payer() if include_dir: s3 = S3FileSystem.get_session().client('s3') try: # Ensure key ends in slash so that this won't pick up files that # contain the key as a prefix, but aren't actually directories. # Example: if key is 'model' then we don't want to consider # model-123 a match. dir_key = key if key[-1] == '/' else key + '/' response = s3.list_objects_v2( Bucket=bucket, Prefix=dir_key, MaxKeys=1, RequestPayer=request_payer) if response['KeyCount'] == 0: return S3FileSystem.file_exists(uri, include_dir=False) return True except botocore.exceptions.ClientError: return False else: s3r = S3FileSystem.get_session().resource('s3') try: s3r.Object(bucket, key).load(RequestPayer=request_payer) return True except botocore.exceptions.ClientError: return False
[docs] @staticmethod def read_str(uri: str) -> str: return S3FileSystem.read_bytes(uri).decode('utf-8')
[docs] @staticmethod def read_bytes(uri: str) -> bytes: import botocore s3 = S3FileSystem.get_session().client('s3') request_payer = S3FileSystem.get_request_payer() bucket, key = S3FileSystem.parse_uri(uri) with io.BytesIO() as file_buffer: try: file_size = s3.head_object( Bucket=bucket, Key=key)['ContentLength'] with progressbar(file_size, desc='Downloading') as bar: s3.download_fileobj( Bucket=bucket, Key=key, Fileobj=file_buffer, Callback=lambda bytes: bar.update(bytes), ExtraArgs={'RequestPayer': request_payer}) return file_buffer.getvalue() except botocore.exceptions.ClientError as e: raise NotReadableError('Could not read {}'.format(uri)) from e
[docs] @staticmethod def write_str(uri: str, data: str) -> None: data = bytes(data, encoding='utf-8') S3FileSystem.write_bytes(uri, data)
[docs] @staticmethod def write_bytes(uri: str, data: bytes) -> None: s3 = S3FileSystem.get_session().client('s3') bucket, key = S3FileSystem.parse_uri(uri) file_size = len(data) with io.BytesIO(data) as str_buffer: try: with progressbar(file_size, desc=f'Uploading') as bar: s3.upload_fileobj( Fileobj=str_buffer, Bucket=bucket, Key=key, Callback=lambda bytes: bar.update(bytes)) except Exception as e: raise NotWritableError(f'Could not write {uri}') from e
[docs] @staticmethod def sync_from_dir(src_dir_uri: str, dst_dir: str, delete: bool = False) -> None: # pragma: no cover command = ['aws', 's3', 'sync', src_dir_uri, dst_dir] if delete: command.append('--delete') request_payer = S3FileSystem.get_request_payer() if request_payer: command.append('--request-payer') subprocess.run(command)
[docs] @staticmethod def sync_to_dir(src_dir: str, dst_dir_uri: str, delete: bool = False) -> None: # pragma: no cover S3FileSystem.sync_from_dir(src_dir, dst_dir_uri, delete=delete)
[docs] @staticmethod def copy_to(src_path: str, dst_uri: str) -> None: s3 = S3FileSystem.get_session().client('s3') bucket, key = S3FileSystem.parse_uri(dst_uri) if os.path.isfile(src_path): file_size = os.path.getsize(src_path) try: with progressbar(file_size, desc='Uploading') as bar: s3.upload_file( Filename=src_path, Bucket=bucket, Key=key, Callback=lambda bytes: bar.update(bytes)) except Exception as e: raise NotWritableError(f'Could not write {dst_uri}') from e else: S3FileSystem.sync_to_dir(src_path, dst_uri, delete=True)
[docs] @staticmethod def copy_from(src_uri: str, dst_path: str) -> None: import botocore s3 = S3FileSystem.get_session().client('s3') request_payer = S3FileSystem.get_request_payer() bucket, key = S3FileSystem.parse_uri(src_uri) try: file_size = s3.head_object(Bucket=bucket, Key=key)['ContentLength'] with progressbar(file_size, desc=f'Downloading') as bar: s3.download_file( Bucket=bucket, Key=key, Filename=dst_path, Callback=lambda bytes: bar.update(bytes), ExtraArgs={'RequestPayer': request_payer}) except botocore.exceptions.ClientError: raise NotReadableError(f'Could not read {src_uri}')
[docs] @staticmethod def local_path(uri: str, download_dir: str) -> None: parsed_uri = urlparse(uri) path = os.path.join(download_dir, 's3', parsed_uri.netloc, parsed_uri.path[1:]) return path
[docs] @staticmethod def last_modified(uri: str) -> datetime: bucket, key = S3FileSystem.parse_uri(uri) s3 = S3FileSystem.get_session().client('s3') request_payer = S3FileSystem.get_request_payer() head_data = s3.head_object( Bucket=bucket, Key=key, RequestPayer=request_payer) return head_data['LastModified']
[docs] @staticmethod def list_paths(uri, ext=''): request_payer = S3FileSystem.get_request_payer() parsed_uri = urlparse(uri) bucket = parsed_uri.netloc prefix = os.path.join(parsed_uri.path[1:]) keys = get_matching_s3_keys( bucket, prefix, suffix=ext, request_payer=request_payer) return [os.path.join('s3://', bucket, key) for key in keys]