Source code for rastervision.aws_s3.s3_file_system

from typing import Any, Iterator, Tuple
import io
import os
import subprocess
from datetime import datetime
from urllib.parse import urlparse

import boto3
from everett.manager import ConfigurationMissingError
from import tqdm

from rastervision.pipeline.file_system import (FileSystem, NotReadableError,

AWS_S3 = 'aws_s3'

# Code from
[docs]def get_matching_s3_objects( bucket: str, prefix: str = '', suffix: str = '', delimiter: str = '/', request_payer: str = 'None') -> Iterator[tuple[str, Any]]: """Generate objects in an S3 bucket. Args: bucket: Name of the S3 bucket. prefix: Only fetch objects whose key starts with this prefix. suffix: Only fetch objects whose keys end with this suffix. """ s3 = S3FileSystem.get_client() kwargs = dict( Bucket=bucket, RequestPayer=request_payer, Delimiter=delimiter, Prefix=prefix, ) while True: resp: dict = s3.list_objects_v2(**kwargs) dirs: list[dict] = resp.get('CommonPrefixes', {}) files: list[dict] = resp.get('Contents', {}) for obj in dirs: key = obj['Prefix'] if key.startswith(prefix) and key.endswith(suffix): yield key, obj for obj in files: key = obj['Key'] if key.startswith(prefix) and key.endswith(suffix): yield key, obj # The S3 API is paginated, returning up to 1000 keys at a time. # Pass the continuation token into the next response, until we # reach the final page (when this field is missing). try: kwargs['ContinuationToken'] = resp['NextContinuationToken'] except KeyError: break
[docs]def get_matching_s3_keys(bucket: str, prefix: str = '', suffix: str = '', delimiter: str = '/', request_payer: str = 'None') -> Iterator[str]: """Generate the keys in an S3 bucket. Args: bucket: Name of the S3 bucket. prefix: Only fetch keys that start with this prefix. suffix: Only fetch keys that end with this suffix. """ obj_iterator = get_matching_s3_objects( bucket, prefix=prefix, suffix=suffix, delimiter=delimiter, request_payer=request_payer) out = (key for key, _ in obj_iterator) return out
[docs]def progressbar(total_size: int, desc: str): return tqdm( total=total_size, desc=desc, unit='B', unit_scale=True, unit_divisor=1024, mininterval=0.5, delay=5)
[docs]class S3FileSystem(FileSystem): """A FileSystem for interacting with files stored on AWS S3. Uses Everett configuration of form: ``` [AWS_S3] requester_pays=True ``` """
[docs] @staticmethod def get_request_payer(): # Import here to avoid circular reference. from rastervision.pipeline import rv_config_ as rv_config try: s3_config = rv_config.get_namespace_config(AWS_S3) # 'None' needs the quotes because boto3 cannot handle None. return ('requester' if s3_config( 'requester_pays', parser=bool, default='False') else 'None') except ConfigurationMissingError: return 'None'
[docs] @staticmethod def get_session(): return boto3.Session()
[docs] @staticmethod def get_client(): if os.getenv('AWS_NO_SIGN_REQUEST', '').lower() == 'yes': from botocore import UNSIGNED from botocore.config import Config s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED)) return s3 return S3FileSystem.get_session().client('s3')
[docs] @staticmethod def matches_uri(uri: str, mode: str) -> bool: parsed_uri = urlparse(uri) return parsed_uri.scheme == 's3'
[docs] @staticmethod def parse_uri(uri: str) -> Tuple[str, str]: """Parse bucket name and key from an S3 URI.""" parsed_uri = urlparse(uri) bucket, key = parsed_uri.netloc, parsed_uri.path[1:] return bucket, key
[docs] @staticmethod def file_exists(uri: str, include_dir: bool = True) -> bool: # Lazily load boto import botocore parsed_uri = urlparse(uri) bucket = parsed_uri.netloc key = parsed_uri.path[1:] request_payer = S3FileSystem.get_request_payer() if include_dir: s3 = S3FileSystem.get_client() try: # Ensure key ends in slash so that this won't pick up files that # contain the key as a prefix, but aren't actually directories. # Example: if key is 'model' then we don't want to consider # model-123 a match. dir_key = key if key[-1] == '/' else key + '/' response = s3.list_objects_v2( Bucket=bucket, Prefix=dir_key, MaxKeys=1, RequestPayer=request_payer) if response['KeyCount'] == 0: return S3FileSystem.file_exists(uri, include_dir=False) return True except botocore.exceptions.ClientError: return False else: s3r = S3FileSystem.get_session().resource('s3') try: s3r.Object(bucket, key).load(RequestPayer=request_payer) return True except botocore.exceptions.ClientError: return False
[docs] @staticmethod def read_str(uri: str) -> str: return S3FileSystem.read_bytes(uri).decode('utf-8')
[docs] @staticmethod def read_bytes(uri: str) -> bytes: import botocore s3 = S3FileSystem.get_client() request_payer = S3FileSystem.get_request_payer() bucket, key = S3FileSystem.parse_uri(uri) with io.BytesIO() as file_buffer: try: obj = s3.head_object( Bucket=bucket, Key=key, RequestPayer=request_payer) file_size = obj['ContentLength'] with progressbar(file_size, desc='Downloading') as bar: s3.download_fileobj( Bucket=bucket, Key=key, Fileobj=file_buffer, Callback=lambda bytes: bar.update(bytes), ExtraArgs={'RequestPayer': request_payer}) return file_buffer.getvalue() except botocore.exceptions.ClientError as e: raise NotReadableError('Could not read {}'.format(uri)) from e
[docs] @staticmethod def write_str(uri: str, data: str) -> None: data = bytes(data, encoding='utf-8') S3FileSystem.write_bytes(uri, data)
[docs] @staticmethod def write_bytes(uri: str, data: bytes) -> None: s3 = S3FileSystem.get_client() bucket, key = S3FileSystem.parse_uri(uri) file_size = len(data) with io.BytesIO(data) as str_buffer: try: with progressbar(file_size, desc=f'Uploading') as bar: s3.upload_fileobj( Fileobj=str_buffer, Bucket=bucket, Key=key, Callback=lambda bytes: bar.update(bytes)) except Exception as e: raise NotWritableError(f'Could not write {uri}') from e
[docs] @staticmethod def sync_from_dir(src_dir_uri: str, dst_dir: str, delete: bool = False) -> None: # pragma: no cover command = ['aws', 's3', 'sync', src_dir_uri, dst_dir] if delete: command.append('--delete') request_payer = S3FileSystem.get_request_payer() if request_payer: command.append('--request-payer')
[docs] @staticmethod def sync_to_dir(src_dir: str, dst_dir_uri: str, delete: bool = False) -> None: # pragma: no cover S3FileSystem.sync_from_dir(src_dir, dst_dir_uri, delete=delete)
[docs] @staticmethod def copy_to(src_path: str, dst_uri: str) -> None: s3 = S3FileSystem.get_client() bucket, key = S3FileSystem.parse_uri(dst_uri) if os.path.isfile(src_path): file_size = os.path.getsize(src_path) try: with progressbar(file_size, desc='Uploading') as bar: s3.upload_file( Filename=src_path, Bucket=bucket, Key=key, Callback=lambda bytes: bar.update(bytes)) except Exception as e: raise NotWritableError(f'Could not write {dst_uri}') from e else: S3FileSystem.sync_to_dir(src_path, dst_uri, delete=True)
[docs] @staticmethod def copy_from(src_uri: str, dst_path: str) -> None: import botocore s3 = S3FileSystem.get_client() request_payer = S3FileSystem.get_request_payer() bucket, key = S3FileSystem.parse_uri(src_uri) try: obj = s3.head_object( Bucket=bucket, Key=key, RequestPayer=request_payer) file_size = obj['ContentLength'] with progressbar(file_size, desc=f'Downloading') as bar: s3.download_file( Bucket=bucket, Key=key, Filename=dst_path, Callback=lambda bytes: bar.update(bytes), ExtraArgs={'RequestPayer': request_payer}) except botocore.exceptions.ClientError: raise NotReadableError(f'Could not read {src_uri}')
[docs] @staticmethod def local_path(uri: str, download_dir: str) -> None: parsed_uri = urlparse(uri) path = os.path.join(download_dir, 's3', parsed_uri.netloc, parsed_uri.path[1:]) return path
[docs] @staticmethod def last_modified(uri: str) -> datetime: bucket, key = S3FileSystem.parse_uri(uri) s3 = S3FileSystem.get_client() request_payer = S3FileSystem.get_request_payer() head_data = s3.head_object( Bucket=bucket, Key=key, RequestPayer=request_payer) return head_data['LastModified']
[docs] @staticmethod def list_paths(uri: str, ext: str = '', delimiter: str = '/') -> list[str]: request_payer = S3FileSystem.get_request_payer() parsed_uri = urlparse(uri) bucket = parsed_uri.netloc prefix = os.path.join(parsed_uri.path[1:]) keys = get_matching_s3_keys( bucket, prefix, suffix=ext, delimiter=delimiter, request_payer=request_payer) paths = [os.path.join('s3://', bucket, key) for key in keys] return paths