Skip to content

Commit

Permalink
Improve endpoint tests and bug fix in endpoint model
Browse files Browse the repository at this point in the history
  • Loading branch information
sadra-barikbin committed Sep 5, 2024
1 parent b291871 commit c3ac5d6
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 17 deletions.
8 changes: 5 additions & 3 deletions src/lighteval/models/endpoint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,12 @@ def _process_generate_response(self, response: EndpointOutput, request: GreedyUn
def _process_logprob_response(
self, response: TextGenerationOutput, request: LoglikelihoodRequest | LoglikelihoodRollingRequest
) -> LoglikelihoodResponse:
len_choice = len(request.tokenized_continuation)
logits = sum([t.logprob for t in response.details.prefill[1:][-len_choice:]])
if "Fine" in request.choice:
print(response.details.prefill)
logits = sum([t.logprob for t in response.details.prefill[len(request.tokenized_context):]])
return LoglikelihoodResponse(
result=(logits, True) if isinstance(request, LoglikelihoodRequest) else logits,
input_tokens=[t.id for t in response.details.prefill[:-len_choice]],
input_tokens=[t.id for t in response.details.prefill[len(request.tokenized_context):]],
generated_tokens=-1,
truncated_tokens_count=-1,
padded_tokens_count=-1,
Expand Down Expand Up @@ -255,6 +256,7 @@ def _prepare_request(self, request: Request) -> EndpointInput:
context = request.context + [ChatCompletionInputMessage(role="assistant", content=request.choice)]
if not isinstance(context, str):
context = self.tokenizer.apply_chat_template(context, tokenize=False)
context = context.split(self.tokenizer.bos_token, 1)[-1]

if isinstance(context, str):
prepared_request = TextGenerationInput(
Expand Down
60 changes: 46 additions & 14 deletions tests/test_endpoint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,21 @@
from typing import Iterator, TypeAlias
from unittest.mock import patch

import torch
import docker
import docker.errors
import pytest
import requests
from huggingface_hub import ChatCompletionInputMessage
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, PreTrainedTokenizerFast

from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.metrics.metrics import Metrics
from lighteval.models.tgi_model import ModelClient as TGIModel
from lighteval.pipeline import ParallelismManager, Pipeline, PipelineParameters
from lighteval.tasks.lighteval_task import LightevalTask, LightevalTaskConfig, create_requests_from_tasks
from lighteval.tasks.requests import (
LoglikelihoodRequest,
Doc,
Request,
RequestType,
Expand Down Expand Up @@ -83,12 +86,19 @@ def tgi_model() -> Iterator[TGIModel]:
raise RuntimeError("Couldn't setup TGI server.")
model = TGIModel(address)
yield model
container.stop()
container.wait()
container.remove()
# container.stop()
# container.wait()
# container.remove()
model.cleanup()


@pytest.fixture(scope="module")
def reference_model_tokenizer() -> tuple[LlamaForCausalLM, PreTrainedTokenizerFast]:
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
return model, tokenizer


RequestDict: TypeAlias = dict[RequestType, list[Request]]


Expand Down Expand Up @@ -150,19 +160,41 @@ def zero_shot_request_dict(self, task: LightevalTask) -> RequestDict:
result[req_type].extend(doc_result[req_type])
return result

def test_greedy_until(self, zero_shot_request_dict: RequestDict, tgi_model: TGIModel):
returns = tgi_model.greedy_until(zero_shot_request_dict[RequestType.GREEDY_UNTIL])
def test_greedy_until(self, reference_model_tokenizer: tuple[LlamaForCausalLM, PreTrainedTokenizerFast], zero_shot_request_dict: RequestDict, tgi_model: TGIModel):
requests = zero_shot_request_dict[RequestType.GREEDY_UNTIL]
returns = tgi_model.greedy_until(requests)
model, tokenizer = reference_model_tokenizer
assert len(returns) == 2
assert all(r.result is not None for r in returns)

def test_loglikelihood(self, zero_shot_request_dict: RequestDict, tgi_model: TGIModel):
returns = tgi_model.loglikelihood(zero_shot_request_dict[RequestType.LOGLIKELIHOOD])
for req, res in zip(requests, returns):
is_chat = not isinstance(req.context, str)
tokenized_context = tokenizer.apply_chat_template(req.context, return_tensors='pt') if is_chat else tokenizer(req.context, return_tensors='pt')['input_ids']
ref_context_continuaiton = model.generate(tokenized_context, tokenizer=tokenizer, stop_strings=req.stop_sequence, max_new_tokens=req.generation_size)[0].tolist()
continuation = tokenizer.decode(ref_context_continuaiton)[len(tokenizer.decode(tokenized_context[0].tolist())):]
assert continuation == res.result

def test_loglikelihood(self, reference_model_tokenizer: tuple[LlamaForCausalLM, PreTrainedTokenizerFast], zero_shot_request_dict: RequestDict, tgi_model: TGIModel):
requests: list[LoglikelihoodRequest] = zero_shot_request_dict[RequestType.LOGLIKELIHOOD]
returns = tgi_model.loglikelihood(requests)
model, tokenizer = reference_model_tokenizer
assert len(returns) == 4
assert all(r.result is not None for r in returns)

returns = tgi_model.loglikelihood_rolling(zero_shot_request_dict[RequestType.LOGLIKELIHOOD_ROLLING])
assert len(returns) == 2
assert all(r.result is not None for r in returns)
for req, res in zip(requests, returns):
is_chat = not isinstance(req.context, str)
sequence = req.context + [ChatCompletionInputMessage(role='assistant',content=req.choice)] if is_chat else req.context+req.choice
tokenized_sequence = tokenizer.apply_chat_template(sequence, return_tensors='pt') if is_chat else tokenizer(sequence, return_tensors='pt')['input_ids']
context_length = len(tokenizer.apply_chat_template(req.context)) if is_chat else len(tokenizer.encode(req.context))
output = model.generate(tokenized_sequence, max_new_tokens=1, return_dict_in_generate=True, output_hidden_states=True)
with torch.no_grad():
scores = torch.log_softmax(model.lm_head(output.hidden_states[0][-1]),dim=-1)[0]
continuation_scores = scores[context_length-1:-1]
tokenized_choice = tokenized_sequence[0][context_length:]
score = continuation_scores.gather(dim=-1, index=tokenized_choice.view(-1,1)).sum()

assert tokenized_choice.tolist() == res.input_tokens
assert torch.allclose(torch.tensor(res.result[0]),score,rtol=.01)

# returns = tgi_model.loglikelihood_rolling(zero_shot_request_dict[RequestType.LOGLIKELIHOOD_ROLLING])
# assert len(returns) == 2
# assert all(r.result is not None for r in returns)

@pytest.mark.parametrize("num_fewshot", [0, 2])
@pytest.mark.parametrize("use_chat_template", [False, True])
Expand Down

0 comments on commit c3ac5d6

Please sign in to comment.