diff --git a/llm_exl2_client_multi_outlines.py b/llm_exl2_client_multi_outlines.py index fc43250..b89fcb1 100644 --- a/llm_exl2_client_multi_outlines.py +++ b/llm_exl2_client_multi_outlines.py @@ -745,4 +745,4 @@ async def get_nvidia_smi(): if __name__ == "__main__": import uvicorn - uvicorn.run(app, host=host, port=port, log_level="debug") + uvicorn.run(app, host=host, port=port, log_level="debug") \ No newline at end of file diff --git a/llm_exl2_dynamic_gen.py b/llm_exl2_dynamic_gen.py index 1e170fd..7bca8b4 100644 --- a/llm_exl2_dynamic_gen.py +++ b/llm_exl2_dynamic_gen.py @@ -1,48 +1,45 @@ -import sys, os -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -from exllamav2 import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Cache_Q4, ExLlamaV2Tokenizer, ExLlamaV2Lora -from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2DynamicJob, ExLlamaV2Sampler -from blessed import Terminal -import pprint - import asyncio import json import os -import logging import time import configparser import argparse -import tiktoken -import torch -import random from typing import AsyncIterable, List, Generator, Union, Optional - -import requests -import sseclient +import traceback import subprocess -import textwrap - -from fastapi import FastAPI +from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from pydantic import BaseModel -import uuid -import threading +from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextStreamer, TextIteratorStreamer from threading import Thread import queue -import uvicorn -from io import StringIO -from util import format_prompt_llama3, format_prompt, format_prompt_tess, format_prompt_commandr -from util_merge import ExLlamaV2MergePassthrough +import traceback +import re + + +import sys, os +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from exllamav2 import( + ExLlamaV2, + ExLlamaV2Config, + ExLlamaV2Cache, + ExLlamaV2Cache_8bit, + ExLlamaV2Cache_Q4, + ExLlamaV2Tokenizer, +) + +from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2DynamicJob, ExLlamaV2Sampler +import uuid +from blessed import Terminal +import textwrap +from outlines.integrations.exllamav2 import RegexFilter, TextFilter, JSONFilter, ChoiceFilter def generate_unique_id(): return uuid.uuid4() -# This is a demo and small stress to showcase some of the features of the dynamic batching generator. -repo_str = 'commandr-exl2' - class CompletionRequest(BaseModel): model: str prompt: Union[str, List[str]] @@ -77,7 +74,42 @@ class ChatCompletionRequest(BaseModel): n: Optional[int] = 1 # default value of 1, batch size top_p: Optional[float] = 0.0 # default value of 0.0 user: Optional[str] = None + stop_at: Optional[str] = None + outlines_type: Optional[str] = None + choices: Optional[list[str]] = None + regex: Optional[str] = None + json: Optional[str] = None + request_id: Optional[str] = None + partial_generation: Optional[str] = None + +#repo_str = 'theprofessor-exl2-speculative' +parser = argparse.ArgumentParser(description='Run server with specified port.') + +# Add argument for port with default type as integer +parser.add_argument('--port', type=int, help='Port to run the server on.') +parser.add_argument('--max_context', type=int, default=8192, help='Context length.') +parser.add_argument('--repo_str', type=str, default='llama3-70b-instruct', help='The model repository name') +parser.add_argument('--total_context', type=int, default=32768, help="Total context length") +parser.add_argument('--max_chunk_size', type=int, default=2048, help='Max chunk size.') +parser.add_argument('--max_new_tokens', type=int, default=2048, help='Max new tokens.') +parser.add_argument('--display_mode', type=int, default=1, help='Display mode.') +parser.add_argument('--use_draft_model', action="store_true", help='Do speculative decoding') +parser.add_argument('--not_paged', action="store_true", help='Do not do paged attention') + + + +# Parse the arguments +args = parser.parse_args() +repo_str = args.repo_str + +config = configparser.ConfigParser() +config.read('config.ini') + +repo_id = config.get(repo_str, 'repo') +host = config.get('settings', 'host') + +port = args.port if args.port is not None else config.getint('settings', 'port') class StatusArea: def __init__(self, num_lines): self.num_lines = num_lines @@ -169,14 +201,14 @@ def display(self): if self.console_line is not None: print(term.move_xy(0, self.console_line) + self.display_text) - -parser = argparse.ArgumentParser(description='Run server with specified port.') - -# Add argument for port with default type as integer -parser.add_argument('--port', type=int, help='Port to run the server on.') - -# Parse the arguments -args = parser.parse_args() +def get_stop_conditions(tokenizer): + # get_stop_condition special case if model is llama3 + if "llama3" in repo_str: + return [tokenizer.single_id("<|eot_id|>"), tokenizer.eos_token_id] + # elif prompt_format == "granite": + # return [tokenizer.eos_token_id, "\n\nQuestion:"] + else: + return [tokenizer.eos_token_id] config = configparser.ConfigParser() config.read('config.ini') @@ -191,25 +223,25 @@ def display(self): # 2: Print completions as jobs finish # 3: Step over output iteration by iteration # 4: Space heater mode (no output) -display_mode = 1 +display_mode = args.display_mode # Whether to use paged mode or not. The generator is very handicapped in unpaged mode, does not support batching # or CFG, but it will work without flash-attn 2.5.7+ -paged = True +paged = not args.not_paged # Where to find our model model_dir = repo_id # Total number of tokens to allocate space for. This is not the max_seq_len supported by the model but # the total to distribute dynamically over however many jobs are active at once -total_context = 32768 +total_context = total_context = args.total_context # Max individual context -max_context = 8192 +max_context = args.max_context # N-gram or draft model speculative decoding. Largely detrimental to performance at higher batch sizes. use_ngram = False -use_draft_model = False +use_draft_model = args.use_draft_model if use_draft_model: model_dir = repo_id draft_model_dir = specrepo_id @@ -219,24 +251,14 @@ def display(self): # Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a # new job is started, but at the expense of overall prompt ingestion speed. -max_chunk_size = 2048 +max_chunk_size = args.max_chunk_size # Max new tokens per completion. For this example applies to all jobs. -max_new_tokens = 2048 - -# Use LMFE to constrain the output to JSON format. See schema and details below. -json_mode = False +max_new_tokens = args.max_new_tokens # Demonstrate token healing healing = True -# Ban some phrases maybe -ban_strings = None -# ban_strings = [ -# "person to person", -# "one person to another" -# ] - term = Terminal() @@ -282,11 +304,14 @@ def display(self): ) model.load_autosplit(cache, progress = True) -#model.load([16,18,18,20]) # Also, tokenizer print("Loading tokenizer...") tokenizer = ExLlamaV2Tokenizer(config) +hf_tokenizer_kwargs = {} +hf_tokenizer_kwargs.setdefault("padding_side", "left") +hf_tokenizer = AutoTokenizer.from_pretrained(model_dir, **hf_tokenizer_kwargs) + # Model Merge @@ -332,48 +357,65 @@ def display(self): LLM_LINES = max_batch_size status_area = StatusArea(STATUS_LINES) displays = {} +prompt_ids2jobs = {} +print("*** Loaded.. now Inference...:") -if json_mode: - print("Creating jobs... (initializing JSON filters could take a moment.)") - - -def get_stop_conditions(prompt_format, tokenizer): - if prompt_format == "llama": - return [tokenizer.eos_token_id] - elif prompt_format == "llama3": - return [tokenizer.single_id("<|eot_id|>")] - elif prompt_format == "granite": - return [tokenizer.eos_token_id, "\n\nQuestion:"] - - -# Only import lmfe if json_mode is set - -if json_mode: - import json - from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter - from lmformatenforcer import JsonSchemaParser - from exllamav2.generator.filters import ExLlamaV2PrefixFilter - from pydantic import BaseModel - from typing import Literal - - class JSONResponse(BaseModel): - response: str - confidence: Literal["low", "medium", "high"] - is_subjective: Literal["no", "yes", "possibly"] - - schema_parser = JsonSchemaParser(JSONResponse.schema()) - - +# take from https://github.com/tiangolo/fastapi/discussions/11360 +class RequestCancelledMiddleware: + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + global prompt_ids2jobs + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + # Let's make a shared queue for the request messages + queue = asyncio.Queue() + cancelled_request_ids = [] + async def message_poller(sentinel, handler_task): + nonlocal queue + request_id = str(generate_unique_id()) + while True: + message = await receive() + print(message) + if "body" in message: + message["body"] = json.loads(message["body"].decode('utf8')) + message["body"]["request_id"] = request_id + message["body"] = str.encode(json.dumps(message["body"])) + print(message) + if message["type"] == "http.disconnect": + cancelled_request_ids.append(request_id) + handler_task.cancel() + return sentinel # Break the loop + + # Puts the message in the queue + await queue.put(message) + + sentinel = object() + handler_task = asyncio.create_task(self.app(scope, queue.get, send)) + asyncio.create_task(message_poller(sentinel, handler_task)) + + try: + return await handler_task + except asyncio.CancelledError: + print("Cancelling request due to disconnect") + # TODO: FIgure out how to get prompt id that disconnected + while len(cancelled_request_ids) > 0: + cancelled_id = cancelled_request_ids.pop() + generator.cancel(prompt_ids2jobs[cancelled_id]) + del prompt_ids2jobs[cancelled_id] -print("*** Loaded.. now Inference...:") app = FastAPI(title="EXL2") +app.add_middleware(RequestCancelledMiddleware) async def stream_response(prompt_id, timeout=180): global partial_responses while True: - await asyncio.sleep(0.001) # Sleep to yield control to the event loop + await asyncio.sleep(0.05) # Sleep to yield control to the event loop # Check if prompt_id exists in partial_responses if prompt_id in partial_responses: @@ -388,73 +430,75 @@ async def stream_response(prompt_id, timeout=180): yield f'data: {{"id":"chatcmpl-{prompt_id}","object":"chat.completion.chunk","created":{int(time.time())},"model":"{repo_str}","choices":[{{"index":0,"delta":{{}},"finish_reason":"stop"}}]}}\n\n' break -# Worker thread function + def process_prompts(): - global partial_responses + global partial_responses + global prompt_ids2jobs + try: - # To see what's going on, mode 1 - while True: - while not prompts.empty() or len(input_ids): - while len(input_ids) < max_batch_size and not prompts.empty(): - prompt_id, prompt, max_tokens, stream, temperature = prompts.get() - if json_mode: - prompt += "\n\n Answer in JSON syntax." - filters = [ - ExLlamaV2TokenEnforcerFilter(schema_parser, tokenizer), - ExLlamaV2PrefixFilter(model, tokenizer, "{") - ] - else: - filters = None - ids = tokenizer.encode(prompt, encode_special_tokens = True) - prompt_tokens = ids.shape[-1] - new_tokens = prompt_tokens + max_tokens - #print("Processing prompt: " + str(prompt_id) + " Req tokens: " + str(new_tokens)) - status_area.update(f"Processing prompt: {prompt_id} Req tokens: {new_tokens}", line=STATUS_LINES-1) - # Truncate if new_tokens exceed max_context - if new_tokens > max_context: - # Calculate how many tokens to truncate - ids = tokenizer.encode("Say, 'Prompt exceeds allowed length. Please try again.'") - # Update new_tokens after truncation + while True: + while not prompts.empty() or len(input_ids): + while not prompts.empty(): + prompt_id, prompt, max_tokens, stream, temperature, outlines_dict = prompts.get() + stop_at = outlines_dict.get("stop_at", None) + if outlines_dict["type"] == "choices": + filters = [ChoiceFilter(outlines_dict["choices"], hf_tokenizer)] + elif outlines_dict["type"] == "json": + filters = [JSONFilter(outlines_dict["json"], hf_tokenizer)] + elif outlines_dict["type"] == "regex": + filters = [RegexFilter(outlines_dict["regex"], hf_tokenizer)] + else: + filters = [] + ids = tokenizer.encode(prompt, encode_special_tokens = True) prompt_tokens = ids.shape[-1] new_tokens = prompt_tokens + max_tokens - print("Truncating prompt: " + str(prompt_id) + " Req tokens: " + str(new_tokens)) - prompt_length.append(prompt_tokens) - input_ids.append(ids) - #streamer.append(stream) - #prompt_ids.append(prompt_id) - - job = ExLlamaV2DynamicJob( - input_ids = ids, - max_new_tokens = max_tokens, - stop_conditions = get_stop_conditions('llama', tokenizer), - gen_settings = ExLlamaV2Sampler.Settings(), - banned_strings = ban_strings, - filters = filters, - filter_prefer_eos = True, - token_healing = healing - ) - - job.prompt_length = prompt_tokens - job.input_ids = ids - job.streamer = stream - job.prompt_ids = prompt_id - - generator.enqueue(job) - #displays = { job: JobStatusDisplay(job, line, STATUS_LINES) for line, job in enumerate(jobs) } - displays[job] = JobStatusDisplay(job, STATUS_LINES) - - for index, (job, display) in enumerate(list(displays.items())): - display.update_position(index%LLM_LINES) # Set position before updating - - - if(len(input_ids)): - #inputs = torch.cat([x[:, -1:] for x in input_ids], dim = 0) - #logits = model.forward(inputs, caches, input_mask = None).float().cpu() - + #print("Processing prompt: " + str(prompt_id) + " Req tokens: " + str(new_tokens)) + status_area.update(f"Processing prompt: {prompt_id} Req tokens: {new_tokens}", line=STATUS_LINES-1) + # Truncate if new_tokens exceed max_context + if new_tokens > max_context: + # Calculate how many tokens to truncate + ids = tokenizer.encode("Say, 'Prompt exceeds allowed length. Please try again.'") + # Update new_tokens after truncation + prompt_tokens = ids.shape[-1] + new_tokens = prompt_tokens + max_tokens + print("Truncating prompt: " + str(prompt_id) + " Req tokens: " + str(new_tokens)) + prompt_length.append(prompt_tokens) + input_ids.append(ids) + #streamer.append(stream) + #prompt_ids.append(prompt_id) + + preferred_eos = get_stop_conditions(tokenizer) + + if stop_at is not None: + preferred_eos.append(stop_at) + + gen_settings = ExLlamaV2Sampler.Settings() + gen_settings.temperature = 1.0 if temperature>1 else temperature # To make sure the temperature value does not exceed 1 + + job = ExLlamaV2DynamicJob( + input_ids = ids, + max_new_tokens = max_tokens, + stop_conditions = preferred_eos if stop_at is None else [tokenizer.eos_token_id, stop_at], + gen_settings = gen_settings, + filters = filters, + token_healing = healing + ) + + job.prompt_length = prompt_tokens + job.input_ids = ids + job.streamer = stream + job.prompt_ids = prompt_id + job.stop_at = stop_at + + generator.enqueue(job) + #displays = { job: JobStatusDisplay(job, line, STATUS_LINES) for line, job in enumerate(jobs) } + displays[job] = JobStatusDisplay(job, STATUS_LINES) + + for index, (job, display) in enumerate(list(displays.items())): + display.update_position(index%LLM_LINES) # Set position before updating + prompt_ids2jobs[prompt_id] = job results = generator.iterate() for r in results: - #for i in range(len(input_ids)): - #r = results[i] job = r["job"] displays[job].update(r) displays[job].display() @@ -463,21 +507,23 @@ def process_prompts(): outcontent = r.get("text", "") reason = None if(job.streamer): + if r["eos"] and job.stop_at is not None: + outcontent += job.stop_at partial_response_data = { - "id": f"chatcmpl-{job.prompt_ids}", - "object": "chat.completion.chunk", - "created": int(time.time()), - "model": repo_str, - "choices": [ - { - "index": 0, - "delta": { - "content": outcontent - }, - "finish_reason": reason - } - ] - } + "id": f"chatcmpl-{job.prompt_ids}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": repo_str, + "choices": [ + { + "index": 0, + "delta": { + "content": outcontent + }, + "finish_reason": reason + } + ] + } # Initialize a list for new prompt_id or append to existing one if job.prompt_ids not in partial_responses: @@ -497,7 +543,6 @@ def process_prompts(): # Calculate token counts completion_tokens_old = (tokenizer.encode(generated_text)).shape[-1] - prompt_tokens_old = (tokenizer.encode(prompt)).shape[-1] completion_tokens = r['new_tokens'] prompt_tokens = r['prompt_tokens'] @@ -515,6 +560,8 @@ def process_prompts(): responses[eos_prompt_id] = partial_response_data else:# Construct the response based on the format + if job.stop_at is not None: + generated_text += job.stop_at response_data = { "id": f"chatcmpl-{prompt_id}", "object": "chat.completion", @@ -534,56 +581,267 @@ def process_prompts(): "total_tokens": full_tokens } } - responses[eos_prompt_id] = response_data + del prompt_ids2jobs[eos_prompt_id] - # Clean up - input_ids.pop() - prompt_length.pop() - #streamer.pop(i) - else: - # Sleep for a short duration when there's no work - time.sleep(0.1) # Sleep for 100 milliseconds + else: + # Sleep for a short duration when there's no work + time.sleep(0.1) # Sleep for 100 milliseconds + except Exception as e: + print("Reset server due to ", e) + print(traceback.format_exc()) + for prompt_id in prompt_ids2jobs: + job = prompt_ids2jobs[prompt_id] + if(job.streamer): + ## Generator, yield here.. + partial_response_data = { + "finish_reason": "stop" + } + + responses[prompt_id] = partial_response_data + else: + print("Error handling for full generation current not implemented") + generator.cancel(job) + prompt_ids2jobs = {} # Start worker thread worker = Thread(target=process_prompts) worker.start() +async def format_prompt(messages): + formatted_prompt = "" + for message in messages: + if message.role == "system": + formatted_prompt += f"{message.content}\n\n" + elif message.role == "user": + formatted_prompt += f"### User:\n{message.content}\n\n" + elif message.role == "assistant": + formatted_prompt += f"### Assistant:\n{message.content}\n\n" + # Add the final "### Assistant:\n" to prompt for the next response + formatted_prompt += "### Assistant:\n" + return formatted_prompt + +async def format_prompt_llama3(messages): + formatted_prompt = "" + system_message_found = False + + # Check for a system message first + for message in messages: + if message.role == "system": + system_message_found = True + break + + # If no system message was found, prepend a default one + if not system_message_found: + formatted_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful AI assistant.<|eot_id|>" + for message in messages: + if message.role == "system": + formatted_prompt += f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{message.content}<|eot_id|>" + elif message.role == "user": + formatted_prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{message.content}<|eot_id|>" + elif message.role == "assistant": + formatted_prompt += f"<|start_header_id|>assistant<|end_header_id|>\n\n{message.content}<|eot_id|>" + # Add the final "### Assistant:\n" to prompt for the next response + formatted_prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n" + return formatted_prompt + +async def format_prompt_yi(messages): + formatted_prompt = "" + system_message_found = False + + # Check for a system message first + for message in messages: + if message.role == "system": + system_message_found = True + break + + # If no system message was found, prepend a default one + if not system_message_found: + formatted_prompt = "<|im_start|>system\nYou are a helpful AI assistant.<|im_end|>\n" + for message in messages: + if message.role == "system": + formatted_prompt += f"<|im_start|>system\n{message.content}<|im_end|>\n" + elif message.role == "user": + formatted_prompt += f"<|im_start|>user\n{message.content}<|im_end|>\n" + elif message.role == "assistant": + formatted_prompt += f"<|im_start|>assistant\n{message.content}<|im_end|>\n" + # Add the final "### Assistant:\n" to prompt for the next response + formatted_prompt += "<|im_start|>assistant\n" + return formatted_prompt + +async def format_prompt_nous(messages): + formatted_prompt = "" + for message in messages: + if message.role == "system": + formatted_prompt += f"{message.content}\n" + elif message.role == "user": + formatted_prompt += f"USER: {message.content}\n" + elif message.role == "assistant": + formatted_prompt += f"ASSISTANT: {message.content}\n" + # Add the final "### Assistant:\n" to prompt for the next response + formatted_prompt += "ASSISTANT: " + return formatted_prompt + +async def format_prompt_tess(messages): + formatted_prompt = "" + for message in messages: + if message.role == "system": + formatted_prompt += f"SYSTEM: {message.content}\n" + elif message.role == "user": + formatted_prompt += f"USER: {message.content}\n" + elif message.role == "assistant": + formatted_prompt += f"ASSISTANT: {message.content}\n" + # Add the final "### Assistant:\n" to prompt for the next response + formatted_prompt += "ASSISTANT: " + return formatted_prompt + +async def format_prompt_code(messages): + formatted_prompt = "" + for message in messages: + if message.role == "system": + formatted_prompt += f"### System Prompt\nYou are an intelligent programming assistant.\n\n" + elif message.role == "user": + formatted_prompt += f"### User Message\n{message.content}\n\n" + elif message.role == "assistant": + formatted_prompt += f"### Assistant\n{message.content}\n\n" + # Add the final "### Assistant" with ellipsis to prompt for the next response + formatted_prompt += "### Assistant\n..." + return formatted_prompt + +async def format_prompt_zephyr(messages): + formatted_prompt = "" + for message in messages: + if message.role == "system": + formatted_prompt += f"<|system|>\n{message.content}\n" + elif message.role == "user": + formatted_prompt += f"<|user|>\n{message.content}\n" + elif message.role == "assistant": + formatted_prompt += f"<|assistant|>\n{message.content}\n" + # Add the final "### Assistant:\n" to prompt for the next response + formatted_prompt += "<|assistant|>\n" + return formatted_prompt + +async def format_prompt_starling(messages): + formatted_prompt = "" + system_message = "" + for message in messages: + if message.role == "system": + # Save system message to prepend to the first user message + system_message += f"{message.content}\n\n" + elif message.role == "user": + # Prepend system message if it exists + if system_message: + formatted_prompt += f"GPT4 Correct User: {system_message}{message.content}<|end_of_turn|>" + system_message = "" # Clear system message after prepending + else: + formatted_prompt += f"GPT4 Correct User: {message.content}<|end_of_turn|>" + elif message.role == "assistant": + formatted_prompt += f"GPT4 Correct Assistant: {message.content}<|end_of_turn|>" # Prep for user follow-up + formatted_prompt += "GPT4 Correct Assistant: \n\n" + return formatted_prompt + +async def format_prompt_mixtral(messages): + formatted_prompt = " " + system_message = "" + for message in messages: + if message.role == "system": + # Save system message to prepend to the first user message + system_message += f"{message.content}\n\n" + elif message.role == "user": + # Prepend system message if it exists + if system_message: + formatted_prompt += f"[INST] {system_message}{message.content} [/INST] " + system_message = "" # Clear system message after prepending + else: + formatted_prompt += f"[INST] {message.content} [/INST] " + elif message.role == "assistant": + formatted_prompt += f" {message.content} " # Prep for user follow-up + return formatted_prompt + + +async def format_prompt_commandr(messages): + formatted_prompt = "" + system_message_found = False + + # Check for a system message first + for message in messages: + if message.role == "system": + system_message_found = True + break + + # If no system message was found, prepend a default one + if not system_message_found: + formatted_prompt += f"<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{message.content}<|END_OF_TURN_TOKEN|>" + + for message in messages: + if message.role == "system": + formatted_prompt += f"<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{message.content}<|END_OF_TURN_TOKEN|>" + elif message.role == "user": + formatted_prompt += f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{message.content}<|END_OF_TURN_TOKEN|>" + elif message.role == "assistant": + formatted_prompt += f"<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{message.content}<|END_OF_TURN_TOKEN|>" + # Add the final "### Assistant:\n" to prompt for the next response + formatted_prompt += "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" + return formatted_prompt + + @app.post('/v1/chat/completions') async def mainchat(request: ChatCompletionRequest): - try: prompt = '' if repo_str == 'Phind-CodeLlama-34B-v2': prompt = await format_prompt_code(request.messages) elif repo_str == 'zephyr-7b-beta': prompt = await format_prompt_zephyr(request.messages) - elif repo_str == 'llama3-70b-instruct' or repo_str == 'llama3-70b-instruct-speculative': + elif repo_str == 'llama3-70b-instruct': prompt = await format_prompt_llama3(request.messages) elif repo_str == 'Starling-LM-7B-alpha': prompt = await format_prompt_starling(request.messages) - elif repo_str == 'Mixtral-8x7B-Instruct-v0.1-GPTQ' or repo_str == 'miqu-exl2-speculative': + elif repo_str == 'Mixtral-8x7B-Instruct-v0.1-GPTQ': prompt = await format_prompt_mixtral(request.messages) - elif repo_str == 'Yi-34B-Chat-GPTQ' or repo_str == 'Nous-Hermes-2-Yi-34B-GPTQ' or repo_str == 'theprofessor-exl2-speculative' or repo_str == 'Yi-34B-Chat': + elif repo_str == 'Yi-34B-Chat-GPTQ' or repo_str == 'Nous-Hermes-2-Yi-34B-GPTQ' or repo_str == 'theprofessor-exl2-speculative' or repo_str == 'dbrx-instruct-exl2': prompt = await format_prompt_yi(request.messages) elif repo_str == 'Nous-Capybara-34B-GPTQ' or repo_str == 'goliath-120b-GPTQ' or repo_str == 'goliath-120b-exl2' or repo_str == 'goliath-120b-exl2-rpcal': prompt = await format_prompt_nous(request.messages) - elif repo_str == 'tess-xl-exl2' or repo_str == 'tess-xl-exl2-speculative' or repo_str == 'venus-exl2-speculative': + elif repo_str == 'tess-xl-exl2' or repo_str == 'tess-xl-exl2-speculative': prompt = await format_prompt_tess(request.messages) - elif repo_str == 'tinyllama-exl2-speculative': - prompt = await format_prompt_zephyr(request.messages) elif repo_str == 'commandr-exl2' or repo_str == 'commandr-exl2-speculative': prompt = await format_prompt_commandr(request.messages) else: prompt = await format_prompt(request.messages) - status_area.update(f"Prompt: {prompt}") + if request.partial_generation is not None: + prompt += request.partial_generation + print(prompt) timeout = 180 # seconds start_time = time.time() - prompt_id = generate_unique_id() # Replace with a function to generate unique IDs - prompts.put((prompt_id, prompt, request.max_tokens, request.stream, request.temperature)) + prompt_id = request.request_id # Replace with a function to generate unique IDs + outlines_dict = {} + + # Adjust temperature if it is 0 + if request.temperature == 0: + request.temperature = 0.001 + + if request.stop_at is not None: + outlines_dict["stop_at"] = request.stop_at + if request.outlines_type is not None: + outlines_dict["type"] = request.outlines_type + else: + outlines_dict["type"] = "text" + if outlines_dict["type"] == "choices": + assert request.choices is not None + outlines_dict["choices"] = request.choices + elif outlines_dict["type"] == "json": + assert request.json is not None + outlines_dict["json"] = request.json + elif outlines_dict["type"] == "regex": + assert request.regex is not None + outlines_dict["regex"] = request.regex + else: + assert outlines_dict["type"] == "text" + prompts.put((prompt_id, prompt, request.max_tokens, request.stream, request.temperature, outlines_dict)) if request.stream: #response = StreamingResponse(streaming_request(prompt, request.max_tokens, tempmodel=repo_str, response_format='chat_completion'), media_type="text/event-stream") @@ -599,9 +857,11 @@ async def mainchat(request: ChatCompletionRequest): return responses.pop(prompt_id) except Exception as e: + print(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) - return response + + @app.get('/ping') @@ -633,10 +893,6 @@ async def get_nvidia_smi(): if __name__ == "__main__": + import uvicorn - uvicorn.run(app, host=host, port=port, log_level="error") - - print(term.enter_fullscreen()) - - - + uvicorn.run(app, host=host, port=port, log_level="debug")