Source code for rastervision.aws_s3.s3_file_system

from typing import Any, Iterator
import io
import os
import subprocess
from datetime import datetime
from urllib.parse import urlparse

import boto3
from tqdm.auto import tqdm

from rastervision.pipeline.file_system import (FileSystem, NotReadableError,
                                               NotWritableError)

AWS_S3 = 'aws_s3'


# Code from https://alexwlchan.net/2017/07/listing-s3-keys/
[docs]def get_matching_s3_objects( bucket: str, prefix: str = '', suffix: str = '', delimiter: str = '/', request_payer: str = 'None') -> Iterator[tuple[str, Any]]: """Generate objects in an S3 bucket. Args: bucket: Name of the S3 bucket. prefix: Only fetch objects whose key starts with this prefix. suffix: Only fetch objects whose keys end with this suffix. """ s3 = S3FileSystem.get_client() kwargs = dict( Bucket=bucket, RequestPayer=request_payer, Delimiter=delimiter, Prefix=prefix, ) while True: resp: dict = s3.list_objects_v2(**kwargs) dirs: list[dict[str, Any]] = resp.get('CommonPrefixes', {}) files: list[dict[str, Any]] = resp.get('Contents', {}) for obj in dirs: key: str = obj['Prefix'] if key.startswith(prefix) and key.endswith(suffix): yield key, obj for obj in files: key: str = obj['Key'] if key.startswith(prefix) and key.endswith(suffix): yield key, obj # The S3 API is paginated, returning up to 1000 keys at a time. # Pass the continuation token into the next response, until we # reach the final page (when this field is missing). try: kwargs['ContinuationToken'] = resp['NextContinuationToken'] except KeyError: break
[docs]def get_matching_s3_keys(bucket: str, prefix: str = '', suffix: str = '', delimiter: str = '/', request_payer: str = 'None') -> Iterator[str]: """Generate the keys in an S3 bucket. Args: bucket: Name of the S3 bucket. prefix: Only fetch keys that start with this prefix. suffix: Only fetch keys that end with this suffix. """ obj_iterator = get_matching_s3_objects( bucket, prefix=prefix, suffix=suffix, delimiter=delimiter, request_payer=request_payer) out = (key for key, _ in obj_iterator) return out
[docs]def progressbar(total_size: int, desc: str): return tqdm( total=total_size, desc=desc, unit='B', unit_scale=True, unit_divisor=1024, mininterval=0.5, delay=5)
[docs]class S3FileSystem(FileSystem): """A FileSystem for interacting with files stored on AWS S3. Uses Everett configuration of form: ``` [AWS_S3] requester_pays=True ``` """
[docs] @staticmethod def get_request_payer() -> str: # attempt to get from environ request_payer = os.getenv('AWS_REQUEST_PAYER', 'None') # attempt to get from RV config if request_payer == 'None': # Import here to avoid circular reference. from rastervision.pipeline import rv_config_ as rv_config requester_pays = rv_config.get_namespace_option( AWS_S3, 'requester_pays', as_bool=True) if requester_pays: request_payer = 'requester' return request_payer
[docs] @staticmethod def get_session(): return boto3.Session()
[docs] @staticmethod def get_client(): if os.getenv('AWS_NO_SIGN_REQUEST', '').lower() == 'yes': from botocore import UNSIGNED from botocore.config import Config s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED)) return s3 return S3FileSystem.get_session().client('s3')
[docs] @staticmethod def matches_uri(uri: str, mode: str) -> bool: parsed_uri = urlparse(uri) return parsed_uri.scheme == 's3'
[docs] @staticmethod def parse_uri(uri: str) -> tuple[str, str]: """Parse bucket name and key from an S3 URI.""" parsed_uri = urlparse(uri) bucket, key = parsed_uri.netloc, parsed_uri.path[1:] return bucket, key
[docs] @staticmethod def file_exists(uri: str, include_dir: bool = True) -> bool: # Lazily load boto import botocore parsed_uri = urlparse(uri) bucket = parsed_uri.netloc key = parsed_uri.path[1:] request_payer = S3FileSystem.get_request_payer() if include_dir: s3 = S3FileSystem.get_client() try: # Ensure key ends in slash so that this won't pick up files that # contain the key as a prefix, but aren't actually directories. # Example: if key is 'model' then we don't want to consider # model-123 a match. dir_key = key if key[-1] == '/' else key + '/' response = s3.list_objects_v2( Bucket=bucket, Prefix=dir_key, MaxKeys=1, RequestPayer=request_payer) if response['KeyCount'] == 0: return S3FileSystem.file_exists(uri, include_dir=False) return True except botocore.exceptions.ClientError: return False else: s3r = S3FileSystem.get_session().resource('s3') try: s3r.Object(bucket, key).load(RequestPayer=request_payer) return True except botocore.exceptions.ClientError: return False
[docs] @staticmethod def read_str(uri: str) -> str: return S3FileSystem.read_bytes(uri).decode('utf-8')
[docs] @staticmethod def read_bytes(uri: str) -> bytes: import botocore s3 = S3FileSystem.get_client() request_payer = S3FileSystem.get_request_payer() bucket, key = S3FileSystem.parse_uri(uri) with io.BytesIO() as file_buffer: try: obj = s3.head_object( Bucket=bucket, Key=key, RequestPayer=request_payer) file_size = obj['ContentLength'] with progressbar(file_size, desc='Downloading') as bar: s3.download_fileobj( Bucket=bucket, Key=key, Fileobj=file_buffer, Callback=lambda bytes: bar.update(bytes), ExtraArgs={'RequestPayer': request_payer}) return file_buffer.getvalue() except botocore.exceptions.ClientError as e: raise NotReadableError('Could not read {}'.format(uri)) from e
[docs] @staticmethod def write_str(uri: str, data: str) -> None: data = bytes(data, encoding='utf-8') S3FileSystem.write_bytes(uri, data)
[docs] @staticmethod def write_bytes(uri: str, data: bytes) -> None: s3 = S3FileSystem.get_client() bucket, key = S3FileSystem.parse_uri(uri) file_size = len(data) with io.BytesIO(data) as str_buffer: try: with progressbar(file_size, desc='Uploading') as bar: s3.upload_fileobj( Fileobj=str_buffer, Bucket=bucket, Key=key, Callback=lambda bytes: bar.update(bytes)) except Exception as e: raise NotWritableError(f'Could not write {uri}') from e
[docs] @staticmethod def sync_from_dir(src_dir_uri: str, dst_dir: str, delete: bool = False) -> None: # pragma: no cover command = ['aws', 's3', 'sync', src_dir_uri, dst_dir] if delete: command.append('--delete') request_payer = S3FileSystem.get_request_payer() if request_payer: command.append('--request-payer') subprocess.run(command)
[docs] @staticmethod def sync_to_dir(src_dir: str, dst_dir_uri: str, delete: bool = False) -> None: # pragma: no cover S3FileSystem.sync_from_dir(src_dir, dst_dir_uri, delete=delete)
[docs] @staticmethod def copy_to(src_path: str, dst_uri: str) -> None: s3 = S3FileSystem.get_client() bucket, key = S3FileSystem.parse_uri(dst_uri) if os.path.isfile(src_path): file_size = os.path.getsize(src_path) try: with progressbar(file_size, desc='Uploading') as bar: s3.upload_file( Filename=src_path, Bucket=bucket, Key=key, Callback=lambda bytes: bar.update(bytes)) except Exception as e: raise NotWritableError(f'Could not write {dst_uri}') from e else: S3FileSystem.sync_to_dir(src_path, dst_uri, delete=True)
[docs] @staticmethod def copy_from(src_uri: str, dst_path: str) -> None: import botocore s3 = S3FileSystem.get_client() request_payer = S3FileSystem.get_request_payer() bucket, key = S3FileSystem.parse_uri(src_uri) try: obj = s3.head_object( Bucket=bucket, Key=key, RequestPayer=request_payer) file_size = obj['ContentLength'] with progressbar(file_size, desc=f'Downloading') as bar: s3.download_file( Bucket=bucket, Key=key, Filename=dst_path, Callback=lambda bytes: bar.update(bytes), ExtraArgs={'RequestPayer': request_payer}) except botocore.exceptions.ClientError: raise NotReadableError(f'Could not read {src_uri}')
[docs] @staticmethod def local_path(uri: str, download_dir: str) -> None: parsed_uri = urlparse(uri) path = os.path.join(download_dir, 's3', parsed_uri.netloc, parsed_uri.path[1:]) return path
[docs] @staticmethod def last_modified(uri: str) -> datetime: bucket, key = S3FileSystem.parse_uri(uri) s3 = S3FileSystem.get_client() request_payer = S3FileSystem.get_request_payer() head_data = s3.head_object( Bucket=bucket, Key=key, RequestPayer=request_payer) return head_data['LastModified']
[docs] @staticmethod def list_paths(uri: str, ext: str = '', delimiter: str = '/') -> list[str]: request_payer = S3FileSystem.get_request_payer() if not uri.endswith('/'): uri += '/' parsed_uri = urlparse(uri) bucket = parsed_uri.netloc prefix = os.path.join(parsed_uri.path[1:]) keys = get_matching_s3_keys( bucket, prefix, suffix=ext, delimiter=delimiter, request_payer=request_payer) paths = [os.path.join('s3://', bucket, key) for key in keys] return paths