diff --git a/caikit_nlp/config/config.yml b/caikit_nlp/config/config.yml index 6f440a22..09d6af89 100644 --- a/caikit_nlp/config/config.yml +++ b/caikit_nlp/config/config.yml @@ -52,6 +52,8 @@ embedding: autocast: false # For testing, set device to "mps" on MacOS or "xpu" for IPEX GPU. # Otherwise, the default does automatic checks for cuda GPU (else cpu). + graphmode: false + # Use graph mode with IPEX CPU device: "" runtime: diff --git a/caikit_nlp/modules/text_embedding/embedding.py b/caikit_nlp/modules/text_embedding/embedding.py index c1eb7173..0a85b72d 100644 --- a/caikit_nlp/modules/text_embedding/embedding.py +++ b/caikit_nlp/modules/text_embedding/embedding.py @@ -14,6 +14,7 @@ # Standard from collections.abc import Sized +from contextlib import nullcontext from enum import Enum, auto from typing import Any, Callable, Dict, List, NamedTuple, Optional, TypeVar, Union import importlib @@ -84,6 +85,7 @@ def __init__(self, *args, **kwargs): # pylint: disable=unused-argument AUTOCAST = env_val_to_bool(val=embedding_cfg.get("autocast")) IPEX = env_val_to_bool(val=embedding_cfg.get("ipex")) +GRAPH_MODE = env_val_to_bool(val=embedding_cfg.get("graphmode")) PT2_COMPILE = env_val_to_bool(val=embedding_cfg.get("pt2_compile")) RETRIES = env_val_to_int(val=embedding_cfg.get("retries"), default=0) BATCH_SIZE = env_val_to_int(val=embedding_cfg.get("batch_size"), default=0) @@ -801,6 +803,14 @@ def sum_token_count( class SentenceTransformerWithTruncate(SentenceTransformer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if GRAPH_MODE: + # Initialize the compiled model right after the base class initialization + self.compiled_model = ( + self._apply_graph_mode() + ) # Compile and store the graph model + def _truncate_input_tokens( self, truncate_input_tokens: int, @@ -904,6 +914,44 @@ def _truncate_input_tokens( return TruncatedTokensTuple(tokenized, input_token_count) + def _apply_graph_mode(self) -> torch.jit.ScriptModule: + """ + Compiles the model into a TorchScript graph using predefined fixed-size randomized + input tensors.The tensors simulate typical input structures without relying + on actual input feature data. + + :return: A TorchScript graph that is optimized for inference. + """ + self.eval() + + max_seq_length = self.max_seq_length + vocab_size = self.tokenizer.vocab_size + + # Generate random input_ids within the vocabulary range and a full attention mask + input_ids = torch.randint(low=0, high=vocab_size, size=(1, max_seq_length)) + attention_mask = torch.ones(1, max_seq_length).int() + + # Context manager for automatic mixed precision, if applicable + context_manager = torch.cpu.amp.autocast() if AUTOCAST else nullcontext() + + with torch.no_grad(), context_manager: + # Trace the model with the synthetic input to create a TorchScript graph + compiled_graph = torch.jit.trace( + self, + ( + { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + ), + strict=False, + ) + + # Freeze the compiled graph to optimize it for runtime performance + compiled_graph = torch.jit.freeze(compiled_graph) + + return compiled_graph + def encode( self, sentences: Union[str, List[str]], @@ -954,7 +1002,7 @@ def encode( output_value, normalize_embeddings, ) - + # torchscript requires eval mode self.eval() if convert_to_tensor: @@ -999,20 +1047,18 @@ def encode( features = batch_to_device(features, device) - if AUTOCAST: - with torch.no_grad(), torch.cpu.amp.autocast(): - out_features = self.forward(features) - embeddings = out_features["sentence_embedding"] - if convert_to_numpy: - embeddings = embeddings.detach().cpu() - all_embeddings.extend(embeddings) - else: - with torch.no_grad(): - out_features = self.forward(features) - embeddings = out_features["sentence_embedding"] - if convert_to_numpy: - embeddings = embeddings.detach().cpu() - all_embeddings.extend(embeddings) + # Determine which model to use based on GRAPH_MODE + model_to_use = self.compiled_model if GRAPH_MODE else self.forward + + # Execution context based on AUTOCAST + context_manager = torch.cpu.amp.autocast() if AUTOCAST else nullcontext() + + with torch.no_grad(), context_manager: + out_features = model_to_use(features) + embeddings = out_features["sentence_embedding"] + if convert_to_numpy: + embeddings = embeddings.detach().cpu() + all_embeddings.extend(embeddings) # Restore original order all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]