Skip to content

Commit 0cd2878

Browse files
EnableAsyncgaotongxiao
authored andcommitted
[Feature] AWS S3 obtainer support (#1888)
* feat: add aws s3 obtainer feat: add aws s3 obtainer fix: format fix: format * fix: avoid duplicated code fix: code format * fix: runtime.txt * fix: remove duplicated code
1 parent bbe8964 commit 0cd2878

File tree

3 files changed

+125
-1
lines changed

3 files changed

+125
-1
lines changed
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
from .aws_s3_obtainer import AWSS3Obtainer
23
from .naive_data_obtainer import NaiveDataObtainer
34

4-
__all__ = ['NaiveDataObtainer']
5+
__all__ = ['NaiveDataObtainer', 'AWSS3Obtainer']
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import os.path as osp
3+
import ssl
4+
from typing import Dict, List, Optional
5+
6+
from mmengine import mkdir_or_exist
7+
8+
from mmocr.registry import DATA_OBTAINERS
9+
from .naive_data_obtainer import NaiveDataObtainer
10+
11+
ssl._create_default_https_context = ssl._create_unverified_context
12+
13+
14+
@DATA_OBTAINERS.register_module()
15+
class AWSS3Obtainer(NaiveDataObtainer):
16+
"""A AWS S3 obtainer.
17+
18+
download -> extract -> move
19+
20+
Args:
21+
files (list[dict]): A list of file information.
22+
cache_path (str): The path to cache the downloaded files.
23+
data_root (str): The root path of the dataset. It is usually set auto-
24+
matically and users do not need to set it manually in config file
25+
in most cases.
26+
task (str): The task of the dataset. It is usually set automatically
27+
and users do not need to set it manually in config file
28+
in most cases.
29+
"""
30+
31+
def __init__(self, files: List[Dict], cache_path: str, data_root: str,
32+
task: str) -> None:
33+
try:
34+
import boto3
35+
from botocore import UNSIGNED
36+
from botocore.config import Config
37+
except ImportError:
38+
raise ImportError(
39+
'Please install boto3 to download hiertext dataset.')
40+
self.files = files
41+
self.cache_path = cache_path
42+
self.data_root = data_root
43+
self.task = task
44+
self.s3_client = boto3.client(
45+
's3', config=Config(signature_version=UNSIGNED))
46+
self.total_length = 0
47+
mkdir_or_exist(self.data_root)
48+
mkdir_or_exist(osp.join(self.data_root, f'{task}_imgs'))
49+
mkdir_or_exist(osp.join(self.data_root, 'annotations'))
50+
mkdir_or_exist(self.cache_path)
51+
52+
def find_bucket_key(self, s3_path: str):
53+
"""This is a helper function that given an s3 path such that the path
54+
is of the form: bucket/key It will return the bucket and the key
55+
represented by the s3 path.
56+
57+
Args:
58+
s3_path (str): The AWS s3 path.
59+
"""
60+
s3_components = s3_path.split('/', 1)
61+
bucket = s3_components[0]
62+
s3_key = ''
63+
if len(s3_components) > 1:
64+
s3_key = s3_components[1]
65+
return bucket, s3_key
66+
67+
def s3_download(self, s3_bucket: str, s3_object_key: str, dst_path: str):
68+
"""Download file from given s3 url with progress bar.
69+
70+
Args:
71+
s3_bucket (str): The s3 bucket to download the file.
72+
s3_object_key (str): The s3 object key to download the file.
73+
dst_path (str): The destination path to save the file.
74+
"""
75+
meta_data = self.s3_client.head_object(
76+
Bucket=s3_bucket, Key=s3_object_key)
77+
total_length = int(meta_data.get('ContentLength', 0))
78+
downloaded = 0
79+
80+
def progress(chunk):
81+
nonlocal downloaded
82+
downloaded += chunk
83+
percent = min(100. * downloaded / total_length, 100)
84+
file_name = osp.basename(dst_path)
85+
print(f'\rDownloading {file_name}: {percent:.2f}%', end='')
86+
87+
print(f'Downloading {dst_path}')
88+
self.s3_client.download_file(
89+
s3_bucket, s3_object_key, dst_path, Callback=progress)
90+
91+
def download(self, url: Optional[str], dst_path: str) -> None:
92+
"""Download file from given url with progress bar.
93+
94+
Args:
95+
url (str): The url to download the file.
96+
dst_path (str): The destination path to save the file.
97+
"""
98+
if url is None and not osp.exists(dst_path):
99+
raise FileNotFoundError(
100+
'Direct url is not available for this dataset.'
101+
' Please manually download the required files'
102+
' following the guides.')
103+
104+
if url.startswith('magnet'):
105+
raise NotImplementedError('Please use any BitTorrent client to '
106+
'download the following magnet link to '
107+
f'{osp.abspath(dst_path)} and '
108+
f'try again.\nLink: {url}')
109+
110+
print('Downloading...')
111+
print(f'URL: {url}')
112+
print(f'Destination: {osp.abspath(dst_path)}')
113+
print('If you stuck here for a long time, please check your network, '
114+
'or manually download the file to the destination path and '
115+
'run the script again.')
116+
if url.startswith('s3://'):
117+
url = url[5:]
118+
bucket, key = self.find_bucket_key(url)
119+
self.s3_download(bucket, key, osp.abspath(dst_path))
120+
elif url.startswith('https://') or url.startswith('http://'):
121+
super().download(url, dst_path)
122+
print('')

requirements/optional.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
boto3

0 commit comments

Comments
 (0)