Skip to content

Commit

Permalink
Use urlopen from urllib.request to download models instead of wget (o…
Browse files Browse the repository at this point in the history
…nnx#4006)

* Replace wget with Python urllib.requests

Signed-off-by: Chun-Wei Chen <[email protected]>
  • Loading branch information
jcwchen authored Feb 11, 2022
1 parent c2a3364 commit c9d61b6
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 43 deletions.
2 changes: 1 addition & 1 deletion VERSION_NUMBER
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.11.0rc1
1.11.0rc2
85 changes: 50 additions & 35 deletions onnx/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from urllib.error import HTTPError
import json
import os
import wget # type: ignore
import hashlib
from io import BytesIO
from typing import List, Optional, Dict, Any, Tuple, cast, Set, IO
Expand Down Expand Up @@ -69,17 +68,17 @@ def __repr__(self) -> str:

def set_dir(new_dir: str) -> None:
"""
Set the current ONNX hub cache location
@param new_dir: location of new model hub cache
Sets the current ONNX hub cache location
:param new_dir: location of new model hub cache
"""
global _ONNX_HUB_DIR
_ONNX_HUB_DIR = new_dir


def get_dir() -> str:
"""
Get the current ONNX hub cache location
@return: The location of the ONNX hub model cache
Gets the current ONNX hub cache location
:return: The location of the ONNX hub model cache
"""
return _ONNX_HUB_DIR

Expand All @@ -93,26 +92,26 @@ def _parse_repo_info(repo: str) -> Tuple[str, str, str]:
if ":" in repo:
repo_ref = repo.split("/")[1].split(":")[1]
else:
repo_ref = "master"
repo_ref = "main"
return repo_owner, repo_name, repo_ref


def _verify_repo_ref(repo: str) -> bool:
"""
Verifies whether the given model repo can be trusted.
A model repo can be trusted if it matches onnx/models:master.
A model repo can be trusted if it matches onnx/models:main.
"""
repo_owner, repo_name, repo_ref = _parse_repo_info(repo)
return (repo_owner == "onnx") and (repo_name == "models") and (repo_ref == "master")
return (repo_owner == "onnx") and (repo_name == "models") and (repo_ref == "main")


def _get_base_url(repo: str, lfs: bool = False) -> str:
"""
Gets the base github url from a repo specification string
@param repo: The location of the model repo in format "user/repo[:branch]".
If no branch is found will default to "master"
@param lfs: whether the url is for downloading lfs models
@return: the base github url for downloading
:param repo: The location of the model repo in format "user/repo[:branch]".
If no branch is found will default to "main"
:param lfs: whether the url is for downloading lfs models
:return: the base github url for downloading
"""
repo_owner, repo_name, repo_ref = _parse_repo_info(repo)

Expand All @@ -122,22 +121,38 @@ def _get_base_url(repo: str, lfs: bool = False) -> str:
return "https://raw.githubusercontent.com/{}/{}/{}/".format(repo_owner, repo_name, repo_ref)


def _download_file(url: str, file_name: str) -> None:
"""
Downloads the file with specifed file_name from the url
:param url: a url of download link
:param file_name: a specified file name for the downloaded file
"""
chunk_size = 16384 # 1024 * 16
with urlopen(url) as response, open(file_name, 'wb') as f:
# Loads processively with chuck_size for huge models
while True:
chunk = response.read(chunk_size)
if not chunk:
break
f.write(chunk)


def list_models(
repo: str = "onnx/models:master", model: Optional[str] = None, tags: Optional[List[str]] = None
repo: str = "onnx/models:main", model: Optional[str] = None, tags: Optional[List[str]] = None
) -> List[ModelInfo]:
"""
Get the list of model info consistent with a given name and tags
Gets the list of model info consistent with a given name and tags
@param repo: The location of the model repo in format "user/repo[:branch]".
If no branch is found will default to "master"
@param model: The name of the model to search for. If `None`, will return all models with matching tags.
@param tags: A list of tags to filter models by. If `None`, will return all models with matching name.
:param repo: The location of the model repo in format "user/repo[:branch]".
If no branch is found will default to "main"
:param model: The name of the model to search for. If `None`, will return all models with matching tags.
:param tags: A list of tags to filter models by. If `None`, will return all models with matching name.
"""
base_url = _get_base_url(repo)
manifest_url = base_url + "ONNX_HUB_MANIFEST.json"
try:
with urlopen(manifest_url) as f:
manifest: List[ModelInfo] = [ModelInfo(info) for info in json.load(cast(IO[str], f))]
with urlopen(manifest_url) as response:
manifest: List[ModelInfo] = [ModelInfo(info) for info in json.load(cast(IO[str], response))]
except HTTPError as e:
raise AssertionError("Could not find manifest at {}".format(manifest_url), e)

Expand All @@ -157,14 +172,14 @@ def list_models(
return matching_info_list


def get_model_info(model: str, repo: str = "onnx/models:master", opset: Optional[int] = None) -> ModelInfo:
def get_model_info(model: str, repo: str = "onnx/models:main", opset: Optional[int] = None) -> ModelInfo:
"""
Get the model info matching the given name and opset.
Gets the model info matching the given name and opset.
@param model: The name of the onnx model in the manifest. This field is case-sensitive
@param repo: The location of the model repo in format "user/repo[:branch]".
If no branch is found will default to "master"
@param opset: The opset of the model to get. The default of `None` will return the model with largest opset.
:param model: The name of the onnx model in the manifest. This field is case-sensitive
:param repo: The location of the model repo in format "user/repo[:branch]".
If no branch is found will default to "main"
:param opset: The opset of the model to get. The default of `None` will return the model with largest opset.
"""
matching_models = list_models(repo, model)
if not matching_models:
Expand All @@ -182,20 +197,20 @@ def get_model_info(model: str, repo: str = "onnx/models:master", opset: Optional

def load(
model: str,
repo: str = "onnx/models:master",
repo: str = "onnx/models:main",
opset: Optional[int] = None,
force_reload: bool = False,
silent: bool = False,
) -> Optional[onnx.ModelProto]:
"""
Download a model by name from the onnx model hub
Downloads a model by name from the onnx model hub
@param model: The name of the onnx model in the manifest. This field is case-sensitive
@param repo: The location of the model repo in format "user/repo[:branch]".
If no branch is found will default to "master"
@param opset: The opset of the model to download. The default of `None` automatically chooses the largest opset
@param force_reload: Whether to force the model to re-download even if its already found in the cache
@param silent: Whether to suppress the warning message if the repo is not trusted.
:param model: The name of the onnx model in the manifest. This field is case-sensitive
:param repo: The location of the model repo in format "user/repo[:branch]".
If no branch is found will default to "main"
:param opset: The opset of the model to download. The default of `None` automatically chooses the largest opset
:param force_reload: Whether to force the model to re-download even if its already found in the cache
:param silent: Whether to suppress the warning message if the repo is not trusted.
"""
selected_model = get_model_info(model, repo, opset)
local_model_path_arr = selected_model.model_path.split("/")
Expand All @@ -218,7 +233,7 @@ def load(
os.makedirs(os.path.dirname(local_model_path), exist_ok=True)
lfs_url = _get_base_url(repo, True)
print("Downloading {} to local path {}".format(model, local_model_path))
wget.download(lfs_url + selected_model.model_path, local_model_path)
_download_file(lfs_url + selected_model.model_path, local_model_path)
else:
print("Using cached {} model from {}".format(model, local_model_path))

Expand Down
2 changes: 1 addition & 1 deletion onnx/test/hub_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
class TestModelHub(unittest.TestCase):
def setUp(self) -> None:
self.name = "MNIST"
self.repo = "onnx/models:master"
self.repo = "onnx/models:main"
self.opset = 7

def test_force_reload(self) -> None:
Expand Down
3 changes: 1 addition & 2 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,4 @@ wheel
jupyter
pyzmq
setuptools
twine
wget
twine
3 changes: 1 addition & 2 deletions requirements-release.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,4 @@ nbval
ipython
wheel
setuptools
twine
wget
twine
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
numpy >= 1.16.6
protobuf >= 3.12.2
typing-extensions >= 3.6.2.1
wget
typing-extensions >= 3.6.2.1

0 comments on commit c9d61b6

Please sign in to comment.