30
30
from onnx import ModelProto
31
31
32
32
from deepsparse .log import get_main_logger
33
- from deepsparse .utils .onnx import MODEL_ONNX_NAME , truncate_onnx_model
34
- from sparsezoo import Model
33
+ from deepsparse .utils .onnx import MODEL_ONNX_NAME , model_to_path , truncate_onnx_model
35
34
from sparsezoo .utils import save_onnx
36
35
37
36
38
37
__all__ = [
38
+ "get_deployment_path" ,
39
39
"setup_transformers_pipeline" ,
40
40
"overwrite_transformer_onnx_model_inputs" ,
41
41
"fix_numpy_types" ,
@@ -62,12 +62,12 @@ def setup_transformers_pipeline(
62
62
:param sequence_length: The sequence length to use for the model
63
63
:param tokenizer_padding_side: The side to pad on for the tokenizer,
64
64
either "left" or "right"
65
- :param engine_kwargs: The kwargs to pass to the engine
66
65
:param onnx_model_name: The name of the onnx model to be loaded.
67
66
If not specified, defaults are used (see setup_onnx_file_path)
67
+ :param engine_kwargs: The kwargs to pass to the engine
68
68
:return The model path, config, tokenizer, and engine kwargs
69
69
"""
70
- model_path , config , tokenizer = setup_onnx_file_path (
70
+ model_path , config , tokenizer = fetch_onnx_file_path (
71
71
model_path , sequence_length , onnx_model_name
72
72
)
73
73
@@ -87,7 +87,7 @@ def setup_transformers_pipeline(
87
87
return model_path , config , tokenizer , engine_kwargs
88
88
89
89
90
- def setup_onnx_file_path (
90
+ def fetch_onnx_file_path (
91
91
model_path : str ,
92
92
sequence_length : int ,
93
93
onnx_model_name : Optional [str ] = None ,
@@ -102,6 +102,7 @@ def setup_onnx_file_path(
102
102
:param onnx_model_name: optionally, the precise name of the ONNX model
103
103
of interest may be specified. If not specified, the default ONNX model
104
104
name will be used (refer to `get_deployment_path` for details)
105
+ :param task: task to use for the config. Defaults to None
105
106
:return: file path to the processed ONNX file for the engine to compile
106
107
"""
107
108
deployment_path , onnx_path = get_deployment_path (model_path , onnx_model_name )
@@ -148,6 +149,7 @@ def get_deployment_path(
148
149
the deployment directory
149
150
"""
150
151
onnx_model_name = onnx_model_name or MODEL_ONNX_NAME
152
+
151
153
if os .path .isfile (model_path ):
152
154
# return the parent directory of the ONNX file
153
155
return os .path .dirname (model_path ), model_path
@@ -163,22 +165,9 @@ def get_deployment_path(
163
165
)
164
166
return model_path , os .path .join (model_path , onnx_model_name )
165
167
166
- elif model_path .startswith ("zoo:" ):
167
- zoo_model = Model (model_path )
168
- deployment_path = zoo_model .deployment_directory_path
169
- return deployment_path , os .path .join (deployment_path , onnx_model_name )
170
- elif model_path .startswith ("hf:" ):
171
- from huggingface_hub import snapshot_download
172
-
173
- deployment_path = snapshot_download (repo_id = model_path .replace ("hf:" , "" , 1 ))
174
- onnx_path = os .path .join (deployment_path , onnx_model_name )
175
- if not os .path .isfile (onnx_path ):
176
- raise ValueError (
177
- f"{ onnx_model_name } not found in transformers model directory "
178
- f"{ deployment_path } . Be sure that an export of the model is written to "
179
- f"{ onnx_path } "
180
- )
181
- return deployment_path , onnx_path
168
+ elif model_path .startswith ("zoo:" ) or model_path .startswith ("hf:" ):
169
+ onnx_model_path = model_to_path (model_path )
170
+ return os .path .dirname (onnx_model_path ), onnx_model_path
182
171
else :
183
172
raise ValueError (
184
173
f"model_path { model_path } is not a valid file, directory, or zoo stub"
0 commit comments