Skip to content

Commit

Permalink
Make model type backwards compatible (#1212)
Browse files Browse the repository at this point in the history
### Description

This PR makes accessing the model type possible by reading from the
GenAI config if the model object does not contain the type attribute. It
also adds the chat and system templates for Qwen models.

### Motivation and Context

This PR allows the examples to be backwards compatible with the
published RCs for v0.6.0.

The Qwen chat template and Qwen system template were obtained from the
following information.

```
>>> from transformers import AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", cache_dir="./cache_dir")
tokenizer_config.json: 100%|█████████████████████████████████████████████████████████████████████████████| 7.30k/7.30k [00:00<00:00, 9.64MB/s]
vocab.json: 100%|████████████████████████████████████████████████████████████████████████████████████████| 2.78M/2.78M [00:00<00:00, 16.5MB/s]
merges.txt: 100%|████████████████████████████████████████████████████████████████████████████████████████| 1.67M/1.67M [00:00<00:00, 15.6MB/s]
tokenizer.json: 100%|████████████████████████████████████████████████████████████████████████████████████| 7.03M/7.03M [00:00<00:00, 21.5MB/s]
>>> prompt = "Give me a short introduction to large language model."
>>> messages = [ {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."}, {"role": "user", "content": prompt} ]
>>> text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
>>> text
'<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nGive me a short introduction to large language model.<|im_end|>\n<|im_start|>assistant\n'
```
  • Loading branch information
kunal-vaishnavi authored Feb 1, 2025
1 parent 6da4195 commit dba39b4
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 39 deletions.
56 changes: 36 additions & 20 deletions examples/python/model-chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,38 @@ def main(args):
search_options['batch_size'] = 1

if args.verbose: print(search_options)

# Get model type
model_type = None
if hasattr(model, "type"):
model_type = model.type
else:
import json, os

with open(os.path.join(args.model_path, "genai_config.json"), "r") as f:
genai_config = json.load(f)
model_type = genai_config["model"]["type"]

# Set chat template
if args.chat_template:
if args.chat_template.count('{') != 1 or args.chat_template.count('}') != 1:
raise ValueError("Chat template must have exactly one pair of curly braces with input word in it, e.g. '<|user|>\n{input} <|end|>\n<|assistant|>'")
else:
if model.type.startswith("phi2") or model.type.startswith("phi3"):
if model_type.startswith("phi2") or model_type.startswith("phi3"):
args.chat_template = '<|user|>\n{input} <|end|>\n<|assistant|>'
elif model.type.startswith("phi4"):
elif model_type.startswith("phi4"):
args.chat_template = '<|im_start|>user<|im_sep|>\n{input}<|im_end|>\n<|im_start|>assistant<|im_sep|>'
elif model.type.startswith("llama3"):
elif model_type.startswith("llama3"):
args.chat_template = '<|start_header_id|>user<|end_header_id|>\n{input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>'
elif model.type.startswith("llama2"):
elif model_type.startswith("llama2"):
args.chat_template = '<s>{input}'
elif model_type.startswith("qwen2"):
args.chat_template = '<|im_start|>user\n{input}<|im_end|>\n<|im_start|>assistant\n'
else:
raise ValueError(f"Chat Template for model type {model.type} is not known. Please provide chat template using --chat_template")
raise ValueError(f"Chat Template for model type {model_type} is not known. Please provide chat template using --chat_template")

if args.verbose:
print("Model type is:", model.type)
print("Model type is:", model_type)
print("Chat Template is:", args.chat_template)

params = og.GeneratorParams(model)
Expand All @@ -55,16 +69,22 @@ def main(args):
if args.verbose: print("Generator created")

# Set system prompt
if model.type.startswith('phi2') or model.type.startswith('phi3'):
system_prompt = f"<|system|>\n{args.system_prompt}<|end|>"
elif model.type.startswith('phi4'):
system_prompt = f"<|im_start|>system<|im_sep|>\n{args.system_prompt}<|im_end|>"
elif model.type.startswith("llama3"):
system_prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{args.system_prompt}<|eot_id|>"
elif model.type.startswith("llama2"):
system_prompt = f"<s>[INST] <<SYS>>\n{args.system_prompt}\n<</SYS>>"
else:
if "<|" in args.system_prompt and "|>" in args.system_prompt:
# User-provided system template already has tags
system_prompt = args.system_prompt
else:
if model_type.startswith('phi2') or model_type.startswith('phi3'):
system_prompt = f"<|system|>\n{args.system_prompt}<|end|>"
elif model_type.startswith('phi4'):
system_prompt = f"<|im_start|>system<|im_sep|>\n{args.system_prompt}<|im_end|>"
elif model_type.startswith("llama3"):
system_prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{args.system_prompt}<|eot_id|>"
elif model_type.startswith("llama2"):
system_prompt = f"<s>[INST] <<SYS>>\n{args.system_prompt}\n<</SYS>>"
elif model_type.startswith("qwen2"):
system_prompt = f"<|im_start|>system\n{args.system_prompt}<|im_end|>\n"
else:
system_prompt = args.system_prompt

system_tokens = tokenizer.encode(system_prompt)
generator.append_tokens(system_tokens)
Expand All @@ -79,11 +99,7 @@ def main(args):

if args.timings: started_timestamp = time.time()

# If there is a chat template, use it
prompt = text
if args.chat_template:
prompt = f'{args.chat_template.format(input=text)}'

prompt = f'{args.chat_template.format(input=text)}'
input_tokens = tokenizer.encode(prompt)

generator.append_tokens(input_tokens)
Expand Down
54 changes: 35 additions & 19 deletions examples/python/model-qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,37 +26,57 @@ def main(args):
search_options['batch_size'] = 1

if args.verbose: print(search_options)

# Get model type
model_type = None
if hasattr(model, "type"):
model_type = model.type
else:
import json, os

with open(os.path.join(args.model_path, "genai_config.json"), "r") as f:
genai_config = json.load(f)
model_type = genai_config["model"]["type"]

# Set chat template
if args.chat_template:
if args.chat_template.count('{') != 1 or args.chat_template.count('}') != 1:
raise ValueError("Chat template must have exactly one pair of curly braces with input word in it, e.g. '<|user|>\n{input} <|end|>\n<|assistant|>'")
else:
if model.type.startswith("phi2") or model.type.startswith("phi3"):
if model_type.startswith("phi2") or model_type.startswith("phi3"):
args.chat_template = '<|user|>\n{input} <|end|>\n<|assistant|>'
elif model.type.startswith("phi4"):
elif model_type.startswith("phi4"):
args.chat_template = '<|im_start|>user<|im_sep|>\n{input}<|im_end|>\n<|im_start|>assistant<|im_sep|>'
elif model.type.startswith("llama3"):
elif model_type.startswith("llama3"):
args.chat_template = '<|start_header_id|>user<|end_header_id|>\n{input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>'
elif model.type.startswith("llama2"):
elif model_type.startswith("llama2"):
args.chat_template = '<s>{input}'
elif model_type.startswith("qwen2"):
args.chat_template = '<|im_start|>user\n{input}<|im_end|>\n<|im_start|>assistant\n'
else:
raise ValueError(f"Chat Template for model type {model.type} is not known. Please provide chat template using --chat_template")
raise ValueError(f"Chat Template for model type {model_type} is not known. Please provide chat template using --chat_template")

params = og.GeneratorParams(model)
params.set_search_options(**search_options)
generator = og.Generator(model, params)

# Set system prompt
if model.type.startswith('phi2') or model.type.startswith('phi3'):
system_prompt = f"<|system|>\n{args.system_prompt}<|end|>"
elif model.type.startswith('phi4'):
system_prompt = f"<|im_start|>system<|im_sep|>\n{args.system_prompt}<|im_end|>"
elif model.type.startswith("llama3"):
system_prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{args.system_prompt}<|eot_id|>"
elif model.type.startswith("llama2"):
system_prompt = f"<s>[INST] <<SYS>>\n{args.system_prompt}\n<</SYS>>"
else:
if "<|" in args.system_prompt and "|>" in args.system_prompt:
# User-provided system template already has tags
system_prompt = args.system_prompt
else:
if model_type.startswith('phi2') or model_type.startswith('phi3'):
system_prompt = f"<|system|>\n{args.system_prompt}<|end|>"
elif model_type.startswith('phi4'):
system_prompt = f"<|im_start|>system<|im_sep|>\n{args.system_prompt}<|im_end|>"
elif model_type.startswith("llama3"):
system_prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{args.system_prompt}<|eot_id|>"
elif model_type.startswith("llama2"):
system_prompt = f"<s>[INST] <<SYS>>\n{args.system_prompt}\n<</SYS>>"
elif model_type.startswith("qwen2"):
system_prompt = f"<|im_start|>system\n{args.system_prompt}<|im_end|>\n"
else:
system_prompt = args.system_prompt

system_tokens = tokenizer.encode(system_prompt)
generator.append_tokens(system_tokens)
Expand All @@ -71,11 +91,7 @@ def main(args):

if args.timings: started_timestamp = time.time()

# If there is a chat template, use it
prompt = text
if args.chat_template:
prompt = f'{args.chat_template.format(input=text)}'

prompt = f'{args.chat_template.format(input=text)}'
input_tokens = tokenizer.encode(prompt)

generator.append_tokens(input_tokens)
Expand Down

0 comments on commit dba39b4

Please sign in to comment.