18
18
__all__ = ["Model" ]
19
19
20
20
TRUST_REMOTE_CODE = os .getenv ("TRUST_REMOTE_CODE" , "false" ).lower () in ["true" , "1" ]
21
+ DISABLE_TENSOR_CACHE = os .getenv ("DISABLE_TENSOR_CACHE" , "false" ).lower () in [
22
+ "true" ,
23
+ "1" ,
24
+ ]
21
25
# Disable gradients
22
26
torch .set_grad_enabled (False )
23
27
32
36
__all__ .append (FlashBert )
33
37
34
38
39
+ def wrap_model_if_hpu (model_handle , device ):
40
+ """Wrap the model in HPU graph if the device is HPU."""
41
+ if device .type == "hpu" :
42
+ from habana_frameworks .torch .hpu import wrap_in_hpu_graph
43
+
44
+ model_handle .model = wrap_in_hpu_graph (
45
+ model_handle .model , disable_tensor_cache = DISABLE_TENSOR_CACHE
46
+ )
47
+ return model_handle
48
+
49
+
50
+ def create_model (model_class , model_path , device , datatype , pool = "cls" ):
51
+ """Create a model instance and wrap it if needed."""
52
+ model_handle = model_class (
53
+ model_path ,
54
+ device ,
55
+ datatype ,
56
+ pool ,
57
+ trust_remote = TRUST_REMOTE_CODE ,
58
+ )
59
+ return wrap_model_if_hpu (model_handle , device )
60
+
61
+
35
62
def get_model (model_path : Path , dtype : Optional [str ], pool : str ):
36
63
if dtype == "float32" :
37
64
datatype = torch .float32
@@ -46,6 +73,7 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
46
73
logger .info (f"backend device: { device } " )
47
74
48
75
config = AutoConfig .from_pretrained (model_path , trust_remote_code = TRUST_REMOTE_CODE )
76
+
49
77
if (
50
78
hasattr (config , "auto_map" )
51
79
and isinstance (config .auto_map , dict )
@@ -54,8 +82,9 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
54
82
== "jinaai/jina-bert-v2-qk-post-norm--modeling_bert.JinaBertModel"
55
83
):
56
84
# Add specific offline modeling for model "jinaai/jina-embeddings-v2-base-code" which uses "autoMap" to reference code in other repository
57
- return FlashJinaBert (model_path , config , device , datatype , pool )
58
- elif config .model_type == "bert" :
85
+ return create_model (FlashJinaBert , model_path , device , datatype )
86
+
87
+ if config .model_type == "bert" :
59
88
config : BertConfig
60
89
if (
61
90
use_ipex ()
@@ -66,98 +95,36 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
66
95
):
67
96
if pool != "cls" :
68
97
if config .architectures [0 ].endswith ("ForMaskedLM" ) and pool == "splade" :
69
- return MaskedLanguageModel (
70
- model_path ,
71
- device ,
72
- datatype ,
73
- trust_remote = TRUST_REMOTE_CODE ,
98
+ return create_model (
99
+ MaskedLanguageModel , model_path , device , datatype , pool
74
100
)
75
- return DefaultModel (
76
- model_path , device , datatype , pool , trust_remote = TRUST_REMOTE_CODE
77
- )
101
+ return create_model (DefaultModel , model_path , device , datatype , pool )
102
+
78
103
try :
79
- return FlashBert ( model_path , device , datatype )
80
- except FileNotFoundError as e :
104
+ return create_model ( FlashBert , model_path , device , datatype )
105
+ except FileNotFoundError :
81
106
logger .info (
82
107
"Do not have safetensors file for this model, use default transformers model path instead"
83
108
)
84
- return DefaultModel (
85
- model_path , device , datatype , pool , trust_remote = TRUST_REMOTE_CODE
86
- )
109
+ return create_model (DefaultModel , model_path , device , datatype , pool )
110
+
87
111
if config .architectures [0 ].endswith ("Classification" ):
88
- return ClassificationModel (
89
- model_path , device , datatype , trust_remote = TRUST_REMOTE_CODE
90
- )
112
+ return create_model (ClassificationModel , model_path , device , datatype )
91
113
elif config .architectures [0 ].endswith ("ForMaskedLM" ) and pool == "splade" :
92
- return MaskedLanguageModel (
93
- model_path , device , datatype , trust_remote = TRUST_REMOTE_CODE
94
- )
114
+ return create_model (MaskedLanguageModel , model_path , device , datatype )
95
115
else :
96
- return DefaultModel (
97
- model_path ,
98
- device ,
99
- datatype ,
100
- pool ,
101
- trust_remote = TRUST_REMOTE_CODE ,
102
- )
103
- elif config .model_type == "mistral" and device .type == "hpu" :
116
+ return create_model (DefaultModel , model_path , device , datatype , pool )
117
+
118
+ if config .model_type == "mistral" and device .type == "hpu" :
104
119
try :
105
- return FlashMistral (
106
- model_path ,
107
- device ,
108
- datatype ,
109
- pool ,
110
- )
111
- except FileNotFoundError as e :
112
- return DefaultModel (
113
- model_path ,
114
- device ,
115
- datatype ,
116
- pool ,
117
- trust_remote = TRUST_REMOTE_CODE ,
118
- )
120
+ return create_model (FlashMistral , model_path , device , datatype , pool )
121
+ except FileNotFoundError :
122
+ return create_model (DefaultModel , model_path , device , datatype , pool )
123
+
124
+ # Default case
125
+ if config .architectures [0 ].endswith ("Classification" ):
126
+ return create_model (ClassificationModel , model_path , device , datatype )
127
+ elif config .architectures [0 ].endswith ("ForMaskedLM" ) and pool == "splade" :
128
+ return create_model (MaskedLanguageModel , model_path , device , datatype )
119
129
else :
120
- if device .type == "hpu" :
121
- from habana_frameworks .torch .hpu import wrap_in_hpu_graph
122
-
123
- if config .architectures [0 ].endswith ("Classification" ):
124
- model_handle = ClassificationModel (
125
- model_path ,
126
- device ,
127
- datatype ,
128
- trust_remote = TRUST_REMOTE_CODE ,
129
- )
130
- elif config .architectures [0 ].endswith ("ForMaskedLM" ) and pool == "splade" :
131
- model_handle = MaskedLanguageModel (
132
- model_path , device , datatype , trust_remote = TRUST_REMOTE_CODE
133
- )
134
- else :
135
- model_handle = DefaultModel (
136
- model_path ,
137
- device ,
138
- datatype ,
139
- pool ,
140
- trust_remote = TRUST_REMOTE_CODE ,
141
- )
142
- model_handle .model = wrap_in_hpu_graph (model_handle .model )
143
- return model_handle
144
- elif use_ipex ():
145
- if config .architectures [0 ].endswith ("Classification" ):
146
- return ClassificationModel (
147
- model_path ,
148
- device ,
149
- datatype ,
150
- trust_remote = TRUST_REMOTE_CODE ,
151
- )
152
- elif config .architectures [0 ].endswith ("ForMaskedLM" ) and pool == "splade" :
153
- return MaskedLanguageModel (
154
- model_path , device , datatype , trust_remote = TRUST_REMOTE_CODE
155
- )
156
- else :
157
- return DefaultModel (
158
- model_path ,
159
- device ,
160
- datatype ,
161
- pool ,
162
- trust_remote = TRUST_REMOTE_CODE ,
163
- )
130
+ return create_model (DefaultModel , model_path , device , datatype , pool )
0 commit comments