diff --git a/examples/python/model-chat.py b/examples/python/model-chat.py index 12fb18a44..6068d7524 100644 --- a/examples/python/model-chat.py +++ b/examples/python/model-chat.py @@ -29,24 +29,38 @@ def main(args): search_options['batch_size'] = 1 if args.verbose: print(search_options) + + # Get model type + model_type = None + if hasattr(model, "type"): + model_type = model.type + else: + import json, os + + with open(os.path.join(args.model_path, "genai_config.json"), "r") as f: + genai_config = json.load(f) + model_type = genai_config["model"]["type"] + # Set chat template if args.chat_template: if args.chat_template.count('{') != 1 or args.chat_template.count('}') != 1: raise ValueError("Chat template must have exactly one pair of curly braces with input word in it, e.g. '<|user|>\n{input} <|end|>\n<|assistant|>'") else: - if model.type.startswith("phi2") or model.type.startswith("phi3"): + if model_type.startswith("phi2") or model_type.startswith("phi3"): args.chat_template = '<|user|>\n{input} <|end|>\n<|assistant|>' - elif model.type.startswith("phi4"): + elif model_type.startswith("phi4"): args.chat_template = '<|im_start|>user<|im_sep|>\n{input}<|im_end|>\n<|im_start|>assistant<|im_sep|>' - elif model.type.startswith("llama3"): + elif model_type.startswith("llama3"): args.chat_template = '<|start_header_id|>user<|end_header_id|>\n{input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>' - elif model.type.startswith("llama2"): + elif model_type.startswith("llama2"): args.chat_template = '{input}' + elif model_type.startswith("qwen2"): + args.chat_template = '<|im_start|>user\n{input}<|im_end|>\n<|im_start|>assistant\n' else: - raise ValueError(f"Chat Template for model type {model.type} is not known. Please provide chat template using --chat_template") + raise ValueError(f"Chat Template for model type {model_type} is not known. Please provide chat template using --chat_template") if args.verbose: - print("Model type is:", model.type) + print("Model type is:", model_type) print("Chat Template is:", args.chat_template) params = og.GeneratorParams(model) @@ -55,16 +69,22 @@ def main(args): if args.verbose: print("Generator created") # Set system prompt - if model.type.startswith('phi2') or model.type.startswith('phi3'): - system_prompt = f"<|system|>\n{args.system_prompt}<|end|>" - elif model.type.startswith('phi4'): - system_prompt = f"<|im_start|>system<|im_sep|>\n{args.system_prompt}<|im_end|>" - elif model.type.startswith("llama3"): - system_prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{args.system_prompt}<|eot_id|>" - elif model.type.startswith("llama2"): - system_prompt = f"[INST] <>\n{args.system_prompt}\n<>" - else: + if "<|" in args.system_prompt and "|>" in args.system_prompt: + # User-provided system template already has tags system_prompt = args.system_prompt + else: + if model_type.startswith('phi2') or model_type.startswith('phi3'): + system_prompt = f"<|system|>\n{args.system_prompt}<|end|>" + elif model_type.startswith('phi4'): + system_prompt = f"<|im_start|>system<|im_sep|>\n{args.system_prompt}<|im_end|>" + elif model_type.startswith("llama3"): + system_prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{args.system_prompt}<|eot_id|>" + elif model_type.startswith("llama2"): + system_prompt = f"[INST] <>\n{args.system_prompt}\n<>" + elif model_type.startswith("qwen2"): + system_prompt = f"<|im_start|>system\n{args.system_prompt}<|im_end|>\n" + else: + system_prompt = args.system_prompt system_tokens = tokenizer.encode(system_prompt) generator.append_tokens(system_tokens) @@ -79,11 +99,7 @@ def main(args): if args.timings: started_timestamp = time.time() - # If there is a chat template, use it - prompt = text - if args.chat_template: - prompt = f'{args.chat_template.format(input=text)}' - + prompt = f'{args.chat_template.format(input=text)}' input_tokens = tokenizer.encode(prompt) generator.append_tokens(input_tokens) diff --git a/examples/python/model-qa.py b/examples/python/model-qa.py index 5e639ef2b..e5e449a77 100644 --- a/examples/python/model-qa.py +++ b/examples/python/model-qa.py @@ -26,37 +26,57 @@ def main(args): search_options['batch_size'] = 1 if args.verbose: print(search_options) + + # Get model type + model_type = None + if hasattr(model, "type"): + model_type = model.type + else: + import json, os + + with open(os.path.join(args.model_path, "genai_config.json"), "r") as f: + genai_config = json.load(f) + model_type = genai_config["model"]["type"] + # Set chat template if args.chat_template: if args.chat_template.count('{') != 1 or args.chat_template.count('}') != 1: raise ValueError("Chat template must have exactly one pair of curly braces with input word in it, e.g. '<|user|>\n{input} <|end|>\n<|assistant|>'") else: - if model.type.startswith("phi2") or model.type.startswith("phi3"): + if model_type.startswith("phi2") or model_type.startswith("phi3"): args.chat_template = '<|user|>\n{input} <|end|>\n<|assistant|>' - elif model.type.startswith("phi4"): + elif model_type.startswith("phi4"): args.chat_template = '<|im_start|>user<|im_sep|>\n{input}<|im_end|>\n<|im_start|>assistant<|im_sep|>' - elif model.type.startswith("llama3"): + elif model_type.startswith("llama3"): args.chat_template = '<|start_header_id|>user<|end_header_id|>\n{input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>' - elif model.type.startswith("llama2"): + elif model_type.startswith("llama2"): args.chat_template = '{input}' + elif model_type.startswith("qwen2"): + args.chat_template = '<|im_start|>user\n{input}<|im_end|>\n<|im_start|>assistant\n' else: - raise ValueError(f"Chat Template for model type {model.type} is not known. Please provide chat template using --chat_template") + raise ValueError(f"Chat Template for model type {model_type} is not known. Please provide chat template using --chat_template") params = og.GeneratorParams(model) params.set_search_options(**search_options) generator = og.Generator(model, params) # Set system prompt - if model.type.startswith('phi2') or model.type.startswith('phi3'): - system_prompt = f"<|system|>\n{args.system_prompt}<|end|>" - elif model.type.startswith('phi4'): - system_prompt = f"<|im_start|>system<|im_sep|>\n{args.system_prompt}<|im_end|>" - elif model.type.startswith("llama3"): - system_prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{args.system_prompt}<|eot_id|>" - elif model.type.startswith("llama2"): - system_prompt = f"[INST] <>\n{args.system_prompt}\n<>" - else: + if "<|" in args.system_prompt and "|>" in args.system_prompt: + # User-provided system template already has tags system_prompt = args.system_prompt + else: + if model_type.startswith('phi2') or model_type.startswith('phi3'): + system_prompt = f"<|system|>\n{args.system_prompt}<|end|>" + elif model_type.startswith('phi4'): + system_prompt = f"<|im_start|>system<|im_sep|>\n{args.system_prompt}<|im_end|>" + elif model_type.startswith("llama3"): + system_prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{args.system_prompt}<|eot_id|>" + elif model_type.startswith("llama2"): + system_prompt = f"[INST] <>\n{args.system_prompt}\n<>" + elif model_type.startswith("qwen2"): + system_prompt = f"<|im_start|>system\n{args.system_prompt}<|im_end|>\n" + else: + system_prompt = args.system_prompt system_tokens = tokenizer.encode(system_prompt) generator.append_tokens(system_tokens) @@ -71,11 +91,7 @@ def main(args): if args.timings: started_timestamp = time.time() - # If there is a chat template, use it - prompt = text - if args.chat_template: - prompt = f'{args.chat_template.format(input=text)}' - + prompt = f'{args.chat_template.format(input=text)}' input_tokens = tokenizer.encode(prompt) generator.append_tokens(input_tokens)