Skip to content

Commit 53eae1b

Browse files
authored
[python, gaudi] Refactor the code and add wrap_in_hpu_graph to corner case (#625)
Signed-off-by: Liu, Kaixuan <[email protected]>
1 parent bedb2e5 commit 53eae1b

File tree

7 files changed

+74
-104
lines changed

7 files changed

+74
-104
lines changed

backends/python/server/text_embeddings_server/models/__init__.py

Lines changed: 54 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
__all__ = ["Model"]
1919

2020
TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE", "false").lower() in ["true", "1"]
21+
DISABLE_TENSOR_CACHE = os.getenv("DISABLE_TENSOR_CACHE", "false").lower() in [
22+
"true",
23+
"1",
24+
]
2125
# Disable gradients
2226
torch.set_grad_enabled(False)
2327

@@ -32,6 +36,29 @@
3236
__all__.append(FlashBert)
3337

3438

39+
def wrap_model_if_hpu(model_handle, device):
40+
"""Wrap the model in HPU graph if the device is HPU."""
41+
if device.type == "hpu":
42+
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
43+
44+
model_handle.model = wrap_in_hpu_graph(
45+
model_handle.model, disable_tensor_cache=DISABLE_TENSOR_CACHE
46+
)
47+
return model_handle
48+
49+
50+
def create_model(model_class, model_path, device, datatype, pool="cls"):
51+
"""Create a model instance and wrap it if needed."""
52+
model_handle = model_class(
53+
model_path,
54+
device,
55+
datatype,
56+
pool,
57+
trust_remote=TRUST_REMOTE_CODE,
58+
)
59+
return wrap_model_if_hpu(model_handle, device)
60+
61+
3562
def get_model(model_path: Path, dtype: Optional[str], pool: str):
3663
if dtype == "float32":
3764
datatype = torch.float32
@@ -46,6 +73,7 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
4673
logger.info(f"backend device: {device}")
4774

4875
config = AutoConfig.from_pretrained(model_path, trust_remote_code=TRUST_REMOTE_CODE)
76+
4977
if (
5078
hasattr(config, "auto_map")
5179
and isinstance(config.auto_map, dict)
@@ -54,8 +82,9 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
5482
== "jinaai/jina-bert-v2-qk-post-norm--modeling_bert.JinaBertModel"
5583
):
5684
# Add specific offline modeling for model "jinaai/jina-embeddings-v2-base-code" which uses "autoMap" to reference code in other repository
57-
return FlashJinaBert(model_path, config, device, datatype, pool)
58-
elif config.model_type == "bert":
85+
return create_model(FlashJinaBert, model_path, device, datatype)
86+
87+
if config.model_type == "bert":
5988
config: BertConfig
6089
if (
6190
use_ipex()
@@ -66,98 +95,36 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
6695
):
6796
if pool != "cls":
6897
if config.architectures[0].endswith("ForMaskedLM") and pool == "splade":
69-
return MaskedLanguageModel(
70-
model_path,
71-
device,
72-
datatype,
73-
trust_remote=TRUST_REMOTE_CODE,
98+
return create_model(
99+
MaskedLanguageModel, model_path, device, datatype, pool
74100
)
75-
return DefaultModel(
76-
model_path, device, datatype, pool, trust_remote=TRUST_REMOTE_CODE
77-
)
101+
return create_model(DefaultModel, model_path, device, datatype, pool)
102+
78103
try:
79-
return FlashBert(model_path, device, datatype)
80-
except FileNotFoundError as e:
104+
return create_model(FlashBert, model_path, device, datatype)
105+
except FileNotFoundError:
81106
logger.info(
82107
"Do not have safetensors file for this model, use default transformers model path instead"
83108
)
84-
return DefaultModel(
85-
model_path, device, datatype, pool, trust_remote=TRUST_REMOTE_CODE
86-
)
109+
return create_model(DefaultModel, model_path, device, datatype, pool)
110+
87111
if config.architectures[0].endswith("Classification"):
88-
return ClassificationModel(
89-
model_path, device, datatype, trust_remote=TRUST_REMOTE_CODE
90-
)
112+
return create_model(ClassificationModel, model_path, device, datatype)
91113
elif config.architectures[0].endswith("ForMaskedLM") and pool == "splade":
92-
return MaskedLanguageModel(
93-
model_path, device, datatype, trust_remote=TRUST_REMOTE_CODE
94-
)
114+
return create_model(MaskedLanguageModel, model_path, device, datatype)
95115
else:
96-
return DefaultModel(
97-
model_path,
98-
device,
99-
datatype,
100-
pool,
101-
trust_remote=TRUST_REMOTE_CODE,
102-
)
103-
elif config.model_type == "mistral" and device.type == "hpu":
116+
return create_model(DefaultModel, model_path, device, datatype, pool)
117+
118+
if config.model_type == "mistral" and device.type == "hpu":
104119
try:
105-
return FlashMistral(
106-
model_path,
107-
device,
108-
datatype,
109-
pool,
110-
)
111-
except FileNotFoundError as e:
112-
return DefaultModel(
113-
model_path,
114-
device,
115-
datatype,
116-
pool,
117-
trust_remote=TRUST_REMOTE_CODE,
118-
)
120+
return create_model(FlashMistral, model_path, device, datatype, pool)
121+
except FileNotFoundError:
122+
return create_model(DefaultModel, model_path, device, datatype, pool)
123+
124+
# Default case
125+
if config.architectures[0].endswith("Classification"):
126+
return create_model(ClassificationModel, model_path, device, datatype)
127+
elif config.architectures[0].endswith("ForMaskedLM") and pool == "splade":
128+
return create_model(MaskedLanguageModel, model_path, device, datatype)
119129
else:
120-
if device.type == "hpu":
121-
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
122-
123-
if config.architectures[0].endswith("Classification"):
124-
model_handle = ClassificationModel(
125-
model_path,
126-
device,
127-
datatype,
128-
trust_remote=TRUST_REMOTE_CODE,
129-
)
130-
elif config.architectures[0].endswith("ForMaskedLM") and pool == "splade":
131-
model_handle = MaskedLanguageModel(
132-
model_path, device, datatype, trust_remote=TRUST_REMOTE_CODE
133-
)
134-
else:
135-
model_handle = DefaultModel(
136-
model_path,
137-
device,
138-
datatype,
139-
pool,
140-
trust_remote=TRUST_REMOTE_CODE,
141-
)
142-
model_handle.model = wrap_in_hpu_graph(model_handle.model)
143-
return model_handle
144-
elif use_ipex():
145-
if config.architectures[0].endswith("Classification"):
146-
return ClassificationModel(
147-
model_path,
148-
device,
149-
datatype,
150-
trust_remote=TRUST_REMOTE_CODE,
151-
)
152-
elif config.architectures[0].endswith("ForMaskedLM") and pool == "splade":
153-
return MaskedLanguageModel(
154-
model_path, device, datatype, trust_remote=TRUST_REMOTE_CODE
155-
)
156-
else:
157-
return DefaultModel(
158-
model_path,
159-
device,
160-
datatype,
161-
pool,
162-
trust_remote=TRUST_REMOTE_CODE,
163-
)
130+
return create_model(DefaultModel, model_path, device, datatype, pool)

backends/python/server/text_embeddings_server/models/classification_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def __init__(
1818
model_path: Path,
1919
device: torch.device,
2020
dtype: torch.dtype,
21+
pool: str = "cls",
2122
trust_remote: bool = False,
2223
):
2324
model = AutoModelForSequenceClassification.from_pretrained(

backends/python/server/text_embeddings_server/models/default_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __init__(
1919
model_path: Path,
2020
device: torch.device,
2121
dtype: torch.dtype,
22-
pool: str,
22+
pool: str = "cls",
2323
trust_remote: bool = False,
2424
):
2525
model = (

backends/python/server/text_embeddings_server/models/flash_bert.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,14 @@ def forward(
294294

295295

296296
class FlashBert(Model):
297-
def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
297+
def __init__(
298+
self,
299+
model_path: Path,
300+
device: torch.device,
301+
dtype: torch.dtype,
302+
pool: str = "cls",
303+
trust_remote: bool = False,
304+
):
298305
config = BertConfig.from_pretrained(model_path)
299306

300307
if hasattr(config, "max_seq_length"):
@@ -306,10 +313,6 @@ def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
306313
model = FlashBertModel(f, device, dtype, config)
307314
self.device = device
308315
self.dtype = dtype
309-
if device.type == "hpu":
310-
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
311-
312-
model = wrap_in_hpu_graph(model, disable_tensor_cache=False)
313316
self.hidden_size = config.hidden_size
314317

315318
super(FlashBert, self).__init__(model=model, dtype=dtype, device=device)

backends/python/server/text_embeddings_server/models/flash_mistral.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,12 @@ def forward(
364364

365365
class FlashMistral(Model):
366366
def __init__(
367-
self, model_path: Path, device: torch.device, dtype: torch.dtype, pool: str
367+
self,
368+
model_path: Path,
369+
device: torch.device,
370+
dtype: torch.dtype,
371+
pool: str = "cls",
372+
trust_remote: bool = False,
368373
):
369374
config = MistralConfig.from_pretrained(model_path)
370375

@@ -379,10 +384,6 @@ def __init__(
379384
model = FlashMistralModel(model_path, index_data, device, dtype, config)
380385
self.device = device
381386
self.dtype = dtype
382-
if device.type == "hpu":
383-
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
384-
385-
model = wrap_in_hpu_graph(model, disable_tensor_cache=False)
386387
self.hidden_size = config.hidden_size
387388

388389
super(FlashMistral, self).__init__(model=model, dtype=dtype, device=device)

backends/python/server/text_embeddings_server/models/jinaBert_model.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -478,11 +478,12 @@ class FlashJinaBert(Model):
478478
def __init__(
479479
self,
480480
model_path: Path,
481-
config: AutoConfig,
482481
device: torch.device,
483482
dtype: torch.dtype,
484-
pool: str,
483+
pool: str = "mean",
484+
trust_remote: bool = True,
485485
):
486+
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote)
486487
if hasattr(config, "max_seq_length"):
487488
self.max_input_length = config.max_seq_length
488489
else:
@@ -494,10 +495,6 @@ def __init__(
494495
self.pooling = DefaultPooling(self.hidden_size, pooling_mode=pool)
495496
self.device = device
496497
self.dtype = dtype
497-
if device.type == "hpu":
498-
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
499-
500-
model = wrap_in_hpu_graph(model, disable_tensor_cache=False)
501498
self.hidden_size = config.hidden_size
502499

503500
super(FlashJinaBert, self).__init__(model=model, dtype=dtype, device=device)

backends/python/server/text_embeddings_server/models/masked_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __init__(
1919
model_path: Path,
2020
device: torch.device,
2121
dtype: torch.dtype,
22+
pool: str = "cls",
2223
trust_remote: bool = False,
2324
):
2425
model = (

0 commit comments

Comments
 (0)