From 99247bc7cec42553bfacf2a92c4b86602404d2af Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Fri, 23 Aug 2024 11:04:51 +0330 Subject: [PATCH 01/10] Package into a PR --- pyproject.toml | 2 +- src/lighteval/few_shot_manager.py | 41 ++--- src/lighteval/models/abstract_model.py | 42 +++-- src/lighteval/models/base_model.py | 35 ++-- src/lighteval/models/endpoint_model.py | 242 +++++++++++++------------ src/lighteval/models/nanotron_model.py | 63 +++---- src/lighteval/models/tgi_model.py | 34 +--- src/lighteval/tasks/lighteval_task.py | 5 +- src/lighteval/tasks/requests.py | 20 +- tests/test_endpoint_model.py | 178 ++++++++++++++++++ 10 files changed, 423 insertions(+), 239 deletions(-) create mode 100644 tests/test_endpoint_model.py diff --git a/pyproject.toml b/pyproject.toml index e301d7afd..bb44b229f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,7 +88,7 @@ nanotron = [ ] tensorboardX = ["tensorboardX"] quality = ["ruff==v0.2.2","pre-commit"] -tests = ["pytest==7.4.0"] +tests = ["pytest==7.4.0", "docker"] dev = ["lighteval[accelerate,quality,tests]"] extended_tasks = [ "langdetect", # ifeval diff --git a/src/lighteval/few_shot_manager.py b/src/lighteval/few_shot_manager.py index 081703a8a..af3c0c493 100644 --- a/src/lighteval/few_shot_manager.py +++ b/src/lighteval/few_shot_manager.py @@ -27,10 +27,11 @@ from itertools import cycle from typing import TYPE_CHECKING, Optional -from transformers import AutoTokenizer, PreTrainedTokenizer +from huggingface_hub import ChatCompletionInputMessage +from transformers import PreTrainedTokenizerBase from lighteval.logging.hierarchical_logger import hlog_warn -from lighteval.tasks.requests import Doc +from lighteval.tasks.requests import Context, Conversation, Doc if TYPE_CHECKING: @@ -181,27 +182,25 @@ def init_fewshot_sampling_balanced( def get_examples_with_chat_template( self, task: "LightevalTask", - tokenizer: AutoTokenizer, example: str, instruction: str, fewshot_ex: list[str], system_prompt: str, - ): - examples = [] + ) -> Conversation: + examples: Conversation = [] if system_prompt is not None: - examples.append({"role": "system", "content": system_prompt}) + examples.append(ChatCompletionInputMessage(role="system", content=system_prompt)) for ex in fewshot_ex: - examples.append({"role": "user", "content": task.doc_to_text_without_instructions(ex)}) - examples.append({"role": "assistant", "content": task.doc_to_target(ex)}) + examples.append(ChatCompletionInputMessage(role="user", content=task.doc_to_text_without_instructions(ex))) + examples.append(ChatCompletionInputMessage(role="assistant", content=task.doc_to_target(ex))) # We add the actual example - examples.append({"role": "user", "content": example}) + examples.append(ChatCompletionInputMessage(role="user", content=example)) # We add the initial instruction if present, after the system prompt of before the task - if examples[0]["role"] == "system": - examples[0]["content"] = examples[0]["content"] + instruction + if examples[0].role == "system": + examples[0].content = examples[0].content + instruction else: - examples[0]["content"] = instruction + examples[0]["content"] - - return tokenizer.apply_chat_template(examples, tokenize=False, add_generation_prompt=True) + examples[0].content = instruction + examples[0].content + return examples def get_examples( self, @@ -209,7 +208,7 @@ def get_examples( example: str, instruction: str, fewshot_ex: list[str], - ): + ) -> str: if len(fewshot_ex) == 0: return instruction + example @@ -220,7 +219,7 @@ def get_examples( return instruction + labeled_examples + example def create_multi_turn_contexts( - self, doc: Doc, use_chat_template: bool, system_prompt: Optional[str], tokenizer: PreTrainedTokenizer + self, doc: Doc, use_chat_template: bool, system_prompt: Optional[str], tokenizer: PreTrainedTokenizerBase ) -> list[str]: """Creates N contexts (depending on the number of turn) for a tasks. Multi turn tasks need use chat templating. @@ -268,10 +267,10 @@ def fewshot_context( sampler: Optional[random.Random] = None, truncate_few_shots: bool = False, max_model_length: Optional[int] = None, - tokenizer: Optional[AutoTokenizer] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, use_chat_template=False, system_prompt: str = None, - ): + ) -> tuple[Context, int]: """Returns a fewshot context string that is made up of a prepended description (if provided), the `num_fewshot` number of examples, and an appended prompt example. @@ -300,13 +299,12 @@ def fewshot_context( if use_chat_template: output = self.get_examples_with_chat_template( task=task, - tokenizer=tokenizer, example=example, instruction=instruction, fewshot_ex=fewshot_ex, system_prompt=system_prompt, ) - toks = tokenizer(output)["input_ids"] + toks = tokenizer.apply_chat_template(output, add_generation_prompt=True) else: output = self.get_examples(task=task, example=example, instruction=instruction, fewshot_ex=fewshot_ex) toks = tokenizer(output)["input_ids"] @@ -324,13 +322,12 @@ def fewshot_context( if use_chat_template: output = self.get_examples_with_chat_template( task=task, - tokenizer=tokenizer, example=example, instruction=instruction, fewshot_ex=fewshot_ex[:num_effective_fewshots], system_prompt=system_prompt, ) - toks = tokenizer(output)["input_ids"] + toks = tokenizer.apply_chat_template(output, add_generation_prompt=True) else: output = self.get_examples( task=task, diff --git a/src/lighteval/models/abstract_model.py b/src/lighteval/models/abstract_model.py index b9111c311..1668fbe43 100644 --- a/src/lighteval/models/abstract_model.py +++ b/src/lighteval/models/abstract_model.py @@ -24,7 +24,8 @@ from typing import Optional, Union import torch -from transformers import BatchEncoding +from huggingface_hub import ChatCompletionInputMessage +from transformers import BatchEncoding, PreTrainedTokenizerBase from lighteval.models.model_config import EnvConfig from lighteval.models.model_output import ( @@ -34,12 +35,14 @@ LoglikelihoodSingleTokenReturn, ) from lighteval.tasks.requests import ( + Conversation, GreedyUntilMultiTurnRequest, GreedyUntilRequest, LoglikelihoodRequest, LoglikelihoodRollingRequest, LoglikelihoodSingleTokenRequest, ) +from lighteval.utils import as_list TokenSequence = Union[list[int], torch.LongTensor, torch.Tensor, BatchEncoding] @@ -64,7 +67,7 @@ def cleanup(self): @property @abstractmethod - def tokenizer(self): + def tokenizer(self) -> PreTrainedTokenizerBase: raise NotImplementedError @property @@ -133,17 +136,34 @@ def loglikelihood_single_token( return NotImplemented # Tokenization utils - def tok_encode(self, str_to_encode: str | list[str], add_special_tokens: Optional[bool] = None) -> TokenSequence: + def tok_encode( + self, + input: str | list[str] | ChatCompletionInputMessage | Conversation | list[Conversation], + add_special_tokens: Optional[bool] = None, + ) -> TokenSequence: if add_special_tokens is None: add_special_tokens = self.add_special_tokens - if isinstance(str_to_encode, str): - return self.tokenizer.encode(str_to_encode, add_special_tokens=add_special_tokens) - return self.tokenizer( - str_to_encode, - padding=True, - add_special_tokens=add_special_tokens, - return_tensors="pt", - ) + if isinstance(input, str): + return self.tokenizer.encode(input, add_special_tokens=add_special_tokens) + elif isinstance(input, ChatCompletionInputMessage) or isinstance(input[0], ChatCompletionInputMessage): + return self.tokenizer.apply_chat_template( + as_list(input), add_generation_prompt=True, add_special_tokens=add_special_tokens + ) + elif isinstance(input, list) and isinstance(input[0], str): + return self.tokenizer( + input, + padding=True, + add_special_tokens=add_special_tokens, + return_tensors="pt", + ) + else: + return self.tokenizer.apply_chat_template( + input, + add_generation_prompt=True, + add_special_tokens=add_special_tokens, + padding=True, + return_tensors="pt", + ) def tok_encode_pair(self, context, continuation): """Encodes a context, continuation pair by taking care of the spaces in between.""" diff --git a/src/lighteval/models/base_model.py b/src/lighteval/models/base_model.py index 666c01319..3589cadac 100644 --- a/src/lighteval/models/base_model.py +++ b/src/lighteval/models/base_model.py @@ -531,16 +531,26 @@ def greedy_until( context = [c.context for c in batch] - # See doc https://huggingface.co/docs/transformers/v4.38.2/en/pad_truncation#padding-and-truncation - # Will do left truncation and padding, as defined when creating the tokenizer - tokenized = self.tokenizer( - context, - truncation="longest_first", # we truncate to the model max length if needed - padding="longest", # we pad to the longest sequence - return_tensors="pt", - max_length=self.max_length - 1, # we always allow minimum one token of generation - add_special_tokens=self.add_special_tokens, - ).to(self.device) + if self.use_chat_template: + tokenized = self.tokenizer.apply_chat_template( + context, + truncation="longest_first", + padding="longest", + return_tensors="pt", + max_length=self.max_length - 1, + add_special_tokens=self.add_special_tokens, + ) + else: + # See doc https://huggingface.co/docs/transformers/v4.38.2/en/pad_truncation#padding-and-truncation + # Will do left truncation and padding, as defined when creating the tokenizer + tokenized = self.tokenizer( + context, + truncation="longest_first", # we truncate to the model max length if needed + padding="longest", # we pad to the longest sequence + return_tensors="pt", + max_length=self.max_length - 1, # we always allow minimum one token of generation + add_special_tokens=self.add_special_tokens, + ).to(self.device) # The main question for this step is the following: # Would we rather truncate the prompt to allow generation to go to max_new_tokens, at the risk @@ -566,10 +576,7 @@ def greedy_until( input_ids=tokenized["input_ids"], input_lengths=[len(item == 1) for item in tokenized["attention_mask"]], input_mask=tokenized["attention_mask"], - truncated=[ - len(c) - tokenized["input_ids"].shape[1] if len(c) > tokenized["input_ids"].shape[1] else 0 - for c in context - ], + truncated=[max(len(c.tokenized_context) - tokenized["input_ids"].shape[1], 0) for c in batch], padded=[sum(mask == 0) for mask in tokenized["attention_mask"]], ) diff --git a/src/lighteval/models/endpoint_model.py b/src/lighteval/models/endpoint_model.py index 87959ef61..aaa13ffb7 100644 --- a/src/lighteval/models/endpoint_model.py +++ b/src/lighteval/models/endpoint_model.py @@ -21,14 +21,20 @@ # SOFTWARE. import asyncio -from typing import Coroutine, List, Optional, Union +from dataclasses import asdict +from typing import Coroutine, List, Optional, TypeAlias, Union, cast import torch from huggingface_hub import ( AsyncInferenceClient, + ChatCompletionInput, + ChatCompletionInputMessage, + ChatCompletionOutput, InferenceClient, InferenceEndpoint, InferenceEndpointTimeoutError, + TextGenerationInput, + TextGenerationInputGenerateParameters, TextGenerationOutput, create_inference_endpoint, get_inference_endpoint, @@ -41,16 +47,25 @@ from lighteval.logging.hierarchical_logger import hlog, hlog_err, hlog_warn from lighteval.models.abstract_model import LightevalModel from lighteval.models.model_config import EnvConfig, InferenceEndpointModelConfig, InferenceModelConfig -from lighteval.models.model_output import GenerateReturn, LoglikelihoodReturn, LoglikelihoodSingleTokenReturn +from lighteval.models.model_output import ( + GenerateReturn, + LoglikelihoodReturn, + LoglikelihoodSingleTokenReturn, +) from lighteval.tasks.requests import ( GreedyUntilRequest, LoglikelihoodRequest, LoglikelihoodRollingRequest, LoglikelihoodSingleTokenRequest, + Request, ) from lighteval.utils import as_list +EndpointInput: TypeAlias = TextGenerationInput | ChatCompletionInput +EndpointOutput: TypeAlias = TextGenerationOutput | ChatCompletionOutput + + BATCH_SIZE = 50 @@ -75,7 +90,6 @@ def __init__( repository=config.repository, revision=config.revision, framework=config.framework, - task="text-generation", accelerator=config.accelerator, vendor=config.vendor, region=config.region, @@ -151,89 +165,127 @@ def max_length(self): self._max_length = 2048 return self._max_length - def _async_process_request( - self, context: str, stop_tokens: list[str], max_tokens: int - ) -> Coroutine[None, list[TextGenerationOutput], str]: - # Todo: add an option to launch with conversational instead for chat prompts - # https://huggingface.co/docs/huggingface_hub/v0.20.3/en/package_reference/inference_client#huggingface_hub.AsyncInferenceClient.conversational - generated_text = self.async_client.text_generation( - prompt=context, - details=True, - decoder_input_details=True, - max_new_tokens=max_tokens, - stop_sequences=stop_tokens, - # truncate=, + def _process_request( + self, prepared_request: EndpointInput, request: Request + ) -> EndpointOutput | Coroutine[None, None, EndpointOutput]: + client = self.async_client if self.use_async else self.client + if isinstance(prepared_request, TextGenerationInput): + # https://github.com/huggingface/huggingface_hub/issues/2471 + request_as_dict = asdict(prepared_request) + request_as_dict["parameters"]["stop_sequences"] = request_as_dict["parameters"]["stop"] + del request_as_dict["parameters"]["stop"] + + return client.text_generation(prepared_request.inputs, **request_as_dict["parameters"]) + elif isinstance(prepared_request, ChatCompletionInput): + return client.chat_completion(**prepared_request) + + def _process_generate_response(self, response: EndpointOutput, request: GreedyUntilRequest) -> GenerateReturn: + is_chat = isinstance(response, ChatCompletionOutput) + if is_chat: + logits = [t.logprob for t in response.choices[0].logprobs.content] + input_tokens = request.tokenized_context + generated_tokens = self.tokenizer.convert_tokens_to_ids( + [t.token for t in response.choices[0].logprobs.content] + ) + else: + logits = [t.logprob for t in response.details.tokens] + input_tokens = [t.id for t in response.details.prefill] + generated_tokens = [t.id for t in response.details.tokens] + return GenerateReturn( + result=response.choices[0].message.content if is_chat else response.generated_text, + logits=logits if request.use_logits else None, + input_tokens=input_tokens, + generated_tokens=generated_tokens, + truncated_tokens_count=-1, + padded_tokens_count=-1, ) - return generated_text - - def _process_request(self, context: str, stop_tokens: list[str], max_tokens: int) -> TextGenerationOutput: - # Todo: add an option to launch with conversational instead for chat prompts - # https://huggingface.co/docs/huggingface_hub/v0.20.3/en/package_reference/inference_client#huggingface_hub.AsyncInferenceClient.conversational - generated_text = self.client.text_generation( - prompt=context, - details=True, - decoder_input_details=True, - max_new_tokens=max_tokens, - stop_sequences=stop_tokens, - # truncate=, - ) + def _process_logprob_response( + self, response: TextGenerationOutput, request: LoglikelihoodRequest | LoglikelihoodRollingRequest + ) -> LoglikelihoodReturn: + cont_toks = torch.tensor(request.tokenized_continuation) + len_choice = len(cont_toks) - return generated_text + logits = sum([t.logprob for t in response.details.prefill[-len_choice:]]) + max_equal = all( + response.details.tokens[i].id == response.details.top_tokens[i][0]["id"] for i in range(-len_choice, 0) + ) + return LoglikelihoodReturn( + result=(logits, max_equal), + input_tokens=[t.id for t in response.details.prefill[:-len_choice]], + generated_tokens=-1, + truncated_tokens_count=-1, + padded_tokens_count=-1, + ) - async def _async_process_batch_generate( + async def _async_process_batch( self, - requests: list[GreedyUntilRequest], - ) -> list[TextGenerationOutput]: + requests: list[Request], + ) -> list[EndpointOutput]: return await asyncio.gather( *[ - self._async_process_request( - context=request.context, - stop_tokens=as_list(request.stop_sequence), - max_tokens=request.generation_size, + cast( + Coroutine[None, None, EndpointOutput], + self._process_request(self._prepare_request(request), request), ) for request in requests ] ) - def _process_batch_generate( + def _process_batch( self, - requests: list[GreedyUntilRequest], - ) -> list[TextGenerationOutput]: + requests: list[Request], + ) -> list[EndpointOutput]: return [ - self._process_request( - context=request.context, - stop_tokens=as_list(request.stop_sequence), - max_tokens=request.generation_size, - ) + cast(EndpointOutput, self._process_request(self._prepare_request(request), request)) for request in requests ] - async def _async_process_batch_logprob( - self, requests: list[LoglikelihoodRequest], rolling: bool = False - ) -> list[TextGenerationOutput]: - return await asyncio.gather( - *[ - self._async_process_request( - context=request.context if rolling else request.context + request.choice, - stop_tokens=[], - max_tokens=1, - ) - for request in requests - ] - ) - - def _process_batch_logprob( - self, requests: list[LoglikelihoodRequest], rolling: bool = False - ) -> list[TextGenerationOutput]: - return [ - self._process_request( - context=request.context if rolling else request.context + request.choice, - stop_tokens=[], - max_tokens=1, + def _prepare_request(self, request: Request) -> EndpointInput: + if isinstance(request, GreedyUntilRequest): + stop = as_list(request.stop_sequence) or None + max_tokens = request.generation_size + context = request.context + elif isinstance(request, (LoglikelihoodRequest, LoglikelihoodRollingRequest)): + stop = None + max_tokens = 1 + rolling = isinstance(request, LoglikelihoodRollingRequest) + if rolling: + context = request.context + elif isinstance(request.context, str): + context = request.context + request.choice + else: + context = request.context + ChatCompletionInputMessage(role="assistant", content=request.choice) + if not isinstance(context, str): + context = self.tokenizer.apply_chat_template(context, add_generation_prompt=True, tokenize=False) + + if isinstance(context, str): + prepared_request = TextGenerationInput( + inputs=context, + parameters=TextGenerationInputGenerateParameters( + details=True, + decoder_input_details=True, + do_sample=False, + seed=42, + max_new_tokens=max_tokens, + stop=stop, + return_full_text=False, + top_n_tokens=1, + ), ) - for request in requests - ] + else: + prepared_request = ChatCompletionInput( + messages=context, + model=self.name, + logprobs=True, + stop=stop, + max_tokens=max_tokens, + seed=42, + temperature=0.0, + top_logprobs=1, + stream=False, + ) + return prepared_request def greedy_until( self, @@ -260,8 +312,6 @@ def greedy_until( for batch in tqdm( dataloader, desc="Greedy generation", position=1, leave=False, disable=self.disable_tqdm ): - # the `returns_logits` flag is only used to filter the results, we always request the full details. - returns_logits = batch[0].use_logits num_samples = batch[0].num_samples if num_samples > 1: hlog_err( @@ -269,18 +319,11 @@ def greedy_until( ) if self.use_async: - responses = asyncio.run(self._async_process_batch_generate(batch)) + responses = asyncio.run(self._async_process_batch(batch)) else: - responses = self._process_batch_generate(batch) + responses = self._process_batch(batch) for response in responses: - results.append( - GenerateReturn( - result=response.generated_text, - logits=[item.logprob for item in response.details.prefill] if returns_logits else None, - truncated_tokens_count=-1, - padded_tokens_count=-1, - ) - ) + results.append(self._process_generate_response(response)) return dataset.get_original_order(results) @@ -305,26 +348,11 @@ def loglikelihood( for batch in tqdm(dataloader, desc="Loglikelihoods", position=1, leave=False, disable=self.disable_tqdm): if self.use_async: - responses = asyncio.run(self._async_process_batch_logprob(batch)) + responses = asyncio.run(self._async_process_batch(batch)) else: - responses = self._process_batch_logprob(batch) + responses = self._process_batch(batch) for cur_request, response in zip(batch, responses): - cont_toks = torch.tensor(cur_request.tokenized_continuation) - len_choice = len(cont_toks) - - logits = [t.logprob for t in response.details.prefill[-len_choice:] if t.logprob is not None] - - greedy_tokens = torch.tensor(logits).argmax(dim=-1) - max_equal = (greedy_tokens == cont_toks).all().squeeze(0) - results.append( - LoglikelihoodReturn( - result=(sum(logits), bool(max_equal)), - input_tokens=[t.id for t in response.details.prefill[:-len_choice]], - generated_tokens=[t.id for t in response.details.prefill[-len_choice:]], - truncated_tokens_count=-1, - padded_tokens_count=-1, - ) - ) + results.append(self._process_logprob_response(cast(TextGenerationOutput, response), cur_request)) return dataset.get_original_order(results) @@ -353,21 +381,11 @@ def loglikelihood_rolling( dataloader, desc="Loglikelihoods, rolling", position=1, leave=False, disable=self.disable_tqdm ): if self.use_async: - responses = asyncio.run(self._async_process_batch_logprob(batch, rolling=True)) + responses = asyncio.run(self._async_process_batch(batch)) else: - responses = self._process_batch_logprob(batch, rolling=True) - for response in responses: - logits = [t.logprob for t in response.details.tokens[:-1]] - - results.append( - LoglikelihoodReturn( - result=sum(logits), - input_tokens=[t.id for t in response.details.prefill], - generated_tokens=[t.id for t in response.details.tokens[:-1]], - truncated_tokens_count=-1, - padded_tokens_count=-1, - ) - ) + responses = self._process_batch(batch) + for response, request in zip(responses, batch): + results.append(self._process_logprob_response(cast(TextGenerationOutput, response), request)) return dataset.get_original_order(results) diff --git a/src/lighteval/models/nanotron_model.py b/src/lighteval/models/nanotron_model.py index d2d6a07f9..845e7da97 100644 --- a/src/lighteval/models/nanotron_model.py +++ b/src/lighteval/models/nanotron_model.py @@ -321,37 +321,9 @@ def forward_batch(batch_size): logger.warning("Determined largest batch size: %d", batch_size) return batch_size - def tok_encode(self, string: str, add_special_tokens: Optional[bool] = None) -> TokenSequence: - # TODO: Merge `tok_encode_batch` here. - if add_special_tokens is None: - add_special_tokens = self.add_special_tokens - return self.tokenizer.encode(string, add_special_tokens=add_special_tokens) - - def tok_encode_batch(self, strings: List[str]) -> TokenSequence: - return self.tokenizer( - strings, - padding=True, - add_special_tokens=self.add_special_tokens, - return_tensors="pt", - ) - - def tok_decode(self, tokens: torch.LongTensor) -> List[str]: - return self.tokenizer.batch_decode(tokens, skip_special_tokens=True) - def _model_call(self, inputs: torch.Tensor) -> torch.Tensor: return self.model(inputs) - def _encode_pair(self, context, continuation): - n_spaces = len(context) - len(context.rstrip()) - if n_spaces > 0: - continuation = context[-n_spaces:] + continuation - context = context[:-n_spaces] - whole_enc = self.tok_encode(context + continuation) - context_enc = self.tok_encode(context) - context_enc_len = len(context_enc) - continuation_enc = whole_enc[context_enc_len:] - return context_enc, continuation_enc - def homogeneize_ending_conditions(self, ending_condition: tuple | dict | list | str) -> tuple[list, int]: """Ending conditions are submitted in several possible formats. By default in lighteval we pass them as tuples (stop sequence, max number of items). @@ -1200,16 +1172,26 @@ def greedy_until( context = [c.context for c in batch] - # See doc https://huggingface.co/docs/transformers/v4.38.2/en/pad_truncation#padding-and-truncation - # Will do left truncation and padding, as defined when creating the tokenizer - tokenized = self.tokenizer( - context, - truncation="longest_first", # we truncate to the model max length if needed - padding="longest", # we pad to the longest sequence - return_tensors="pt", - max_length=self.max_length - 1, # we always allow minimum one token of generation - add_special_tokens=self.add_special_tokens, - ).to(self.device) + if isinstance(context[0], str): + # See doc https://huggingface.co/docs/transformers/v4.38.2/en/pad_truncation#padding-and-truncation + # Will do left truncation and padding, as defined when creating the tokenizer + tokenized = self.tokenizer( + context, + truncation="longest_first", # we truncate to the model max length if needed + padding="longest", # we pad to the longest sequence + return_tensors="pt", + max_length=self.max_length - 1, # we always allow minimum one token of generation + add_special_tokens=self.add_special_tokens, + ).to(self.device) + else: + tokenized = self.tokenizer.apply_chat_template( + context, + truncation="longest_first", + padding="longest", + return_tensors="pt", + max_length=self.max_length - 1, + add_special_tokens=self.add_special_tokens, + ) # The main question for this step is the following: # Would we rather truncate the prompt to allow generation to go to max_new_tokens, at the risk @@ -1232,10 +1214,7 @@ def greedy_until( input_ids=tokenized["input_ids"], input_lengths=[len(item == 1) for item in tokenized["attention_mask"]], input_mask=tokenized["attention_mask"], - truncated=[ - len(c) - tokenized["input_ids"].shape[1] if len(c) > tokenized["input_ids"].shape[1] else 0 - for c in context - ], + truncated=[max(len(c.tokenized_context) - tokenized["input_ids"].shape[1], 0) for c in batch], padded=[sum(mask == 0) for mask in tokenized["attention_mask"]], ) diff --git a/src/lighteval/models/tgi_model.py b/src/lighteval/models/tgi_model.py index 754152587..6de873575 100644 --- a/src/lighteval/models/tgi_model.py +++ b/src/lighteval/models/tgi_model.py @@ -20,19 +20,11 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import asyncio -from typing import Coroutine - import requests -from huggingface_hub import TextGenerationOutput +from huggingface_hub import AsyncInferenceClient, InferenceClient from transformers import AutoTokenizer from lighteval.models.endpoint_model import InferenceEndpointModel -from lighteval.utils import NO_TGI_ERROR_MSG, is_tgi_available - - -if is_tgi_available(): - from text_generation import AsyncClient BATCH_SIZE = 50 @@ -44,17 +36,14 @@ def divide_chunks(array, n): yield array[i : i + n] -# inherit from InferenceEndpointModel instead of LightevalModel since they both use the same interface, and only overwrite -# the client functions, since they use a different client. class ModelClient(InferenceEndpointModel): _DEFAULT_MAX_LENGTH: int = 4096 def __init__(self, address, auth_token=None, model_id=None) -> None: - if not is_tgi_available(): - raise ImportError(NO_TGI_ERROR_MSG) headers = {} if auth_token is None else {"Authorization": f"Bearer {auth_token}"} - self.client = AsyncClient(address, headers=headers, timeout=240) + self.client = InferenceClient(address, headers=headers, timeout=240) + self.async_client = AsyncInferenceClient(address, headers=headers, timeout=240) self._max_gen_toks = 256 self.model_info = requests.get(f"{address}/info", headers=headers).json() if "model_id" not in self.model_info: @@ -64,22 +53,7 @@ def __init__(self, address, auth_token=None, model_id=None) -> None: self._tokenizer = AutoTokenizer.from_pretrained(self.model_info["model_id"]) self._add_special_tokens = True self.use_async = True - - def _async_process_request( - self, context: str, stop_tokens: list[str], max_tokens: int - ) -> Coroutine[None, list[TextGenerationOutput], str]: - # Todo: add an option to launch with conversational instead for chat prompts - generated_text = self.client.generate( - prompt=context, - decoder_input_details=True, - max_new_tokens=max_tokens, - stop_sequences=stop_tokens, - ) - - return generated_text - - def _process_request(self, *args, **kwargs) -> TextGenerationOutput: - return asyncio.run(self._async_process_request(*args, **kwargs)) + self.name = address def set_cache_hook(self, cache_hook): self.cache_hook = cache_hook diff --git a/src/lighteval/tasks/lighteval_task.py b/src/lighteval/tasks/lighteval_task.py index 07120b711..902eb74fa 100644 --- a/src/lighteval/tasks/lighteval_task.py +++ b/src/lighteval/tasks/lighteval_task.py @@ -45,6 +45,7 @@ from lighteval.models.base_model import BaseModel from lighteval.models.model_output import ModelReturn from lighteval.tasks.requests import ( + Context, Doc, GreedyUntilMultiTurnRequest, GreedyUntilRequest, @@ -438,14 +439,14 @@ def get_request_type(self) -> list[RequestType]: # noqa C901 return list(set(request_types)) def construct_requests( - self, formatted_doc: Doc, context: str, document_id_seed: str, current_task_name: str + self, formatted_doc: Doc, context: Context, document_id_seed: str, current_task_name: str ) -> Dict[RequestType, List[Request]]: """ Constructs a list of requests from the task based on the given parameters. Args: formatted_doc (Doc): Formatted document almost straight from the dataset. - ctx (str): Context, which is the few shot examples + the query. + context (Context): Context, which is the few shot examples + the query. document_id_seed (str): Index of the document in the task appended with the seed used for the few shot sampling. current_task_name (str): Name of the current task. diff --git a/src/lighteval/tasks/requests.py b/src/lighteval/tasks/requests.py index 2bd690232..542ed7a91 100644 --- a/src/lighteval/tasks/requests.py +++ b/src/lighteval/tasks/requests.py @@ -23,11 +23,17 @@ import json from dataclasses import asdict, dataclass from enum import Enum, auto -from typing import NamedTuple, Optional, Union +from typing import List, NamedTuple, Optional, TypeAlias, Union + +from huggingface_hub import ChatCompletionInputMessage from lighteval.utils import as_list +# We later could move this and similar types to lighteval/types.py +Conversation: TypeAlias = List[ChatCompletionInputMessage] + + class RequestType(Enum): LOGLIKELIHOOD = auto() LOGLIKELIHOOD_SINGLE_TOKEN = auto() @@ -36,6 +42,9 @@ class RequestType(Enum): GREEDY_UNTIL_MULTI_TURN = auto() +Context: TypeAlias = object + + @dataclass class Request: """ @@ -48,13 +57,13 @@ class Request: task_name (str): The name of the task. example_index (int): The index of the example. request_index (int): The index of the request. - context (str): The context for the request. + context (ContextType): The context for the request. """ task_name: str example_index: int request_index: int - context: str + context: Context @dataclass @@ -117,7 +126,7 @@ class GreedyUntilRequest(Request): stop_sequence: Union[str, tuple[str], list[str]] generation_size: int request_type = RequestType.GREEDY_UNTIL - tokenized_context: list[int] = None + tokenized_context: Optional[list[int]] = None num_samples: int = None use_logits: bool = False @@ -135,6 +144,7 @@ class GreedyUntilMultiTurnRequest(Request): stop_sequence: str generation_size: int + context: Conversation request_type = RequestType.GREEDY_UNTIL_MULTI_TURN use_logits: bool = False @@ -173,7 +183,7 @@ class Doc: target_for_fewshot_sorting: Optional[str] = None # will probably have to be removed in the future # Filled when parsing and adding the few-shot context - ctx: Optional[str] = "" + ctx: Optional[Context] = "" num_asked_few_shots: int = -1 num_effective_few_shots: int = -1 diff --git a/tests/test_endpoint_model.py b/tests/test_endpoint_model.py new file mode 100644 index 000000000..a22c9c869 --- /dev/null +++ b/tests/test_endpoint_model.py @@ -0,0 +1,178 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import os +import random +import time +from collections import defaultdict +from typing import Iterator, TypeAlias + +import docker +import pytest +import requests +import torch +from huggingface_hub import ChatCompletionInputMessage, ChatCompletionOutputMessage + +from lighteval.evaluator import EvaluationTracker, evaluate +from lighteval.metrics.metrics import Metrics +from lighteval.models.tgi_model import ModelClient as TGIModel +from lighteval.tasks.lighteval_task import LightevalTask, LightevalTaskConfig, create_requests_from_tasks +from lighteval.tasks.requests import ( + Doc, + Request, + RequestType, +) + + +TOKEN = os.environ.get("HF_TOKEN") +CACHE_PATH = os.getenv("HF_HOME", ".") + + +@pytest.fixture(scope="module") +def tgi_model() -> Iterator[TGIModel]: + client = docker.from_env() + port = random.randint(8000, 9000) + container = client.containers.run( + "ghcr.io/huggingface/text-generation-inference:2.2.0", + command=[ + "--model-id", + "hf-internal-testing/tiny-random-LlamaForCausalLM", + "--dtype", + "float16", + ], + detach=True, + name="lighteval-tgi-model-test", + auto_remove=True, + ports={"80/tcp": port}, + ) + address = f"http://localhost:{port}" + for _ in range(30): + try: + if requests.get(f"{address}/health"): + break + except Exception: + time.sleep(1) + else: + raise RuntimeError("Couldn't setup TGI server.") + model = TGIModel(address) + yield model + container.stop() + container.wait() + model.cleanup() + + +RequestDict: TypeAlias = dict[RequestType, list[Request]] + + +class TestEndpointModel: + @pytest.fixture + def task(self) -> LightevalTask: + eval_docs = [ + Doc(query="How are you?", choices=["Fine, thanks!", "Not bad!"], instruction="Tell me:\n\n", gold_index=0), + Doc( + query="Comment vas-tu?", + choices=["Ca va! Merci!", "Comme ci, comme ça"], + instruction="Tell me:\n\n", + gold_index=0, + ), + ] + fewshot_docs = [ + Doc(query="كيف حالك؟", choices=["جيد شكراً!", "ليس سيئًا!"], instruction="Tell me:\n\n", gold_index=0), + Doc( + query="Wie geht es dir?", + choices=["Gut, danke!", "Nicht schlecht!"], + instruction="Tell me:\n\n", + gold_index=0, + ), + ] + task_config = LightevalTaskConfig( + "test", lambda _: _, "", "", [Metrics.loglikelihood_acc, Metrics.exact_match, Metrics.byte_perplexity] + ) + task = LightevalTask("test", task_config) + task._docs = eval_docs + task._fewshot_docs = fewshot_docs + return task + + @pytest.fixture + def zero_shot_request_dict(self, task: LightevalTask) -> RequestDict: + result = defaultdict(list) + for i, doc in enumerate(task.eval_docs()): + if i % 2 == 0: + context = [ChatCompletionInputMessage(role="user", content=doc.query)] + else: + context = doc.query + doc_result = task.construct_requests(doc, context, f"{i}_0", "custom|test|0") + for req_type in doc_result: + result[req_type].extend(doc_result[req_type]) + return result + + def test_model_tokenizer_api(self, zero_shot_request_dict: RequestDict, tgi_model: TGIModel): + assert tgi_model.tok_encode(ChatCompletionInputMessage("user", "Hi there!")) == tgi_model.tok_encode( + [ChatCompletionInputMessage("user", "Hi there!")] + ) + assert isinstance( + tgi_model.tok_encode([req.context for req in zero_shot_request_dict[RequestType.GREEDY_UNTIL]]), + torch.Tensor, + ) + + def test_greedy_until(self, zero_shot_request_dict: RequestDict, tgi_model: TGIModel): + returns = tgi_model.greedy_until(zero_shot_request_dict[RequestType.GREEDY_UNTIL]) + assert len(returns) == 6 + assert all(isinstance(r.result, ChatCompletionOutputMessage) and r.result.content for r in returns[:3]) + assert None not in (returns[2].logits, returns[5].logits) + + def test_loglikelihood(self, zero_shot_request_dict: RequestDict, tgi_model: TGIModel): + returns = tgi_model.loglikelihood(zero_shot_request_dict[RequestType.LOGLIKELIHOOD]) + assert len(returns) == 4 + assert all(r.result[0] is not None for r in returns) + + returns = tgi_model.loglikelihood_rolling(zero_shot_request_dict[RequestType.LOGLIKELIHOOD_ROLLING]) + assert len(returns) == 2 + assert all(r.result[0] is not None for r in returns) + + @pytest.mark.parametrize("num_fewshot", [0, 2]) + @pytest.mark.parametrize("use_chat_template", [False, True]) + def test_integration(self, task: LightevalTask, tgi_model: TGIModel, num_fewshot: int, use_chat_template: bool): + evaluation_tracker = EvaluationTracker() + task_dict = {"custom|test": task} + evaluation_tracker.task_config_logger.log(task_dict) + requests_dict, docs = create_requests_from_tasks( + task_dict=task_dict, + fewshot_dict={"custom|test": [(num_fewshot, False)]}, + num_fewshot_seeds=0, + lm=tgi_model, + max_samples=1, + evaluation_tracker=evaluation_tracker, + use_chat_template=use_chat_template, + ) + + evaluation_tracker = evaluate( + lm=tgi_model, + requests_dict=requests_dict, + docs=docs, + task_dict=task_dict, + override_bs=1, + evaluation_tracker=evaluation_tracker, + ) + evaluation_tracker.metrics_logger.aggregate(task_dict=task_dict) + evaluation_tracker.details_logger.aggregate() + evaluation_tracker.generate_final_dict() From 7dbad3a95732bace4dec4cbb276c51e30aab108d Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Sun, 25 Aug 2024 02:52:19 +0330 Subject: [PATCH 02/10] Add tests --- src/lighteval/models/abstract_model.py | 14 ++- src/lighteval/models/base_model.py | 1 + src/lighteval/models/nanotron_model.py | 1 + tests/test_base_model.py | 145 +++++++++++++++++++++++++ tests/test_endpoint_model.py | 48 ++++---- 5 files changed, 183 insertions(+), 26 deletions(-) create mode 100644 tests/test_base_model.py diff --git a/src/lighteval/models/abstract_model.py b/src/lighteval/models/abstract_model.py index 1668fbe43..29ab87b28 100644 --- a/src/lighteval/models/abstract_model.py +++ b/src/lighteval/models/abstract_model.py @@ -163,14 +163,18 @@ def tok_encode( add_special_tokens=add_special_tokens, padding=True, return_tensors="pt", + return_dict=True, ) - def tok_encode_pair(self, context, continuation): + def tok_encode_pair(self, context: str | Conversation, continuation: str | ChatCompletionInputMessage): """Encodes a context, continuation pair by taking care of the spaces in between.""" - n_spaces = len(context) - len(context.rstrip()) - if n_spaces > 0: - continuation = context[-n_spaces:] + continuation - context = context[:-n_spaces] + if isinstance(context, str): + n_spaces = len(context) - len(context.rstrip()) + if n_spaces > 0: + continuation = context[-n_spaces:] + continuation + context = context[:-n_spaces] + else: + continuation = [continuation] whole_enc = self.tok_encode(context + continuation) context_enc = self.tok_encode(context) context_enc_len = len(context_enc) diff --git a/src/lighteval/models/base_model.py b/src/lighteval/models/base_model.py index 3589cadac..cee5906f5 100644 --- a/src/lighteval/models/base_model.py +++ b/src/lighteval/models/base_model.py @@ -539,6 +539,7 @@ def greedy_until( return_tensors="pt", max_length=self.max_length - 1, add_special_tokens=self.add_special_tokens, + return_dict=True, ) else: # See doc https://huggingface.co/docs/transformers/v4.38.2/en/pad_truncation#padding-and-truncation diff --git a/src/lighteval/models/nanotron_model.py b/src/lighteval/models/nanotron_model.py index 845e7da97..13f9dea08 100644 --- a/src/lighteval/models/nanotron_model.py +++ b/src/lighteval/models/nanotron_model.py @@ -1191,6 +1191,7 @@ def greedy_until( return_tensors="pt", max_length=self.max_length - 1, add_special_tokens=self.add_special_tokens, + return_dict=True, ) # The main question for this step is the following: diff --git a/tests/test_base_model.py b/tests/test_base_model.py new file mode 100644 index 000000000..da695fd1b --- /dev/null +++ b/tests/test_base_model.py @@ -0,0 +1,145 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import os +from typing import Iterator, TypeAlias + +import pytest +from huggingface_hub import ChatCompletionInputMessage +from transformers import BatchEncoding + +from lighteval.evaluator import EvaluationTracker, evaluate +from lighteval.metrics.metrics import Metrics +from lighteval.models.base_model import BaseModel +from lighteval.models.model_config import BaseModelConfig, EnvConfig +from lighteval.tasks.lighteval_task import LightevalTask, LightevalTaskConfig, create_requests_from_tasks +from lighteval.tasks.requests import ( + Doc, + Request, + RequestType, +) + + +TOKEN = os.environ.get("HF_TOKEN") +CACHE_PATH = os.getenv("HF_HOME", ".") + + +@pytest.fixture(scope="module") +def base_model() -> Iterator[BaseModel]: + config = BaseModelConfig("hf-internal-testing/tiny-random-LlamaForCausalLM") + return BaseModel(config, EnvConfig(CACHE_PATH, TOKEN)) + + +RequestDict: TypeAlias = dict[RequestType, list[Request]] + + +def test_abstract_model_tokenizer_api(base_model: BaseModel): + encoded = base_model.tok_encode("Hi there!") + assert isinstance(encoded, list) and isinstance(encoded[0], int) + + encoded = base_model.tok_encode(ChatCompletionInputMessage("user", "Hi there!")) + assert encoded == base_model.tok_encode([ChatCompletionInputMessage("user", "Hi there!")]) + assert isinstance(encoded, list) and isinstance(encoded[0], int) + + assert isinstance( + base_model.tok_encode(["Hi there!", "Hello there!"]), + BatchEncoding, + ) + + assert isinstance(base_model.tok_encode([[ChatCompletionInputMessage("user", "Hi there!")]]), BatchEncoding) + + +class TestBaseModel: + @pytest.fixture + def task(self) -> LightevalTask: + eval_docs = [ + Doc( + query="Tell me:\n\nHow are you?", + choices=["Fine, thanks!", "Not bad!"], + instruction="Tell me:\n\n", + gold_index=0, + ), + Doc( + query="Tell me:\n\nComment vas-tu?", + choices=["Ca va! Merci!", "Comme ci, comme ça"], + instruction="Tell me:\n\n", + gold_index=0, + ), + ] + fewshot_docs = [ + Doc( + query="Tell me:\n\nكيف حالك؟", + choices=["جيد شكراً!", "ليس سيئًا!"], + instruction="Tell me:\n\n", + gold_index=0, + ), + Doc( + query="Tell me:\n\nWie geht es dir?", + choices=["Gut, danke!", "Nicht schlecht!"], + instruction="Tell me:\n\n", + gold_index=0, + ), + ] + task_config = LightevalTaskConfig( + name="test", + prompt_function=lambda _: _, + hf_repo="", + hf_subset="", + metric=[Metrics.loglikelihood_acc, Metrics.exact_match, Metrics.byte_perplexity], + generation_size=5, + stop_sequence=[], + ) + task = LightevalTask("test", task_config) + task._docs = eval_docs + task._fewshot_docs = fewshot_docs + return task + + @pytest.mark.parametrize("num_fewshot", [0, 2]) + @pytest.mark.parametrize("use_chat_template", [False, True]) + def test_integration(self, task: LightevalTask, base_model: BaseModel, num_fewshot: int, use_chat_template: bool): + base_model.use_chat_template = use_chat_template + + evaluation_tracker = EvaluationTracker() + task_dict = {"custom|test": task} + evaluation_tracker.task_config_logger.log(task_dict) + requests_dict, docs = create_requests_from_tasks( + task_dict=task_dict, + fewshot_dict={"custom|test": [(num_fewshot, False)]}, + num_fewshot_seeds=0, + lm=base_model, + max_samples=1, + evaluation_tracker=evaluation_tracker, + use_chat_template=use_chat_template, + system_prompt=None, + ) + + evaluation_tracker = evaluate( + lm=base_model, + requests_dict=requests_dict, + docs=docs, + task_dict=task_dict, + override_bs=1, + evaluation_tracker=evaluation_tracker, + ) + evaluation_tracker.metrics_logger.aggregate(task_dict=task_dict) + evaluation_tracker.details_logger.aggregate() + evaluation_tracker.generate_final_dict() diff --git a/tests/test_endpoint_model.py b/tests/test_endpoint_model.py index a22c9c869..932ec6122 100644 --- a/tests/test_endpoint_model.py +++ b/tests/test_endpoint_model.py @@ -29,8 +29,7 @@ import docker import pytest import requests -import torch -from huggingface_hub import ChatCompletionInputMessage, ChatCompletionOutputMessage +from huggingface_hub import ChatCompletionInputMessage from lighteval.evaluator import EvaluationTracker, evaluate from lighteval.metrics.metrics import Metrics @@ -87,25 +86,41 @@ class TestEndpointModel: @pytest.fixture def task(self) -> LightevalTask: eval_docs = [ - Doc(query="How are you?", choices=["Fine, thanks!", "Not bad!"], instruction="Tell me:\n\n", gold_index=0), Doc( - query="Comment vas-tu?", + query="Tell me:\n\nHow are you?", + choices=["Fine, thanks!", "Not bad!"], + instruction="Tell me:\n\n", + gold_index=0, + ), + Doc( + query="Tell me:\n\nComment vas-tu?", choices=["Ca va! Merci!", "Comme ci, comme ça"], instruction="Tell me:\n\n", gold_index=0, ), ] fewshot_docs = [ - Doc(query="كيف حالك؟", choices=["جيد شكراً!", "ليس سيئًا!"], instruction="Tell me:\n\n", gold_index=0), Doc( - query="Wie geht es dir?", + query="Tell me:\n\nكيف حالك؟", + choices=["جيد شكراً!", "ليس سيئًا!"], + instruction="Tell me:\n\n", + gold_index=0, + ), + Doc( + query="Tell me:\n\nWie geht es dir?", choices=["Gut, danke!", "Nicht schlecht!"], instruction="Tell me:\n\n", gold_index=0, ), ] task_config = LightevalTaskConfig( - "test", lambda _: _, "", "", [Metrics.loglikelihood_acc, Metrics.exact_match, Metrics.byte_perplexity] + name="test", + prompt_function=lambda _: _, + hf_repo="", + hf_subset="", + metric=[Metrics.loglikelihood_acc, Metrics.exact_match, Metrics.byte_perplexity], + generation_size=5, + stop_sequence=[], ) task = LightevalTask("test", task_config) task._docs = eval_docs @@ -125,28 +140,18 @@ def zero_shot_request_dict(self, task: LightevalTask) -> RequestDict: result[req_type].extend(doc_result[req_type]) return result - def test_model_tokenizer_api(self, zero_shot_request_dict: RequestDict, tgi_model: TGIModel): - assert tgi_model.tok_encode(ChatCompletionInputMessage("user", "Hi there!")) == tgi_model.tok_encode( - [ChatCompletionInputMessage("user", "Hi there!")] - ) - assert isinstance( - tgi_model.tok_encode([req.context for req in zero_shot_request_dict[RequestType.GREEDY_UNTIL]]), - torch.Tensor, - ) - def test_greedy_until(self, zero_shot_request_dict: RequestDict, tgi_model: TGIModel): returns = tgi_model.greedy_until(zero_shot_request_dict[RequestType.GREEDY_UNTIL]) - assert len(returns) == 6 - assert all(isinstance(r.result, ChatCompletionOutputMessage) and r.result.content for r in returns[:3]) - assert None not in (returns[2].logits, returns[5].logits) + assert len(returns) == 4 + assert all(r.result is not None for r in returns) def test_loglikelihood(self, zero_shot_request_dict: RequestDict, tgi_model: TGIModel): returns = tgi_model.loglikelihood(zero_shot_request_dict[RequestType.LOGLIKELIHOOD]) - assert len(returns) == 4 + assert len(returns) == 8 assert all(r.result[0] is not None for r in returns) returns = tgi_model.loglikelihood_rolling(zero_shot_request_dict[RequestType.LOGLIKELIHOOD_ROLLING]) - assert len(returns) == 2 + assert len(returns) == 4 assert all(r.result[0] is not None for r in returns) @pytest.mark.parametrize("num_fewshot", [0, 2]) @@ -163,6 +168,7 @@ def test_integration(self, task: LightevalTask, tgi_model: TGIModel, num_fewshot max_samples=1, evaluation_tracker=evaluation_tracker, use_chat_template=use_chat_template, + system_prompt=None, ) evaluation_tracker = evaluate( From e1a5bc156550ffab9105c9e17ed3eba16d27319b Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Sun, 25 Aug 2024 22:07:35 +0330 Subject: [PATCH 03/10] Adapt with huggingface_hub change in ChatCompletionInputMessage --- tests/test_base_model.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/test_base_model.py b/tests/test_base_model.py index da695fd1b..050b5ffb4 100644 --- a/tests/test_base_model.py +++ b/tests/test_base_model.py @@ -56,8 +56,8 @@ def test_abstract_model_tokenizer_api(base_model: BaseModel): encoded = base_model.tok_encode("Hi there!") assert isinstance(encoded, list) and isinstance(encoded[0], int) - encoded = base_model.tok_encode(ChatCompletionInputMessage("user", "Hi there!")) - assert encoded == base_model.tok_encode([ChatCompletionInputMessage("user", "Hi there!")]) + encoded = base_model.tok_encode(ChatCompletionInputMessage(role="user", content="Hi there!")) + assert encoded == base_model.tok_encode([ChatCompletionInputMessage(role="user", content="Hi there!")]) assert isinstance(encoded, list) and isinstance(encoded[0], int) assert isinstance( @@ -65,7 +65,9 @@ def test_abstract_model_tokenizer_api(base_model: BaseModel): BatchEncoding, ) - assert isinstance(base_model.tok_encode([[ChatCompletionInputMessage("user", "Hi there!")]]), BatchEncoding) + assert isinstance( + base_model.tok_encode([[ChatCompletionInputMessage(role="user", content="Hi there!")]]), BatchEncoding + ) class TestBaseModel: From b44362b52154ad867ad453dd6981b81c8e6e8bdd Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Sun, 25 Aug 2024 22:38:36 +0330 Subject: [PATCH 04/10] Fix sth in tgi_model --- tests/test_endpoint_model.py | 35 +++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/tests/test_endpoint_model.py b/tests/test_endpoint_model.py index 932ec6122..325145d43 100644 --- a/tests/test_endpoint_model.py +++ b/tests/test_endpoint_model.py @@ -27,6 +27,7 @@ from typing import Iterator, TypeAlias import docker +import docker.errors import pytest import requests from huggingface_hub import ChatCompletionInputMessage @@ -49,20 +50,25 @@ @pytest.fixture(scope="module") def tgi_model() -> Iterator[TGIModel]: client = docker.from_env() - port = random.randint(8000, 9000) - container = client.containers.run( - "ghcr.io/huggingface/text-generation-inference:2.2.0", - command=[ - "--model-id", - "hf-internal-testing/tiny-random-LlamaForCausalLM", - "--dtype", - "float16", - ], - detach=True, - name="lighteval-tgi-model-test", - auto_remove=True, - ports={"80/tcp": port}, - ) + + try: + container = client.containers.get("lighteval-tgi-model-test") + port = container.ports["80/tcp"][0]["HostPort"] + except docker.errors.NotFound: + port = random.randint(8000, 9000) + container = client.containers.run( + "ghcr.io/huggingface/text-generation-inference:2.2.0", + command=[ + "--model-id", + "hf-internal-testing/tiny-random-LlamaForCausalLM", + "--dtype", + "float16", + ], + detach=True, + name="lighteval-tgi-model-test", + auto_remove=False, + ports={"80/tcp": port}, + ) address = f"http://localhost:{port}" for _ in range(30): try: @@ -76,6 +82,7 @@ def tgi_model() -> Iterator[TGIModel]: yield model container.stop() container.wait() + container.remove() model.cleanup() From a31d2a2385b9a8202edba380c25f36d958662010 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Wed, 28 Aug 2024 07:01:06 +0330 Subject: [PATCH 05/10] Fix tiny bugs --- src/lighteval/models/abstract_model.py | 5 +---- src/lighteval/models/endpoint_model.py | 20 +++++++------------- src/lighteval/models/tgi_model.py | 6 +++--- tests/test_endpoint_model.py | 10 +++++----- 4 files changed, 16 insertions(+), 25 deletions(-) diff --git a/src/lighteval/models/abstract_model.py b/src/lighteval/models/abstract_model.py index eca6ae13e..ec5e64e11 100644 --- a/src/lighteval/models/abstract_model.py +++ b/src/lighteval/models/abstract_model.py @@ -168,9 +168,7 @@ def tok_encode( if isinstance(input, str): return self.tokenizer.encode(input, add_special_tokens=add_special_tokens) elif isinstance(input, ChatCompletionInputMessage) or isinstance(input[0], ChatCompletionInputMessage): - return self.tokenizer.apply_chat_template( - as_list(input), add_generation_prompt=True, add_special_tokens=add_special_tokens - ) + return self.tokenizer.apply_chat_template(as_list(input), add_special_tokens=add_special_tokens) elif isinstance(input, list) and isinstance(input[0], str): return self.tokenizer( input, @@ -181,7 +179,6 @@ def tok_encode( else: return self.tokenizer.apply_chat_template( input, - add_generation_prompt=True, add_special_tokens=add_special_tokens, padding=True, return_tensors="pt", diff --git a/src/lighteval/models/endpoint_model.py b/src/lighteval/models/endpoint_model.py index cc25fbb73..3d85d67f5 100644 --- a/src/lighteval/models/endpoint_model.py +++ b/src/lighteval/models/endpoint_model.py @@ -24,7 +24,6 @@ from dataclasses import asdict from typing import Coroutine, List, Optional, TypeAlias, Union, cast -import torch from huggingface_hub import ( AsyncInferenceClient, ChatCompletionInput, @@ -206,15 +205,10 @@ def _process_generate_response(self, response: EndpointOutput, request: GreedyUn def _process_logprob_response( self, response: TextGenerationOutput, request: LoglikelihoodRequest | LoglikelihoodRollingRequest ) -> LoglikelihoodResponse: - cont_toks = torch.tensor(request.tokenized_continuation) - len_choice = len(cont_toks) - - logits = sum([t.logprob for t in response.details.prefill[-len_choice:]]) - max_equal = all( - response.details.tokens[i].id == response.details.top_tokens[i][0]["id"] for i in range(-len_choice, 0) - ) + len_choice = len(request.tokenized_continuation) + logits = sum([t.logprob for t in response.details.prefill[1:][-len_choice:]]) return LoglikelihoodResponse( - result=(logits, max_equal), + result=(logits, True) if isinstance(request, LoglikelihoodRequest) else logits, input_tokens=[t.id for t in response.details.prefill[:-len_choice]], generated_tokens=-1, truncated_tokens_count=-1, @@ -258,9 +252,9 @@ def _prepare_request(self, request: Request) -> EndpointInput: elif isinstance(request.context, str): context = request.context + request.choice else: - context = request.context + ChatCompletionInputMessage(role="assistant", content=request.choice) + context = request.context + [ChatCompletionInputMessage(role="assistant", content=request.choice)] if not isinstance(context, str): - context = self.tokenizer.apply_chat_template(context, add_generation_prompt=True, tokenize=False) + context = self.tokenizer.apply_chat_template(context, tokenize=False) if isinstance(context, str): prepared_request = TextGenerationInput( @@ -325,8 +319,8 @@ def greedy_until( responses = asyncio.run(self._async_process_batch(batch)) else: responses = self._process_batch(batch) - for response in responses: - results.append(self._process_generate_response(response)) + for response, request in zip(responses, batch): + results.append(self._process_generate_response(response, request)) return dataset.get_original_order(results) diff --git a/src/lighteval/models/tgi_model.py b/src/lighteval/models/tgi_model.py index 6de873575..92c3f65b4 100644 --- a/src/lighteval/models/tgi_model.py +++ b/src/lighteval/models/tgi_model.py @@ -42,8 +42,8 @@ class ModelClient(InferenceEndpointModel): def __init__(self, address, auth_token=None, model_id=None) -> None: headers = {} if auth_token is None else {"Authorization": f"Bearer {auth_token}"} - self.client = InferenceClient(address, headers=headers, timeout=240) - self.async_client = AsyncInferenceClient(address, headers=headers, timeout=240) + self.client = InferenceClient(base_url=address, headers=headers, timeout=240) + self.async_client = AsyncInferenceClient(base_url=address, headers=headers, timeout=240) self._max_gen_toks = 256 self.model_info = requests.get(f"{address}/info", headers=headers).json() if "model_id" not in self.model_info: @@ -53,7 +53,7 @@ def __init__(self, address, auth_token=None, model_id=None) -> None: self._tokenizer = AutoTokenizer.from_pretrained(self.model_info["model_id"]) self._add_special_tokens = True self.use_async = True - self.name = address + self.name = self.model_info["model_id"] def set_cache_hook(self, cache_hook): self.cache_hook = cache_hook diff --git a/tests/test_endpoint_model.py b/tests/test_endpoint_model.py index 325145d43..20c7701c8 100644 --- a/tests/test_endpoint_model.py +++ b/tests/test_endpoint_model.py @@ -149,17 +149,17 @@ def zero_shot_request_dict(self, task: LightevalTask) -> RequestDict: def test_greedy_until(self, zero_shot_request_dict: RequestDict, tgi_model: TGIModel): returns = tgi_model.greedy_until(zero_shot_request_dict[RequestType.GREEDY_UNTIL]) - assert len(returns) == 4 + assert len(returns) == 2 assert all(r.result is not None for r in returns) def test_loglikelihood(self, zero_shot_request_dict: RequestDict, tgi_model: TGIModel): returns = tgi_model.loglikelihood(zero_shot_request_dict[RequestType.LOGLIKELIHOOD]) - assert len(returns) == 8 - assert all(r.result[0] is not None for r in returns) + assert len(returns) == 4 + assert all(r.result is not None for r in returns) returns = tgi_model.loglikelihood_rolling(zero_shot_request_dict[RequestType.LOGLIKELIHOOD_ROLLING]) - assert len(returns) == 4 - assert all(r.result[0] is not None for r in returns) + assert len(returns) == 2 + assert all(r.result is not None for r in returns) @pytest.mark.parametrize("num_fewshot", [0, 2]) @pytest.mark.parametrize("use_chat_template", [False, True]) From ec769c5876023aac4d473a1d0bc8c385cbf8e2fd Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Wed, 28 Aug 2024 07:02:22 +0330 Subject: [PATCH 06/10] Adapt to Pipeline --- src/lighteval/models/tgi_model.py | 17 +++++++---- tests/test_base_model.py | 48 ++++++++++++++++++++----------- tests/test_endpoint_model.py | 48 ++++++++++++++++++++----------- 3 files changed, 73 insertions(+), 40 deletions(-) diff --git a/src/lighteval/models/tgi_model.py b/src/lighteval/models/tgi_model.py index 92c3f65b4..0621ea121 100644 --- a/src/lighteval/models/tgi_model.py +++ b/src/lighteval/models/tgi_model.py @@ -24,6 +24,7 @@ from huggingface_hub import AsyncInferenceClient, InferenceClient from transformers import AutoTokenizer +from lighteval.models.abstract_model import ModelInfo from lighteval.models.endpoint_model import InferenceEndpointModel @@ -45,15 +46,19 @@ def __init__(self, address, auth_token=None, model_id=None) -> None: self.client = InferenceClient(base_url=address, headers=headers, timeout=240) self.async_client = AsyncInferenceClient(base_url=address, headers=headers, timeout=240) self._max_gen_toks = 256 - self.model_info = requests.get(f"{address}/info", headers=headers).json() - if "model_id" not in self.model_info: + info = requests.get(f"{address}/info", headers=headers).json() + if "model_id" not in info: raise ValueError("Error occured when fetching info: " + str(self.model_info)) - if model_id: - self.model_info["model_id"] = model_id - self._tokenizer = AutoTokenizer.from_pretrained(self.model_info["model_id"]) + self.name = info["model_id"] + self.model_info = ModelInfo( + model_name=model_id or self.name, + model_sha=info["model_sha"], + model_dtype=info["model_dtype"] or "default", + model_size=-1, + ) + self._tokenizer = AutoTokenizer.from_pretrained(self.model_info.model_name) self._add_special_tokens = True self.use_async = True - self.name = self.model_info["model_id"] def set_cache_hook(self, cache_hook): self.cache_hook = cache_hook diff --git a/tests/test_base_model.py b/tests/test_base_model.py index 050b5ffb4..f622ff77f 100644 --- a/tests/test_base_model.py +++ b/tests/test_base_model.py @@ -22,15 +22,17 @@ import os from typing import Iterator, TypeAlias +from unittest.mock import patch import pytest from huggingface_hub import ChatCompletionInputMessage from transformers import BatchEncoding -from lighteval.evaluator import EvaluationTracker, evaluate +from lighteval.logging.evaluation_tracker import EvaluationTracker from lighteval.metrics.metrics import Metrics from lighteval.models.base_model import BaseModel from lighteval.models.model_config import BaseModelConfig, EnvConfig +from lighteval.pipeline import ParallelismManager, Pipeline, PipelineParameters from lighteval.tasks.lighteval_task import LightevalTask, LightevalTaskConfig, create_requests_from_tasks from lighteval.tasks.requests import ( Doc, @@ -120,28 +122,40 @@ def task(self) -> LightevalTask: def test_integration(self, task: LightevalTask, base_model: BaseModel, num_fewshot: int, use_chat_template: bool): base_model.use_chat_template = use_chat_template + env_config = EnvConfig(token=TOKEN, cache_dir=CACHE_PATH) evaluation_tracker = EvaluationTracker() + pipeline_params = PipelineParameters( + launcher_type=ParallelismManager.NONE, + env_config=env_config, + use_chat_template=use_chat_template, + override_batch_size=1, + ) + + with patch("lighteval.pipeline.Pipeline._init_tasks_and_requests"): + pipeline = Pipeline( + tasks=f"custom|test|{num_fewshot}|0", + pipeline_parameters=pipeline_params, + evaluation_tracker=evaluation_tracker, + model=base_model, + ) task_dict = {"custom|test": task} evaluation_tracker.task_config_logger.log(task_dict) - requests_dict, docs = create_requests_from_tasks( + fewshot_dict = {"custom|test": [(num_fewshot, False)]} + pipeline.task_names_list = ["custom|test"] + pipeline.task_dict = task_dict + pipeline.fewshot_dict = fewshot_dict + requests, docs = create_requests_from_tasks( task_dict=task_dict, - fewshot_dict={"custom|test": [(num_fewshot, False)]}, - num_fewshot_seeds=0, + fewshot_dict=fewshot_dict, + num_fewshot_seeds=pipeline_params.num_fewshot_seeds, lm=base_model, - max_samples=1, + max_samples=pipeline_params.max_samples, evaluation_tracker=evaluation_tracker, use_chat_template=use_chat_template, - system_prompt=None, + system_prompt=pipeline_params.system_prompt, ) + pipeline.requests = requests + pipeline.docs = docs + evaluation_tracker.task_config_logger.log(task_dict) - evaluation_tracker = evaluate( - lm=base_model, - requests_dict=requests_dict, - docs=docs, - task_dict=task_dict, - override_bs=1, - evaluation_tracker=evaluation_tracker, - ) - evaluation_tracker.metrics_logger.aggregate(task_dict=task_dict) - evaluation_tracker.details_logger.aggregate() - evaluation_tracker.generate_final_dict() + pipeline.evaluate() diff --git a/tests/test_endpoint_model.py b/tests/test_endpoint_model.py index 20c7701c8..9550eb6a3 100644 --- a/tests/test_endpoint_model.py +++ b/tests/test_endpoint_model.py @@ -25,6 +25,7 @@ import time from collections import defaultdict from typing import Iterator, TypeAlias +from unittest.mock import patch import docker import docker.errors @@ -32,15 +33,17 @@ import requests from huggingface_hub import ChatCompletionInputMessage -from lighteval.evaluator import EvaluationTracker, evaluate +from lighteval.logging.evaluation_tracker import EvaluationTracker from lighteval.metrics.metrics import Metrics from lighteval.models.tgi_model import ModelClient as TGIModel +from lighteval.pipeline import ParallelismManager, Pipeline, PipelineParameters from lighteval.tasks.lighteval_task import LightevalTask, LightevalTaskConfig, create_requests_from_tasks from lighteval.tasks.requests import ( Doc, Request, RequestType, ) +from lighteval.utils.utils import EnvConfig TOKEN = os.environ.get("HF_TOKEN") @@ -164,28 +167,39 @@ def test_loglikelihood(self, zero_shot_request_dict: RequestDict, tgi_model: TGI @pytest.mark.parametrize("num_fewshot", [0, 2]) @pytest.mark.parametrize("use_chat_template", [False, True]) def test_integration(self, task: LightevalTask, tgi_model: TGIModel, num_fewshot: int, use_chat_template: bool): + env_config = EnvConfig(token=TOKEN, cache_dir=CACHE_PATH) evaluation_tracker = EvaluationTracker() + pipeline_params = PipelineParameters( + launcher_type=ParallelismManager.NONE, + env_config=env_config, + use_chat_template=use_chat_template, + ) + + with patch("lighteval.pipeline.Pipeline._init_tasks_and_requests"): + pipeline = Pipeline( + tasks=f"custom|test|{num_fewshot}|0", + pipeline_parameters=pipeline_params, + evaluation_tracker=evaluation_tracker, + model=tgi_model, + ) task_dict = {"custom|test": task} evaluation_tracker.task_config_logger.log(task_dict) - requests_dict, docs = create_requests_from_tasks( + fewshot_dict = {"custom|test": [(num_fewshot, False)]} + pipeline.task_names_list = ["custom|test"] + pipeline.task_dict = task_dict + pipeline.fewshot_dict = fewshot_dict + requests, docs = create_requests_from_tasks( task_dict=task_dict, - fewshot_dict={"custom|test": [(num_fewshot, False)]}, - num_fewshot_seeds=0, + fewshot_dict=fewshot_dict, + num_fewshot_seeds=pipeline_params.num_fewshot_seeds, lm=tgi_model, - max_samples=1, + max_samples=pipeline_params.max_samples, evaluation_tracker=evaluation_tracker, use_chat_template=use_chat_template, - system_prompt=None, + system_prompt=pipeline_params.system_prompt, ) + pipeline.requests = requests + pipeline.docs = docs + evaluation_tracker.task_config_logger.log(task_dict) - evaluation_tracker = evaluate( - lm=tgi_model, - requests_dict=requests_dict, - docs=docs, - task_dict=task_dict, - override_bs=1, - evaluation_tracker=evaluation_tracker, - ) - evaluation_tracker.metrics_logger.aggregate(task_dict=task_dict) - evaluation_tracker.details_logger.aggregate() - evaluation_tracker.generate_final_dict() + pipeline.evaluate() From b29187161d69ed0997596268197285f16bb7e0b7 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Mon, 2 Sep 2024 22:15:48 +0330 Subject: [PATCH 07/10] Fix a tiny bug forgot to add in base_model.py --- src/lighteval/models/base_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lighteval/models/base_model.py b/src/lighteval/models/base_model.py index 081f7743d..3d5172d1d 100644 --- a/src/lighteval/models/base_model.py +++ b/src/lighteval/models/base_model.py @@ -553,7 +553,7 @@ def greedy_until( max_length=self.max_length - 1, add_special_tokens=self.add_special_tokens, return_dict=True, - ) + ).to(self.device) else: # See doc https://huggingface.co/docs/transformers/v4.38.2/en/pad_truncation#padding-and-truncation # Will do left truncation and padding, as defined when creating the tokenizer From 8c0018e9933c29a1916b8c6f95e8b0785084cfb4 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Thu, 5 Sep 2024 11:26:40 +0330 Subject: [PATCH 08/10] Improve endpoint tests and bug fix in endpoint model --- src/lighteval/models/endpoint_model.py | 7 +- src/lighteval/models/tgi_model.py | 2 +- tests/conftest.py | 12 +++ tests/test_base_model.py | 12 --- tests/test_endpoint_model.py | 81 ++++++++++++++----- tests/test_test.py | 103 +++++++++++++++++++++++++ 6 files changed, 181 insertions(+), 36 deletions(-) create mode 100644 tests/conftest.py create mode 100644 tests/test_test.py diff --git a/src/lighteval/models/endpoint_model.py b/src/lighteval/models/endpoint_model.py index 3d85d67f5..cfd711627 100644 --- a/src/lighteval/models/endpoint_model.py +++ b/src/lighteval/models/endpoint_model.py @@ -205,11 +205,10 @@ def _process_generate_response(self, response: EndpointOutput, request: GreedyUn def _process_logprob_response( self, response: TextGenerationOutput, request: LoglikelihoodRequest | LoglikelihoodRollingRequest ) -> LoglikelihoodResponse: - len_choice = len(request.tokenized_continuation) - logits = sum([t.logprob for t in response.details.prefill[1:][-len_choice:]]) + logits = sum([t.logprob for t in response.details.prefill[len(request.tokenized_context):]]) return LoglikelihoodResponse( result=(logits, True) if isinstance(request, LoglikelihoodRequest) else logits, - input_tokens=[t.id for t in response.details.prefill[:-len_choice]], + input_tokens=[t.id for t in response.details.prefill[len(request.tokenized_context):]], generated_tokens=-1, truncated_tokens_count=-1, padded_tokens_count=-1, @@ -255,6 +254,7 @@ def _prepare_request(self, request: Request) -> EndpointInput: context = request.context + [ChatCompletionInputMessage(role="assistant", content=request.choice)] if not isinstance(context, str): context = self.tokenizer.apply_chat_template(context, tokenize=False) + context = context.split(self.tokenizer.bos_token, 1)[-1] if isinstance(context, str): prepared_request = TextGenerationInput( @@ -290,6 +290,7 @@ def greedy_until( override_bs: Optional[int] = None, ) -> List[GenerativeResponse]: for request in requests: + # Why don't we set context to empty list here? request.tokenized_context = self.tok_encode(request.context) request.stop_sequence = as_list(request.stop_sequence) + [self.tokenizer.eos_token] diff --git a/src/lighteval/models/tgi_model.py b/src/lighteval/models/tgi_model.py index 0621ea121..c359a0860 100644 --- a/src/lighteval/models/tgi_model.py +++ b/src/lighteval/models/tgi_model.py @@ -53,7 +53,7 @@ def __init__(self, address, auth_token=None, model_id=None) -> None: self.model_info = ModelInfo( model_name=model_id or self.name, model_sha=info["model_sha"], - model_dtype=info["model_dtype"] or "default", + model_dtype=info["model_dtype"] if "model_dtype" in info else "default", model_size=-1, ) self._tokenizer = AutoTokenizer.from_pretrained(self.model_info.model_name) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..555b43966 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,12 @@ +from typing import Iterator +import pytest + +from lighteval.models.model_config import BaseModelConfig +from lighteval.models.abstract_model import EnvConfig +from lighteval.models.base_model import BaseModel + + +@pytest.fixture(scope="module") +def base_model() -> Iterator[BaseModel]: + config = BaseModelConfig("hf-internal-testing/tiny-random-LlamaForCausalLM") + return BaseModel(config, EnvConfig()) \ No newline at end of file diff --git a/tests/test_base_model.py b/tests/test_base_model.py index f622ff77f..0d3c72be6 100644 --- a/tests/test_base_model.py +++ b/tests/test_base_model.py @@ -41,16 +41,6 @@ ) -TOKEN = os.environ.get("HF_TOKEN") -CACHE_PATH = os.getenv("HF_HOME", ".") - - -@pytest.fixture(scope="module") -def base_model() -> Iterator[BaseModel]: - config = BaseModelConfig("hf-internal-testing/tiny-random-LlamaForCausalLM") - return BaseModel(config, EnvConfig(CACHE_PATH, TOKEN)) - - RequestDict: TypeAlias = dict[RequestType, list[Request]] @@ -122,11 +112,9 @@ def task(self) -> LightevalTask: def test_integration(self, task: LightevalTask, base_model: BaseModel, num_fewshot: int, use_chat_template: bool): base_model.use_chat_template = use_chat_template - env_config = EnvConfig(token=TOKEN, cache_dir=CACHE_PATH) evaluation_tracker = EvaluationTracker() pipeline_params = PipelineParameters( launcher_type=ParallelismManager.NONE, - env_config=env_config, use_chat_template=use_chat_template, override_batch_size=1, ) diff --git a/tests/test_endpoint_model.py b/tests/test_endpoint_model.py index 9550eb6a3..98e33de0d 100644 --- a/tests/test_endpoint_model.py +++ b/tests/test_endpoint_model.py @@ -27,27 +27,27 @@ from typing import Iterator, TypeAlias from unittest.mock import patch +import torch import docker import docker.errors import pytest import requests from huggingface_hub import ChatCompletionInputMessage +from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, PreTrainedTokenizerFast from lighteval.logging.evaluation_tracker import EvaluationTracker from lighteval.metrics.metrics import Metrics from lighteval.models.tgi_model import ModelClient as TGIModel +from lighteval.models.base_model import BaseModel from lighteval.pipeline import ParallelismManager, Pipeline, PipelineParameters from lighteval.tasks.lighteval_task import LightevalTask, LightevalTaskConfig, create_requests_from_tasks from lighteval.tasks.requests import ( + LoglikelihoodRequest, + LoglikelihoodRollingRequest, Doc, Request, RequestType, ) -from lighteval.utils.utils import EnvConfig - - -TOKEN = os.environ.get("HF_TOKEN") -CACHE_PATH = os.getenv("HF_HOME", ".") @pytest.fixture(scope="module") @@ -83,12 +83,19 @@ def tgi_model() -> Iterator[TGIModel]: raise RuntimeError("Couldn't setup TGI server.") model = TGIModel(address) yield model - container.stop() - container.wait() - container.remove() + # container.stop() + # container.wait() + # container.remove() model.cleanup() +@pytest.fixture(scope="module") +def reference_model_tokenizer() -> tuple[LlamaForCausalLM, PreTrainedTokenizerFast]: + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM") + return model, tokenizer + + RequestDict: TypeAlias = dict[RequestType, list[Request]] @@ -150,28 +157,62 @@ def zero_shot_request_dict(self, task: LightevalTask) -> RequestDict: result[req_type].extend(doc_result[req_type]) return result - def test_greedy_until(self, zero_shot_request_dict: RequestDict, tgi_model: TGIModel): - returns = tgi_model.greedy_until(zero_shot_request_dict[RequestType.GREEDY_UNTIL]) + def test_greedy_until(self, reference_model_tokenizer: tuple[LlamaForCausalLM, PreTrainedTokenizerFast], zero_shot_request_dict: RequestDict, tgi_model: TGIModel): + requests = zero_shot_request_dict[RequestType.GREEDY_UNTIL] + returns = tgi_model.greedy_until(requests) + model, tokenizer = reference_model_tokenizer assert len(returns) == 2 - assert all(r.result is not None for r in returns) - - def test_loglikelihood(self, zero_shot_request_dict: RequestDict, tgi_model: TGIModel): - returns = tgi_model.loglikelihood(zero_shot_request_dict[RequestType.LOGLIKELIHOOD]) + for req, res in zip(requests, returns): + is_chat = not isinstance(req.context, str) + tokenized_context = tokenizer.apply_chat_template(req.context, return_tensors='pt') if is_chat else tokenizer(req.context, return_tensors='pt')['input_ids'] + ref_context_continuaiton = model.generate(tokenized_context, tokenizer=tokenizer, stop_strings=req.stop_sequence, max_new_tokens=req.generation_size)[0].tolist() + continuation = tokenizer.decode(ref_context_continuaiton)[len(tokenizer.decode(tokenized_context[0].tolist())):] + assert continuation == res.result + + def test_loglikelihood(self, reference_model_tokenizer: tuple[LlamaForCausalLM, PreTrainedTokenizerFast], zero_shot_request_dict: RequestDict, tgi_model: TGIModel): + requests: list[LoglikelihoodRequest] = zero_shot_request_dict[RequestType.LOGLIKELIHOOD] + returns = tgi_model.loglikelihood(requests) + model, tokenizer = reference_model_tokenizer assert len(returns) == 4 - assert all(r.result is not None for r in returns) - + for req, res in zip(requests, returns): + is_chat = not isinstance(req.context, str) + sequence = req.context + [ChatCompletionInputMessage(role='assistant',content=req.choice)] if is_chat else req.context+req.choice + tokenized_sequence = tokenizer.apply_chat_template(sequence, return_tensors='pt') if is_chat else tokenizer(sequence, return_tensors='pt')['input_ids'] + + output = model.generate(tokenized_sequence, max_new_tokens=1, return_dict_in_generate=True, output_hidden_states=True) + with torch.no_grad(): + logprobs = torch.log_softmax(model.lm_head(output.hidden_states[0][-1]),dim=-1) + logprobs = logprobs.gather(dim=-1, index=tokenized_sequence[:,1:].unsqueeze(-1)) + context_length = len(tokenizer.apply_chat_template(req.context)) if is_chat else len(tokenizer.encode(req.context)) + continuation_logprob = logprobs[:, context_length-1:].sum() + + tokenized_choice = tokenized_sequence[:, context_length:] + assert tokenized_choice[0].tolist() == res.input_tokens + assert torch.allclose(torch.tensor(res.result[0]), continuation_logprob) + + def test_loglikelihood_rolling(self, reference_model_tokenizer: tuple[LlamaForCausalLM, PreTrainedTokenizerFast], zero_shot_request_dict: RequestDict, tgi_model: TGIModel): + model, tokenizer = reference_model_tokenizer + requests: list[LoglikelihoodRollingRequest] = zero_shot_request_dict[RequestType.LOGLIKELIHOOD_ROLLING] returns = tgi_model.loglikelihood_rolling(zero_shot_request_dict[RequestType.LOGLIKELIHOOD_ROLLING]) assert len(returns) == 2 - assert all(r.result is not None for r in returns) + for req, res in zip(requests, returns): + is_chat = not isinstance(req.context, str) + tokenized_context = tokenizer.apply_chat_template(req.context, return_tensors='pt') if is_chat else tokenizer(req.context, return_tensors='pt')['input_ids'] + output = model.generate(tokenized_context, max_new_tokens=1, return_dict_in_generate=True, output_hidden_states=True) + with torch.no_grad(): + logprobs = torch.log_softmax(model.lm_head(output.hidden_states[0][-1]),dim=-1) + logprob = logprobs.gather(dim=-1, index=tokenized_context[:,1:].unsqueeze(-1)).sum() + + assert tokenized_context[0, 1:].tolist() == res.input_tokens + assert torch.allclose(torch.tensor(res.result), logprob) @pytest.mark.parametrize("num_fewshot", [0, 2]) @pytest.mark.parametrize("use_chat_template", [False, True]) - def test_integration(self, task: LightevalTask, tgi_model: TGIModel, num_fewshot: int, use_chat_template: bool): - env_config = EnvConfig(token=TOKEN, cache_dir=CACHE_PATH) + def test_integration(self, task: LightevalTask, base_model: BaseModel, tgi_model: TGIModel, num_fewshot: int, use_chat_template: bool): + #TODO evaluation_tracker = EvaluationTracker() pipeline_params = PipelineParameters( launcher_type=ParallelismManager.NONE, - env_config=env_config, use_chat_template=use_chat_template, ) diff --git a/tests/test_test.py b/tests/test_test.py new file mode 100644 index 000000000..1612b6504 --- /dev/null +++ b/tests/test_test.py @@ -0,0 +1,103 @@ +import time +import random +import asyncio +from typing import Iterator + +import pytest +import docker +import requests +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM +from huggingface_hub import ( + InferenceClient, + AsyncInferenceClient, + TextGenerationOutput, +) + + +@pytest.fixture(params=["sync", "async"]) +def tgi_client(request) -> Iterator[InferenceClient|AsyncInferenceClient]: + client = docker.from_env() + + try: + container = client.containers.get("lighteval-tgi-model-test") + port = container.ports["80/tcp"][0]["HostPort"] + except docker.errors.NotFound: + port = random.randint(8000, 9000) + container = client.containers.run( + "ghcr.io/huggingface/text-generation-inference:2.2.0", + command=[ + "--model-id", + "hf-internal-testing/tiny-random-LlamaForCausalLM", + "--dtype", + "float16", + ], + detach=True, + name="lighteval-tgi-model-test", + auto_remove=False, + ports={"80/tcp": port}, + ) + address = f"http://localhost:{port}" + for _ in range(40): + try: + if requests.get(f"{address}/health"): + break + except Exception: + time.sleep(1) + else: + raise RuntimeError("Couldn't setup TGI server.") + + if request.param == "async": + yield AsyncInferenceClient(base_url=address) + elif request.param == "sync": + yield InferenceClient(base_url=address) + else: + raise RuntimeError() + + +def test_logprobs(tgi_client: InferenceClient|AsyncInferenceClient): + model: LlamaForCausalLM = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM") + + # It raises error in async setting unless the size of `prompts` is < 3 + prompts = [ + "Tell me:\n\nHow are you?Fine, thanks!", + "Tell me:\n\nHow are you?Not bad!", + "Tell me:\n\nComment vas-tu?Comme ci, comme ça", + "Tell me:\n\nComment vas-tu?Ca va! Merci!", + ] + responses = [] + for prompt in prompts: + responses.append(tgi_client.text_generation( + prompt, + details=True, + decoder_input_details=True, + max_new_tokens=1, + stop_sequences=None, + do_sample=False, + return_full_text=False, + seed=42, + )) + if isinstance(tgi_client, AsyncInferenceClient): + loop = asyncio.get_event_loop() + responses: list[TextGenerationOutput] = loop.run_until_complete(asyncio.gather(*responses)) + + error = False + for prompt, response in zip(prompts, responses): + + tgi_logprobs = torch.tensor([t.logprob for t in response.details.prefill[1:]]) # Skipping whose logprob is None + + tokenized_sequence = tokenizer(prompt, return_tensors='pt')['input_ids'] + output = model.generate(tokenized_sequence, max_new_tokens=1, return_dict_in_generate=True, output_hidden_states=True) + with torch.no_grad(): + logprobs = torch.log_softmax(model.lm_head(output.hidden_states[0][-1]),dim=-1) + logprobs = logprobs.gather(dim=-1, index=tokenized_sequence[:,1:].unsqueeze(-1)).squeeze() + + if not torch.allclose(logprobs.sum(), tgi_logprobs.sum()): + print(f"====== prompt: {repr(prompt)} ======") + print("TGI logprobs:", tgi_logprobs.tolist()) + print("TGI tokens:",[t.id for t in response.details.prefill]) + print("Ref. logprobs:", logprobs.tolist()) + print("Ref. tokens:", tokenized_sequence[0].tolist()) + error = True + assert not error \ No newline at end of file From cafb1e6fec3e36e185815e139389a03d9fa70110 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Mon, 23 Sep 2024 19:48:11 +0330 Subject: [PATCH 09/10] Fix tests --- tests/conftest.py | 12 ----- tests/test_base_model.py | 7 ++- tests/test_endpoint_model.py | 95 ++++++++++++++++++++++++++---------- 3 files changed, 76 insertions(+), 38 deletions(-) delete mode 100644 tests/conftest.py diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index 555b43966..000000000 --- a/tests/conftest.py +++ /dev/null @@ -1,12 +0,0 @@ -from typing import Iterator -import pytest - -from lighteval.models.model_config import BaseModelConfig -from lighteval.models.abstract_model import EnvConfig -from lighteval.models.base_model import BaseModel - - -@pytest.fixture(scope="module") -def base_model() -> Iterator[BaseModel]: - config = BaseModelConfig("hf-internal-testing/tiny-random-LlamaForCausalLM") - return BaseModel(config, EnvConfig()) \ No newline at end of file diff --git a/tests/test_base_model.py b/tests/test_base_model.py index 0d3c72be6..65e082f13 100644 --- a/tests/test_base_model.py +++ b/tests/test_base_model.py @@ -20,7 +20,6 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import os from typing import Iterator, TypeAlias from unittest.mock import patch @@ -41,6 +40,12 @@ ) +@pytest.fixture(scope="module") +def base_model() -> Iterator[BaseModel]: + config = BaseModelConfig("hf-internal-testing/tiny-random-LlamaForCausalLM") + return BaseModel(config, EnvConfig(cache_dir=".")) + + RequestDict: TypeAlias = dict[RequestType, list[Request]] diff --git a/tests/test_endpoint_model.py b/tests/test_endpoint_model.py index 98e33de0d..4ec1e5491 100644 --- a/tests/test_endpoint_model.py +++ b/tests/test_endpoint_model.py @@ -20,31 +20,29 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import os import random import time from collections import defaultdict from typing import Iterator, TypeAlias from unittest.mock import patch -import torch import docker import docker.errors import pytest import requests +import torch from huggingface_hub import ChatCompletionInputMessage from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, PreTrainedTokenizerFast from lighteval.logging.evaluation_tracker import EvaluationTracker from lighteval.metrics.metrics import Metrics from lighteval.models.tgi_model import ModelClient as TGIModel -from lighteval.models.base_model import BaseModel from lighteval.pipeline import ParallelismManager, Pipeline, PipelineParameters from lighteval.tasks.lighteval_task import LightevalTask, LightevalTaskConfig, create_requests_from_tasks from lighteval.tasks.requests import ( + Doc, LoglikelihoodRequest, LoglikelihoodRollingRequest, - Doc, Request, RequestType, ) @@ -157,59 +155,106 @@ def zero_shot_request_dict(self, task: LightevalTask) -> RequestDict: result[req_type].extend(doc_result[req_type]) return result - def test_greedy_until(self, reference_model_tokenizer: tuple[LlamaForCausalLM, PreTrainedTokenizerFast], zero_shot_request_dict: RequestDict, tgi_model: TGIModel): + def test_greedy_until( + self, + reference_model_tokenizer: tuple[LlamaForCausalLM, PreTrainedTokenizerFast], + zero_shot_request_dict: RequestDict, + tgi_model: TGIModel, + ): requests = zero_shot_request_dict[RequestType.GREEDY_UNTIL] returns = tgi_model.greedy_until(requests) model, tokenizer = reference_model_tokenizer assert len(returns) == 2 for req, res in zip(requests, returns): is_chat = not isinstance(req.context, str) - tokenized_context = tokenizer.apply_chat_template(req.context, return_tensors='pt') if is_chat else tokenizer(req.context, return_tensors='pt')['input_ids'] - ref_context_continuaiton = model.generate(tokenized_context, tokenizer=tokenizer, stop_strings=req.stop_sequence, max_new_tokens=req.generation_size)[0].tolist() - continuation = tokenizer.decode(ref_context_continuaiton)[len(tokenizer.decode(tokenized_context[0].tolist())):] + tokenized_context = ( + tokenizer.apply_chat_template(req.context, return_tensors="pt") + if is_chat + else tokenizer(req.context, return_tensors="pt")["input_ids"] + ) + ref_context_continuaiton = model.generate( + tokenized_context, + tokenizer=tokenizer, + stop_strings=req.stop_sequence, + max_new_tokens=req.generation_size, + )[0].tolist() + continuation = tokenizer.decode(ref_context_continuaiton)[ + len(tokenizer.decode(tokenized_context[0].tolist())) : + ] assert continuation == res.result - def test_loglikelihood(self, reference_model_tokenizer: tuple[LlamaForCausalLM, PreTrainedTokenizerFast], zero_shot_request_dict: RequestDict, tgi_model: TGIModel): + def test_loglikelihood( + self, + reference_model_tokenizer: tuple[LlamaForCausalLM, PreTrainedTokenizerFast], + zero_shot_request_dict: RequestDict, + tgi_model: TGIModel, + ): requests: list[LoglikelihoodRequest] = zero_shot_request_dict[RequestType.LOGLIKELIHOOD] - returns = tgi_model.loglikelihood(requests) + + # https://github.com/huggingface/text-generation-inference/issues/2502 + with patch.object(tgi_model, "use_async", False): + returns = tgi_model.loglikelihood(requests) + model, tokenizer = reference_model_tokenizer assert len(returns) == 4 for req, res in zip(requests, returns): is_chat = not isinstance(req.context, str) - sequence = req.context + [ChatCompletionInputMessage(role='assistant',content=req.choice)] if is_chat else req.context+req.choice - tokenized_sequence = tokenizer.apply_chat_template(sequence, return_tensors='pt') if is_chat else tokenizer(sequence, return_tensors='pt')['input_ids'] - - output = model.generate(tokenized_sequence, max_new_tokens=1, return_dict_in_generate=True, output_hidden_states=True) + sequence = ( + req.context + [ChatCompletionInputMessage(role="assistant", content=req.choice)] + if is_chat + else req.context + req.choice + ) + tokenized_sequence = ( + tokenizer.apply_chat_template(sequence, return_tensors="pt") + if is_chat + else tokenizer(sequence, return_tensors="pt")["input_ids"] + ) + + output = model.generate( + tokenized_sequence, max_new_tokens=1, return_dict_in_generate=True, output_hidden_states=True + ) with torch.no_grad(): - logprobs = torch.log_softmax(model.lm_head(output.hidden_states[0][-1]),dim=-1) - logprobs = logprobs.gather(dim=-1, index=tokenized_sequence[:,1:].unsqueeze(-1)) - context_length = len(tokenizer.apply_chat_template(req.context)) if is_chat else len(tokenizer.encode(req.context)) - continuation_logprob = logprobs[:, context_length-1:].sum() + logprobs = torch.log_softmax(model.lm_head(output.hidden_states[0][-1]), dim=-1) + logprobs = logprobs.gather(dim=-1, index=tokenized_sequence[:, 1:].unsqueeze(-1)) + context_length = ( + len(tokenizer.apply_chat_template(req.context)) if is_chat else len(tokenizer.encode(req.context)) + ) + continuation_logprob = logprobs[:, context_length - 1 :].sum() tokenized_choice = tokenized_sequence[:, context_length:] assert tokenized_choice[0].tolist() == res.input_tokens assert torch.allclose(torch.tensor(res.result[0]), continuation_logprob) - def test_loglikelihood_rolling(self, reference_model_tokenizer: tuple[LlamaForCausalLM, PreTrainedTokenizerFast], zero_shot_request_dict: RequestDict, tgi_model: TGIModel): + def test_loglikelihood_rolling( + self, + reference_model_tokenizer: tuple[LlamaForCausalLM, PreTrainedTokenizerFast], + zero_shot_request_dict: RequestDict, + tgi_model: TGIModel, + ): model, tokenizer = reference_model_tokenizer requests: list[LoglikelihoodRollingRequest] = zero_shot_request_dict[RequestType.LOGLIKELIHOOD_ROLLING] returns = tgi_model.loglikelihood_rolling(zero_shot_request_dict[RequestType.LOGLIKELIHOOD_ROLLING]) assert len(returns) == 2 for req, res in zip(requests, returns): is_chat = not isinstance(req.context, str) - tokenized_context = tokenizer.apply_chat_template(req.context, return_tensors='pt') if is_chat else tokenizer(req.context, return_tensors='pt')['input_ids'] - output = model.generate(tokenized_context, max_new_tokens=1, return_dict_in_generate=True, output_hidden_states=True) + tokenized_context = ( + tokenizer.apply_chat_template(req.context, return_tensors="pt") + if is_chat + else tokenizer(req.context, return_tensors="pt")["input_ids"] + ) + output = model.generate( + tokenized_context, max_new_tokens=1, return_dict_in_generate=True, output_hidden_states=True + ) with torch.no_grad(): - logprobs = torch.log_softmax(model.lm_head(output.hidden_states[0][-1]),dim=-1) - logprob = logprobs.gather(dim=-1, index=tokenized_context[:,1:].unsqueeze(-1)).sum() + logprobs = torch.log_softmax(model.lm_head(output.hidden_states[0][-1]), dim=-1) + logprob = logprobs.gather(dim=-1, index=tokenized_context[:, 1:].unsqueeze(-1)).sum() assert tokenized_context[0, 1:].tolist() == res.input_tokens assert torch.allclose(torch.tensor(res.result), logprob) @pytest.mark.parametrize("num_fewshot", [0, 2]) @pytest.mark.parametrize("use_chat_template", [False, True]) - def test_integration(self, task: LightevalTask, base_model: BaseModel, tgi_model: TGIModel, num_fewshot: int, use_chat_template: bool): - #TODO + def test_integration(self, task: LightevalTask, tgi_model: TGIModel, num_fewshot: int, use_chat_template: bool): evaluation_tracker = EvaluationTracker() pipeline_params = PipelineParameters( launcher_type=ParallelismManager.NONE, From fbe1398c2e1c518d664c15b046d918be91d61e54 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Fri, 11 Oct 2024 20:00:08 +0330 Subject: [PATCH 10/10] Add grammar param to endpoint model inputs --- src/lighteval/models/endpoint_model.py | 4 + tests/test_test.py | 103 ------------------------- 2 files changed, 4 insertions(+), 103 deletions(-) delete mode 100644 tests/test_test.py diff --git a/src/lighteval/models/endpoint_model.py b/src/lighteval/models/endpoint_model.py index cfd711627..2021df14b 100644 --- a/src/lighteval/models/endpoint_model.py +++ b/src/lighteval/models/endpoint_model.py @@ -242,9 +242,11 @@ def _prepare_request(self, request: Request) -> EndpointInput: stop = as_list(request.stop_sequence) or None max_tokens = request.generation_size context = request.context + grammar = request.generation_grammar elif isinstance(request, (LoglikelihoodRequest, LoglikelihoodRollingRequest)): stop = None max_tokens = 1 + grammar = None rolling = isinstance(request, LoglikelihoodRollingRequest) if rolling: context = request.context @@ -267,6 +269,7 @@ def _prepare_request(self, request: Request) -> EndpointInput: max_new_tokens=max_tokens, stop=stop, return_full_text=False, + grammar=grammar, top_n_tokens=1, ), ) @@ -280,6 +283,7 @@ def _prepare_request(self, request: Request) -> EndpointInput: seed=42, temperature=0.0, top_logprobs=1, + response_format=grammar, stream=False, ) return prepared_request diff --git a/tests/test_test.py b/tests/test_test.py deleted file mode 100644 index 1612b6504..000000000 --- a/tests/test_test.py +++ /dev/null @@ -1,103 +0,0 @@ -import time -import random -import asyncio -from typing import Iterator - -import pytest -import docker -import requests -import torch -from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM -from huggingface_hub import ( - InferenceClient, - AsyncInferenceClient, - TextGenerationOutput, -) - - -@pytest.fixture(params=["sync", "async"]) -def tgi_client(request) -> Iterator[InferenceClient|AsyncInferenceClient]: - client = docker.from_env() - - try: - container = client.containers.get("lighteval-tgi-model-test") - port = container.ports["80/tcp"][0]["HostPort"] - except docker.errors.NotFound: - port = random.randint(8000, 9000) - container = client.containers.run( - "ghcr.io/huggingface/text-generation-inference:2.2.0", - command=[ - "--model-id", - "hf-internal-testing/tiny-random-LlamaForCausalLM", - "--dtype", - "float16", - ], - detach=True, - name="lighteval-tgi-model-test", - auto_remove=False, - ports={"80/tcp": port}, - ) - address = f"http://localhost:{port}" - for _ in range(40): - try: - if requests.get(f"{address}/health"): - break - except Exception: - time.sleep(1) - else: - raise RuntimeError("Couldn't setup TGI server.") - - if request.param == "async": - yield AsyncInferenceClient(base_url=address) - elif request.param == "sync": - yield InferenceClient(base_url=address) - else: - raise RuntimeError() - - -def test_logprobs(tgi_client: InferenceClient|AsyncInferenceClient): - model: LlamaForCausalLM = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM") - tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM") - - # It raises error in async setting unless the size of `prompts` is < 3 - prompts = [ - "Tell me:\n\nHow are you?Fine, thanks!", - "Tell me:\n\nHow are you?Not bad!", - "Tell me:\n\nComment vas-tu?Comme ci, comme ça", - "Tell me:\n\nComment vas-tu?Ca va! Merci!", - ] - responses = [] - for prompt in prompts: - responses.append(tgi_client.text_generation( - prompt, - details=True, - decoder_input_details=True, - max_new_tokens=1, - stop_sequences=None, - do_sample=False, - return_full_text=False, - seed=42, - )) - if isinstance(tgi_client, AsyncInferenceClient): - loop = asyncio.get_event_loop() - responses: list[TextGenerationOutput] = loop.run_until_complete(asyncio.gather(*responses)) - - error = False - for prompt, response in zip(prompts, responses): - - tgi_logprobs = torch.tensor([t.logprob for t in response.details.prefill[1:]]) # Skipping whose logprob is None - - tokenized_sequence = tokenizer(prompt, return_tensors='pt')['input_ids'] - output = model.generate(tokenized_sequence, max_new_tokens=1, return_dict_in_generate=True, output_hidden_states=True) - with torch.no_grad(): - logprobs = torch.log_softmax(model.lm_head(output.hidden_states[0][-1]),dim=-1) - logprobs = logprobs.gather(dim=-1, index=tokenized_sequence[:,1:].unsqueeze(-1)).squeeze() - - if not torch.allclose(logprobs.sum(), tgi_logprobs.sum()): - print(f"====== prompt: {repr(prompt)} ======") - print("TGI logprobs:", tgi_logprobs.tolist()) - print("TGI tokens:",[t.id for t in response.details.prefill]) - print("Ref. logprobs:", logprobs.tolist()) - print("Ref. tokens:", tokenized_sequence[0].tolist()) - error = True - assert not error \ No newline at end of file