Skip to content

Commit

Permalink
fix: return correct shape logits and add streaming test
Browse files Browse the repository at this point in the history
  • Loading branch information
drbh committed Oct 31, 2024
1 parent 17de599 commit e2b394e
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"choices": [
{
"delta": {
"content": "",
"role": "assistant",
"tool_calls": null
},
"finish_reason": "stop",
"index": 0,
"logprobs": null
}
],
"created": 1730416361,
"id": "",
"model": "Qwen/Qwen2-VL-7B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.4.1-dev0-native",
"usage": null
}
42 changes: 42 additions & 0 deletions integration-tests/models/test_flash_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,45 @@ async def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot):
)

assert response == response_snapshot


@pytest.mark.private
async def test_flash_qwen2_vl_simple_streaming(flash_qwen2, response_snapshot):
responses = await flash_qwen2.chat(
max_tokens=100,
seed=42,
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
},
},
{"type": "text", "text": "Describe this image."},
],
},
],
stream=True,
)

count = 0
generated = ""
last_response = None
try:
async for response in responses:
count += 1
generated += response.choices[0].delta.content
last_response = response
except Exception as e:
# handle when the client library raises an exception when it cant parse "[DONE]" as JSON
pass

assert (
generated
== "The image depicts an anthropomorphic rabbit, wearing a futuristic spacesuit, in an extraterrestrial environment. The setting appears to be a red planet resembling Mars, with rugged terrain and rocky formations in the background. The moon is visible in the distant sky, adding to the lunar landscape."
)
assert count == 58
assert last_response == response_snapshot
Original file line number Diff line number Diff line change
Expand Up @@ -518,5 +518,5 @@ def forward(
hidden_states, _ = self.norm(hidden_states)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
return logits, None
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits
2 changes: 1 addition & 1 deletion server/text_generation_server/models/vlm_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def forward(
lm_head_indices = batch.prefill_head_indices

if self.model.config.model_type == "qwen2_vl":
if position_ids.dim() == 1:
if position_ids.dim() == 1 and batch.prefilling:
position_ids = self.model.get_position_ids(
input_ids, batch.image_grid_thw
)
Expand Down

0 comments on commit e2b394e

Please sign in to comment.