|
9 | 9 | import openai # use the official client for correctness check
|
10 | 10 | from huggingface_hub import snapshot_download # downloading lora to test lora requests
|
11 | 11 |
|
| 12 | +# imports for guided decoding tests |
| 13 | +import json |
| 14 | +import jsonschema |
| 15 | +import re |
| 16 | + |
12 | 17 | from vllm.transformers_utils.tokenizer import get_tokenizer
|
13 | 18 |
|
14 | 19 | MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
|
15 | 20 | MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # any model with a chat template should work here
|
16 | 21 | LORA_NAME = "typeof/zephyr-7b-beta-lora" # technically this needs Mistral-7B-v0.1 as base, but we're not testing generation quality here
|
17 | 22 |
|
| 23 | +TEST_SCHEMA = { |
| 24 | + "type": "object", |
| 25 | + "properties": { |
| 26 | + "name": { |
| 27 | + "type": "string" |
| 28 | + }, |
| 29 | + "age": { |
| 30 | + "type": "integer" |
| 31 | + }, |
| 32 | + "skills": { |
| 33 | + "type": "array", |
| 34 | + "items": { |
| 35 | + "type": "string", |
| 36 | + "maxLength": 10 |
| 37 | + }, |
| 38 | + "minItems": 3 |
| 39 | + }, |
| 40 | + "work history": { |
| 41 | + "type": "array", |
| 42 | + "items": { |
| 43 | + "type": "object", |
| 44 | + "properties": { |
| 45 | + "company": { |
| 46 | + "type": "string" |
| 47 | + }, |
| 48 | + "duration": { |
| 49 | + "type": "string" |
| 50 | + }, |
| 51 | + "position": { |
| 52 | + "type": "string" |
| 53 | + } |
| 54 | + }, |
| 55 | + "required": ["company", "position"] |
| 56 | + } |
| 57 | + } |
| 58 | + }, |
| 59 | + "required": ["name", "age", "skills", "work history"] |
| 60 | +} |
| 61 | + |
| 62 | +TEST_REGEX = r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + \ |
| 63 | + r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)" |
| 64 | + |
| 65 | +TEST_CHOICE = [ |
| 66 | + "Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript", "Ruby", |
| 67 | + "Swift", "Kotlin" |
| 68 | +] |
| 69 | + |
18 | 70 | pytestmark = pytest.mark.asyncio
|
19 | 71 |
|
20 | 72 |
|
@@ -325,6 +377,7 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI):
|
325 | 377 | max_tokens=max_tokens,
|
326 | 378 | temperature=0.0,
|
327 | 379 | logit_bias={str(token_id): 100},
|
| 380 | + seed=42, |
328 | 381 | )
|
329 | 382 | assert completion.choices[0].text is not None and len(
|
330 | 383 | completion.choices[0].text) >= 5
|
@@ -358,5 +411,189 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI):
|
358 | 411 | assert first_response != completion.choices[0].text
|
359 | 412 |
|
360 | 413 |
|
| 414 | +async def test_guided_json_completion(server, client: openai.AsyncOpenAI): |
| 415 | + completion = await client.completions.create( |
| 416 | + model=MODEL_NAME, |
| 417 | + prompt= |
| 418 | + f"Give an example JSON for an employee profile that fits this schema: {TEST_SCHEMA}", |
| 419 | + n=3, |
| 420 | + temperature=1.0, |
| 421 | + max_tokens=500, |
| 422 | + extra_body=dict(guided_json=TEST_SCHEMA)) |
| 423 | + |
| 424 | + assert completion.id is not None |
| 425 | + assert completion.choices is not None and len(completion.choices) == 3 |
| 426 | + for i in range(3): |
| 427 | + assert completion.choices[i].text is not None |
| 428 | + output_json = json.loads(completion.choices[i].text) |
| 429 | + jsonschema.validate(instance=output_json, schema=TEST_SCHEMA) |
| 430 | + |
| 431 | + |
| 432 | +async def test_guided_json_chat(server, client: openai.AsyncOpenAI): |
| 433 | + messages = [{ |
| 434 | + "role": "system", |
| 435 | + "content": "you are a helpful assistant" |
| 436 | + }, { |
| 437 | + "role": "user", |
| 438 | + "content": "Give an example JSON for an employee profile that " + \ |
| 439 | + f"fits this schema: {TEST_SCHEMA}" |
| 440 | + }] |
| 441 | + chat_completion = await client.chat.completions.create( |
| 442 | + model=MODEL_NAME, |
| 443 | + messages=messages, |
| 444 | + max_tokens=500, |
| 445 | + extra_body=dict(guided_json=TEST_SCHEMA)) |
| 446 | + message = chat_completion.choices[0].message |
| 447 | + assert message.content is not None |
| 448 | + json1 = json.loads(message.content) |
| 449 | + jsonschema.validate(instance=json1, schema=TEST_SCHEMA) |
| 450 | + |
| 451 | + messages.append({"role": "assistant", "content": message.content}) |
| 452 | + messages.append({ |
| 453 | + "role": |
| 454 | + "user", |
| 455 | + "content": |
| 456 | + "Give me another one with a different name and age" |
| 457 | + }) |
| 458 | + chat_completion = await client.chat.completions.create( |
| 459 | + model=MODEL_NAME, |
| 460 | + messages=messages, |
| 461 | + max_tokens=500, |
| 462 | + extra_body=dict(guided_json=TEST_SCHEMA)) |
| 463 | + message = chat_completion.choices[0].message |
| 464 | + assert message.content is not None |
| 465 | + json2 = json.loads(message.content) |
| 466 | + jsonschema.validate(instance=json2, schema=TEST_SCHEMA) |
| 467 | + assert json1["name"] != json2["name"] |
| 468 | + assert json1["age"] != json2["age"] |
| 469 | + |
| 470 | + |
| 471 | +async def test_guided_regex_completion(server, client: openai.AsyncOpenAI): |
| 472 | + completion = await client.completions.create( |
| 473 | + model=MODEL_NAME, |
| 474 | + prompt=f"Give an example IPv4 address with this regex: {TEST_REGEX}", |
| 475 | + n=3, |
| 476 | + temperature=1.0, |
| 477 | + max_tokens=20, |
| 478 | + extra_body=dict(guided_regex=TEST_REGEX)) |
| 479 | + |
| 480 | + assert completion.id is not None |
| 481 | + assert completion.choices is not None and len(completion.choices) == 3 |
| 482 | + for i in range(3): |
| 483 | + assert completion.choices[i].text is not None |
| 484 | + assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None |
| 485 | + |
| 486 | + |
| 487 | +async def test_guided_regex_chat(server, client: openai.AsyncOpenAI): |
| 488 | + messages = [{ |
| 489 | + "role": "system", |
| 490 | + "content": "you are a helpful assistant" |
| 491 | + }, { |
| 492 | + "role": |
| 493 | + "user", |
| 494 | + "content": |
| 495 | + f"Give an example IP address with this regex: {TEST_REGEX}" |
| 496 | + }] |
| 497 | + chat_completion = await client.chat.completions.create( |
| 498 | + model=MODEL_NAME, |
| 499 | + messages=messages, |
| 500 | + max_tokens=20, |
| 501 | + extra_body=dict(guided_regex=TEST_REGEX)) |
| 502 | + ip1 = chat_completion.choices[0].message.content |
| 503 | + assert ip1 is not None |
| 504 | + assert re.fullmatch(TEST_REGEX, ip1) is not None |
| 505 | + |
| 506 | + messages.append({"role": "assistant", "content": ip1}) |
| 507 | + messages.append({"role": "user", "content": "Give me a different one"}) |
| 508 | + chat_completion = await client.chat.completions.create( |
| 509 | + model=MODEL_NAME, |
| 510 | + messages=messages, |
| 511 | + max_tokens=20, |
| 512 | + extra_body=dict(guided_regex=TEST_REGEX)) |
| 513 | + ip2 = chat_completion.choices[0].message.content |
| 514 | + assert ip2 is not None |
| 515 | + assert re.fullmatch(TEST_REGEX, ip2) is not None |
| 516 | + assert ip1 != ip2 |
| 517 | + |
| 518 | + |
| 519 | +async def test_guided_choice_completion(server, client: openai.AsyncOpenAI): |
| 520 | + completion = await client.completions.create( |
| 521 | + model=MODEL_NAME, |
| 522 | + prompt="The best language for type-safe systems programming is ", |
| 523 | + n=2, |
| 524 | + temperature=1.0, |
| 525 | + max_tokens=10, |
| 526 | + extra_body=dict(guided_choice=TEST_CHOICE)) |
| 527 | + |
| 528 | + assert completion.id is not None |
| 529 | + assert completion.choices is not None and len(completion.choices) == 2 |
| 530 | + for i in range(2): |
| 531 | + assert completion.choices[i].text in TEST_CHOICE |
| 532 | + |
| 533 | + |
| 534 | +async def test_guided_choice_chat(server, client: openai.AsyncOpenAI): |
| 535 | + messages = [{ |
| 536 | + "role": "system", |
| 537 | + "content": "you are a helpful assistant" |
| 538 | + }, { |
| 539 | + "role": |
| 540 | + "user", |
| 541 | + "content": |
| 542 | + "The best language for type-safe systems programming is " |
| 543 | + }] |
| 544 | + chat_completion = await client.chat.completions.create( |
| 545 | + model=MODEL_NAME, |
| 546 | + messages=messages, |
| 547 | + max_tokens=10, |
| 548 | + extra_body=dict(guided_choice=TEST_CHOICE)) |
| 549 | + choice1 = chat_completion.choices[0].message.content |
| 550 | + assert choice1 in TEST_CHOICE |
| 551 | + |
| 552 | + messages.append({"role": "assistant", "content": choice1}) |
| 553 | + messages.append({ |
| 554 | + "role": "user", |
| 555 | + "content": "I disagree, pick another one" |
| 556 | + }) |
| 557 | + chat_completion = await client.chat.completions.create( |
| 558 | + model=MODEL_NAME, |
| 559 | + messages=messages, |
| 560 | + max_tokens=10, |
| 561 | + extra_body=dict(guided_choice=TEST_CHOICE)) |
| 562 | + choice2 = chat_completion.choices[0].message.content |
| 563 | + assert choice2 in TEST_CHOICE |
| 564 | + assert choice1 != choice2 |
| 565 | + |
| 566 | + |
| 567 | +async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI): |
| 568 | + with pytest.raises(openai.BadRequestError): |
| 569 | + _ = await client.completions.create( |
| 570 | + model=MODEL_NAME, |
| 571 | + prompt="Give an example JSON that fits this schema: 42", |
| 572 | + extra_body=dict(guided_json=42)) |
| 573 | + |
| 574 | + messages = [{ |
| 575 | + "role": "system", |
| 576 | + "content": "you are a helpful assistant" |
| 577 | + }, { |
| 578 | + "role": |
| 579 | + "user", |
| 580 | + "content": |
| 581 | + "The best language for type-safe systems programming is " |
| 582 | + }] |
| 583 | + with pytest.raises(openai.BadRequestError): |
| 584 | + _ = await client.chat.completions.create(model=MODEL_NAME, |
| 585 | + messages=messages, |
| 586 | + extra_body=dict(guided_regex={ |
| 587 | + 1: "Python", |
| 588 | + 2: "C++" |
| 589 | + })) |
| 590 | + |
| 591 | + with pytest.raises(openai.BadRequestError): |
| 592 | + _ = await client.completions.create( |
| 593 | + model=MODEL_NAME, |
| 594 | + prompt="Give an example string that fits this regex", |
| 595 | + extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA)) |
| 596 | + |
| 597 | + |
361 | 598 | if __name__ == "__main__":
|
362 | 599 | pytest.main([__file__])
|
0 commit comments