diff --git a/gritlm/gritlm.py b/gritlm/gritlm.py index 8d4d244..7d352fa 100644 --- a/gritlm/gritlm.py +++ b/gritlm/gritlm.py @@ -3,7 +3,7 @@ import numpy as np import torch from tqdm import tqdm -from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer +from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, BatchEncoding class GritLM(torch.nn.Module): @@ -92,7 +92,7 @@ def encode_corpus(self, corpus: Union[List[str], str, List[Dict[str, str]]], **k @torch.no_grad() def encode( self, - sentences: Union[List[str], str], + sentences: Union[Union[BatchEncoding, dict], Union[List[str], str]], batch_size: int = 256, max_length: int = 512, instruction: str = "", @@ -106,25 +106,36 @@ def encode( if self.num_gpus > 1: batch_size *= self.num_gpus - input_was_string = False if isinstance(sentences, str): sentences = [sentences] - input_was_string = True + input_type = "string" + elif isinstance(sentences, dict): + input_type = "dict" + elif isinstance(sentences, BatchEncoding): + sentences = dict(sentences) + input_type = "dict" + else: + input_type = "list" all_embeddings, all_kv_caches = [], [] for start_index in tqdm(range(0, len(sentences), batch_size), desc="Batches", disable=len(sentences)<256): - sentences_batch = [ - instruction + s + self.embed_eos for s in sentences[start_index:start_index + batch_size] - ] - # This will prepend the bos token if the tokenizer has `add_bos_token=True` - inputs = self.tokenizer( - sentences_batch, - padding=True, - truncation=True, - return_tensors='pt', - max_length=max_length, - add_special_tokens=add_special_tokens, - ).to(self.device) + if input_type == "list" or input_type == "string": + sentences_batch = sentences[start_index:start_index + batch_size] + sentences_batch = [ + instruction + s + self.embed_eos for s in sentences_batch + ] + # This will prepend the bos token if the tokenizer has `add_bos_token=True` + inputs = self.tokenizer( + sentences_batch, + padding=True, + truncation=True, + return_tensors='pt', + max_length=max_length, + add_special_tokens=add_special_tokens, + ).to(self.device) + elif input_type == "dict": + inputs = {k: v[start_index:start_index + batch_size] for k,v in sentences.items() if isinstance(v, torch.Tensor)} + inputs = {k: v.to(self.device) for k,v in inputs.items() if isinstance(v, torch.Tensor)} if (self.attn is not None) and (self.attn[:2] == 'bb'): inputs["is_causal"] = False @@ -166,7 +177,7 @@ def encode( all_embeddings = ( torch.cat(all_embeddings, dim=0) if convert_to_tensor else np.concatenate(all_embeddings, axis=0) ) - if input_was_string: + if input_type == "string": all_embeddings = all_embeddings[0] if get_cache: # all_kv_caches = (