Skip to content
This repository was archived by the owner on Nov 8, 2022. It is now read-only.

Commit 83f9a90

Browse files
author
Izsak, Peter
committed
Merge branch 'deepti/restructure_downloads' into 'master'
Add model file caching; consolidate S3 URLs. - Create a single mechanism for getting the path of a model. - Consolidate all model S3 URLs.
2 parents 30591cc + c9cf928 commit 83f9a90

File tree

3 files changed

+378
-6
lines changed

3 files changed

+378
-6
lines changed
+172
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# ******************************************************************************
2+
# Copyright 2017-2018 Intel Corporation
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ******************************************************************************
16+
17+
from nlp_architect.utils.io import uncompress_file, zipfile_list
18+
from nlp_architect.utils.file_cache import cached_path
19+
20+
from nlp_architect import LIBRARY_OUT
21+
22+
S3_PREFIX = "https://s3-us-west-2.amazonaws.com/nlp-architect-data/"
23+
24+
25+
class PretrainedModel:
26+
27+
""" Generic class to download the pre-trained models
28+
29+
Usage Example:
30+
31+
chunker = ChunkerModel.get_instance()
32+
chunker2 = ChunkerModel.get_instance()
33+
print(chunker, chunker2)
34+
print("Local File path = ", chunker.get_file_path())
35+
files_models = chunker2.get_model_files()
36+
for idx, file_name in enumerate(files_models):
37+
print(str(idx) + ": " + file_name)
38+
39+
"""
40+
41+
def __init__(self, model_name, sub_path, files):
42+
if isinstance(self, (BistModel, ChunkerModel, MrcModel, IntentModel, AbsaModel, NerModel)):
43+
if self._instance is not None: # pylint: disable=no-member
44+
raise Exception("This class is a singleton!")
45+
self.model_name = model_name
46+
self.base_path = S3_PREFIX + sub_path
47+
self.files = files
48+
self.download_path = LIBRARY_OUT / 'pretrained_models' / self.model_name
49+
self.model_files = []
50+
51+
@classmethod
52+
# pylint: disable=no-member
53+
def get_instance(cls):
54+
"""
55+
Static instance access method
56+
Args:
57+
cls (Class name): Calling class
58+
"""
59+
if cls._instance is None:
60+
cls() # pylint: disable=no-value-for-parameter
61+
return cls._instance
62+
63+
def get_file_path(self):
64+
"""
65+
Return local file path of downloaded model files
66+
"""
67+
for filename in self.files:
68+
cached_file_path, need_downloading = cached_path(
69+
self.base_path + filename, self.download_path)
70+
if filename.endswith('zip'):
71+
if need_downloading:
72+
print('Unzipping...')
73+
uncompress_file(cached_file_path, outpath=self.download_path)
74+
print('Done.')
75+
return self.download_path
76+
77+
def get_model_files(self):
78+
"""
79+
Return individual file names of downloaded models
80+
"""
81+
for fileName in self.files:
82+
cached_file_path, need_downloading = cached_path(
83+
self.base_path + fileName, self.download_path)
84+
if fileName.endswith('zip'):
85+
if need_downloading:
86+
print('Unzipping...')
87+
uncompress_file(cached_file_path, outpath=self.download_path)
88+
print('Done.')
89+
self.model_files.extend(zipfile_list(cached_file_path))
90+
else:
91+
self.model_files.extend([fileName])
92+
return self.model_files
93+
94+
95+
# Model-specific classes developers instantiate where model has to be used
96+
97+
class BistModel(PretrainedModel):
98+
"""
99+
Download and process (unzip) pre-trained BIST model
100+
"""
101+
_instance = None
102+
sub_path = 'models/dep_parse/'
103+
files = ['bist-pretrained.zip']
104+
105+
def __init__(self):
106+
super().__init__('bist', self.sub_path, self.files)
107+
BistModel._instance = self
108+
109+
110+
class IntentModel(PretrainedModel):
111+
"""
112+
Download and process (unzip) pre-trained Intent model
113+
"""
114+
_instance = None
115+
sub_path = 'models/intent/'
116+
files = ['model_info.dat', 'model.h5']
117+
118+
def __init__(self):
119+
super().__init__('intent', self.sub_path, self.files)
120+
IntentModel._instance = self
121+
122+
123+
class MrcModel(PretrainedModel):
124+
"""
125+
Download and process (unzip) pre-trained MRC model
126+
"""
127+
_instance = None
128+
sub_path = 'models/mrc/'
129+
files = ['mrc_data.zip', 'mrc_model.zip']
130+
131+
def __init__(self):
132+
super().__init__('mrc', self.sub_path, self.files)
133+
MrcModel._instance = self
134+
135+
136+
class NerModel(PretrainedModel):
137+
"""
138+
Download and process (unzip) pre-trained NER model
139+
"""
140+
_instance = None
141+
sub_path = 'models/ner/'
142+
files = ['model_v4.h5', 'model_info_v4.dat']
143+
144+
def __init__(self):
145+
super().__init__('ner', self.sub_path, self.files)
146+
NerModel._instance = self
147+
148+
149+
class AbsaModel(PretrainedModel):
150+
"""
151+
Download and process (unzip) pre-trained ABSA model
152+
"""
153+
_instance = None
154+
sub_path = 'models/absa/'
155+
files = ['rerank_model.h5']
156+
157+
def __init__(self):
158+
super().__init__('absa', self.sub_path, self.files)
159+
AbsaModel._instance = self
160+
161+
162+
class ChunkerModel(PretrainedModel):
163+
"""
164+
Download and process (unzip) pre-trained Chunker model
165+
"""
166+
_instance = None
167+
sub_path = 'models/chunker/'
168+
files = ['model.h5', 'model_info.dat.params']
169+
170+
def __init__(self):
171+
super().__init__('chunker', self.sub_path, self.files)
172+
ChunkerModel._instance = self

nlp_architect/utils/file_cache.py

+185
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# ******************************************************************************
2+
# Copyright 2017-2018 Intel Corporation
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ******************************************************************************
16+
"""
17+
Utilities for working with the local dataset cache.
18+
"""
19+
import os
20+
import logging
21+
import shutil
22+
import tempfile
23+
import json
24+
from urllib.parse import urlparse
25+
from pathlib import Path
26+
from typing import Tuple, Union, IO
27+
from hashlib import sha256
28+
29+
from nlp_architect import LIBRARY_OUT
30+
from nlp_architect.utils.io import load_json_file
31+
32+
import requests
33+
34+
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
35+
36+
MODEL_CACHE = LIBRARY_OUT / 'pretrained_models'
37+
38+
39+
def cached_path(url_or_filename: Union[str, Path], cache_dir: str = None) -> str:
40+
"""
41+
Given something that might be a URL (or might be a local path),
42+
determine which. If it's a URL, download the file and cache it, and
43+
return the path to the cached file. If it's already a local path,
44+
make sure the file exists and then return the path.
45+
"""
46+
if cache_dir is None:
47+
cache_dir = MODEL_CACHE
48+
else:
49+
cache_dir = cache_dir
50+
if isinstance(url_or_filename, Path):
51+
url_or_filename = str(url_or_filename)
52+
53+
parsed = urlparse(url_or_filename)
54+
55+
if parsed.scheme in ('http', 'https'):
56+
# URL, so get it from the cache (downloading if necessary)
57+
return get_from_cache(url_or_filename, cache_dir)
58+
if os.path.exists(url_or_filename):
59+
# File, and it exists.
60+
print("File already exists. No further processing needed.")
61+
return url_or_filename
62+
if parsed.scheme == '':
63+
# File, but it doesn't exist.
64+
raise FileNotFoundError("file {} not found".format(url_or_filename))
65+
66+
# Something unknown
67+
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
68+
69+
70+
def url_to_filename(url: str, etag: str = None) -> str:
71+
"""
72+
Convert `url` into a hashed filename in a repeatable way.
73+
If `etag` is specified, append its hash to the url's, delimited
74+
by a period.
75+
"""
76+
if url.split('/')[-1].endswith('zip'):
77+
url_bytes = url.encode('utf-8')
78+
url_hash = sha256(url_bytes)
79+
filename = url_hash.hexdigest()
80+
if etag:
81+
etag_bytes = etag.encode('utf-8')
82+
etag_hash = sha256(etag_bytes)
83+
filename += '.' + etag_hash.hexdigest()
84+
else:
85+
filename = url.split('/')[-1]
86+
87+
return filename
88+
89+
90+
def filename_to_url(filename: str, cache_dir: str = None) -> Tuple[str, str]:
91+
"""
92+
Return the url and etag (which may be ``None``) stored for `filename`.
93+
Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist.
94+
"""
95+
if cache_dir is None:
96+
cache_dir = MODEL_CACHE
97+
98+
cache_path = os.path.join(cache_dir, filename)
99+
if not os.path.exists(cache_path):
100+
raise FileNotFoundError("file {} not found".format(cache_path))
101+
102+
meta_path = cache_path + '.json'
103+
if not os.path.exists(meta_path):
104+
raise FileNotFoundError("file {} not found".format(meta_path))
105+
106+
with open(meta_path) as meta_file:
107+
metadata = json.load(meta_file)
108+
url = metadata['url']
109+
etag = metadata['etag']
110+
111+
return url, etag
112+
113+
114+
def http_get(url: str, temp_file: IO) -> None:
115+
req = requests.get(url, stream=True)
116+
for chunk in req.iter_content(chunk_size=1024):
117+
if chunk: # filter out keep-alive new chunks
118+
temp_file.write(chunk)
119+
120+
121+
def get_from_cache(url: str, cache_dir: str = None) -> str:
122+
"""
123+
Given a URL, look for the corresponding dataset in the local cache.
124+
If it's not there, download it. Then return the path to the cached file.
125+
"""
126+
if cache_dir is None:
127+
cache_dir = MODEL_CACHE
128+
129+
os.makedirs(cache_dir, exist_ok=True)
130+
131+
response = requests.head(url, allow_redirects=True)
132+
if response.status_code != 200:
133+
raise IOError("HEAD request failed for url {} with status code {}"
134+
.format(url, response.status_code))
135+
etag = response.headers.get("ETag")
136+
137+
filename = url_to_filename(url, etag)
138+
139+
# get cache path to put the file
140+
cache_path = os.path.join(cache_dir, filename)
141+
142+
need_downloading = True
143+
144+
if os.path.exists(cache_path):
145+
# check if etag has changed comparing with the metadata
146+
if url.split('/')[-1].endswith('zip'):
147+
meta_path = cache_path + '.json'
148+
else:
149+
meta_path = cache_path + '_meta_' + '.json'
150+
meta = load_json_file(meta_path)
151+
if meta['etag'] == etag:
152+
print('file already present')
153+
need_downloading = False
154+
155+
if need_downloading:
156+
print("File not present or etag changed")
157+
# Download to temporary file, then copy to cache dir once finished.
158+
# Otherwise you get corrupt cache entries if the download gets interrupted.
159+
with tempfile.NamedTemporaryFile() as temp_file:
160+
logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
161+
162+
# GET file object
163+
http_get(url, temp_file)
164+
165+
# we are copying the file before closing it, so flush to avoid truncation
166+
temp_file.flush()
167+
# shutil.copyfileobj() starts at the current position, so go to the start
168+
temp_file.seek(0)
169+
170+
logger.info("copying %s to cache at %s", temp_file.name, cache_path)
171+
with open(cache_path, 'wb') as cache_file:
172+
shutil.copyfileobj(temp_file, cache_file)
173+
174+
logger.info("creating metadata file for %s", cache_path)
175+
meta = {'url': url, 'etag': etag}
176+
if url.split('/')[-1].endswith('zip'):
177+
meta_path = cache_path + '.json'
178+
else:
179+
meta_path = cache_path + '_meta_' + '.json'
180+
with open(meta_path, 'w') as meta_file:
181+
json.dump(meta, meta_file)
182+
183+
logger.info("removing temp file %s", temp_file.name)
184+
185+
return cache_path, need_downloading

nlp_architect/utils/io.py

+21-6
Original file line numberDiff line numberDiff line change
@@ -73,18 +73,33 @@ def uncompress_file(filepath: str or os.PathLike, outpath='.'):
7373
outpath (str): path to extract to
7474
"""
7575
filepath = str(filepath)
76-
if filepath.endswith('.zip'):
77-
with zipfile.ZipFile(filepath) as z:
78-
z.extractall(outpath)
79-
elif filepath.endswith('.gz'):
76+
if filepath.endswith('.gz'):
8077
if os.path.isdir(outpath):
8178
raise ValueError('output path for gzip must be a file')
8279
with gzip.open(filepath, 'rb') as fp:
8380
file_content = fp.read()
8481
with open(outpath, 'wb') as fp:
8582
fp.write(file_content)
86-
else:
87-
raise ValueError('Unsupported archive provided. Method supports only .zip/.gz files.')
83+
return None
84+
# To unzip zipped model files having SHA-encoded etag and url as filename
85+
# raise ValueError('Unsupported archive provided. Method supports only .zip/.gz files.')
86+
with zipfile.ZipFile(filepath) as z:
87+
z.extractall(outpath)
88+
return [x for x in z.namelist() if not (x.startswith('__MACOSX') or x.endswith('/'))]
89+
90+
91+
def zipfile_list(filepath: str or os.PathLike):
92+
"""
93+
List the files inside a given zip file
94+
95+
Args:
96+
filepath (str): path to file
97+
98+
Returns:
99+
String list of filenames
100+
"""
101+
with zipfile.ZipFile(filepath) as z:
102+
return [x for x in z.namelist() if not (x.startswith('__MACOSX') or x.endswith('/'))]
88103

89104

90105
def gzip_str(g_str):

0 commit comments

Comments
 (0)