diff --git a/llm_exl2_client_multi_outlines.py b/llm_exl2_client_multi_outlines.py
index fc43250..b89fcb1 100644
--- a/llm_exl2_client_multi_outlines.py
+++ b/llm_exl2_client_multi_outlines.py
@@ -745,4 +745,4 @@ async def get_nvidia_smi():
if __name__ == "__main__":
import uvicorn
- uvicorn.run(app, host=host, port=port, log_level="debug")
+ uvicorn.run(app, host=host, port=port, log_level="debug")
\ No newline at end of file
diff --git a/llm_exl2_dynamic_gen.py b/llm_exl2_dynamic_gen.py
index 1e170fd..7bca8b4 100644
--- a/llm_exl2_dynamic_gen.py
+++ b/llm_exl2_dynamic_gen.py
@@ -1,48 +1,45 @@
-import sys, os
-sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
-
-from exllamav2 import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Cache_Q4, ExLlamaV2Tokenizer, ExLlamaV2Lora
-from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2DynamicJob, ExLlamaV2Sampler
-from blessed import Terminal
-import pprint
-
import asyncio
import json
import os
-import logging
import time
import configparser
import argparse
-import tiktoken
-import torch
-import random
from typing import AsyncIterable, List, Generator, Union, Optional
-
-import requests
-import sseclient
+import traceback
import subprocess
-import textwrap
-
-from fastapi import FastAPI
+from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
-import uuid
-import threading
+from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextStreamer, TextIteratorStreamer
from threading import Thread
import queue
-import uvicorn
-from io import StringIO
-from util import format_prompt_llama3, format_prompt, format_prompt_tess, format_prompt_commandr
-from util_merge import ExLlamaV2MergePassthrough
+import traceback
+import re
+
+
+import sys, os
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from exllamav2 import(
+ ExLlamaV2,
+ ExLlamaV2Config,
+ ExLlamaV2Cache,
+ ExLlamaV2Cache_8bit,
+ ExLlamaV2Cache_Q4,
+ ExLlamaV2Tokenizer,
+)
+
+from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2DynamicJob, ExLlamaV2Sampler
+import uuid
+from blessed import Terminal
+import textwrap
+from outlines.integrations.exllamav2 import RegexFilter, TextFilter, JSONFilter, ChoiceFilter
def generate_unique_id():
return uuid.uuid4()
-# This is a demo and small stress to showcase some of the features of the dynamic batching generator.
-repo_str = 'commandr-exl2'
-
class CompletionRequest(BaseModel):
model: str
prompt: Union[str, List[str]]
@@ -77,7 +74,42 @@ class ChatCompletionRequest(BaseModel):
n: Optional[int] = 1 # default value of 1, batch size
top_p: Optional[float] = 0.0 # default value of 0.0
user: Optional[str] = None
+ stop_at: Optional[str] = None
+ outlines_type: Optional[str] = None
+ choices: Optional[list[str]] = None
+ regex: Optional[str] = None
+ json: Optional[str] = None
+ request_id: Optional[str] = None
+ partial_generation: Optional[str] = None
+
+#repo_str = 'theprofessor-exl2-speculative'
+parser = argparse.ArgumentParser(description='Run server with specified port.')
+
+# Add argument for port with default type as integer
+parser.add_argument('--port', type=int, help='Port to run the server on.')
+parser.add_argument('--max_context', type=int, default=8192, help='Context length.')
+parser.add_argument('--repo_str', type=str, default='llama3-70b-instruct', help='The model repository name')
+parser.add_argument('--total_context', type=int, default=32768, help="Total context length")
+parser.add_argument('--max_chunk_size', type=int, default=2048, help='Max chunk size.')
+parser.add_argument('--max_new_tokens', type=int, default=2048, help='Max new tokens.')
+parser.add_argument('--display_mode', type=int, default=1, help='Display mode.')
+parser.add_argument('--use_draft_model', action="store_true", help='Do speculative decoding')
+parser.add_argument('--not_paged', action="store_true", help='Do not do paged attention')
+
+
+
+# Parse the arguments
+args = parser.parse_args()
+repo_str = args.repo_str
+
+config = configparser.ConfigParser()
+config.read('config.ini')
+
+repo_id = config.get(repo_str, 'repo')
+host = config.get('settings', 'host')
+
+port = args.port if args.port is not None else config.getint('settings', 'port')
class StatusArea:
def __init__(self, num_lines):
self.num_lines = num_lines
@@ -169,14 +201,14 @@ def display(self):
if self.console_line is not None:
print(term.move_xy(0, self.console_line) + self.display_text)
-
-parser = argparse.ArgumentParser(description='Run server with specified port.')
-
-# Add argument for port with default type as integer
-parser.add_argument('--port', type=int, help='Port to run the server on.')
-
-# Parse the arguments
-args = parser.parse_args()
+def get_stop_conditions(tokenizer):
+ # get_stop_condition special case if model is llama3
+ if "llama3" in repo_str:
+ return [tokenizer.single_id("<|eot_id|>"), tokenizer.eos_token_id]
+ # elif prompt_format == "granite":
+ # return [tokenizer.eos_token_id, "\n\nQuestion:"]
+ else:
+ return [tokenizer.eos_token_id]
config = configparser.ConfigParser()
config.read('config.ini')
@@ -191,25 +223,25 @@ def display(self):
# 2: Print completions as jobs finish
# 3: Step over output iteration by iteration
# 4: Space heater mode (no output)
-display_mode = 1
+display_mode = args.display_mode
# Whether to use paged mode or not. The generator is very handicapped in unpaged mode, does not support batching
# or CFG, but it will work without flash-attn 2.5.7+
-paged = True
+paged = not args.not_paged
# Where to find our model
model_dir = repo_id
# Total number of tokens to allocate space for. This is not the max_seq_len supported by the model but
# the total to distribute dynamically over however many jobs are active at once
-total_context = 32768
+total_context = total_context = args.total_context
# Max individual context
-max_context = 8192
+max_context = args.max_context
# N-gram or draft model speculative decoding. Largely detrimental to performance at higher batch sizes.
use_ngram = False
-use_draft_model = False
+use_draft_model = args.use_draft_model
if use_draft_model:
model_dir = repo_id
draft_model_dir = specrepo_id
@@ -219,24 +251,14 @@ def display(self):
# Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a
# new job is started, but at the expense of overall prompt ingestion speed.
-max_chunk_size = 2048
+max_chunk_size = args.max_chunk_size
# Max new tokens per completion. For this example applies to all jobs.
-max_new_tokens = 2048
-
-# Use LMFE to constrain the output to JSON format. See schema and details below.
-json_mode = False
+max_new_tokens = args.max_new_tokens
# Demonstrate token healing
healing = True
-# Ban some phrases maybe
-ban_strings = None
-# ban_strings = [
-# "person to person",
-# "one person to another"
-# ]
-
term = Terminal()
@@ -282,11 +304,14 @@ def display(self):
)
model.load_autosplit(cache, progress = True)
-#model.load([16,18,18,20])
# Also, tokenizer
print("Loading tokenizer...")
tokenizer = ExLlamaV2Tokenizer(config)
+hf_tokenizer_kwargs = {}
+hf_tokenizer_kwargs.setdefault("padding_side", "left")
+hf_tokenizer = AutoTokenizer.from_pretrained(model_dir, **hf_tokenizer_kwargs)
+
# Model Merge
@@ -332,48 +357,65 @@ def display(self):
LLM_LINES = max_batch_size
status_area = StatusArea(STATUS_LINES)
displays = {}
+prompt_ids2jobs = {}
+print("*** Loaded.. now Inference...:")
-if json_mode:
- print("Creating jobs... (initializing JSON filters could take a moment.)")
-
-
-def get_stop_conditions(prompt_format, tokenizer):
- if prompt_format == "llama":
- return [tokenizer.eos_token_id]
- elif prompt_format == "llama3":
- return [tokenizer.single_id("<|eot_id|>")]
- elif prompt_format == "granite":
- return [tokenizer.eos_token_id, "\n\nQuestion:"]
-
-
-# Only import lmfe if json_mode is set
-
-if json_mode:
- import json
- from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter
- from lmformatenforcer import JsonSchemaParser
- from exllamav2.generator.filters import ExLlamaV2PrefixFilter
- from pydantic import BaseModel
- from typing import Literal
-
- class JSONResponse(BaseModel):
- response: str
- confidence: Literal["low", "medium", "high"]
- is_subjective: Literal["no", "yes", "possibly"]
-
- schema_parser = JsonSchemaParser(JSONResponse.schema())
-
-
+# take from https://github.com/tiangolo/fastapi/discussions/11360
+class RequestCancelledMiddleware:
+ def __init__(self, app):
+ self.app = app
+
+ async def __call__(self, scope, receive, send):
+ global prompt_ids2jobs
+ if scope["type"] != "http":
+ await self.app(scope, receive, send)
+ return
+
+ # Let's make a shared queue for the request messages
+ queue = asyncio.Queue()
+ cancelled_request_ids = []
+ async def message_poller(sentinel, handler_task):
+ nonlocal queue
+ request_id = str(generate_unique_id())
+ while True:
+ message = await receive()
+ print(message)
+ if "body" in message:
+ message["body"] = json.loads(message["body"].decode('utf8'))
+ message["body"]["request_id"] = request_id
+ message["body"] = str.encode(json.dumps(message["body"]))
+ print(message)
+ if message["type"] == "http.disconnect":
+ cancelled_request_ids.append(request_id)
+ handler_task.cancel()
+ return sentinel # Break the loop
+
+ # Puts the message in the queue
+ await queue.put(message)
+
+ sentinel = object()
+ handler_task = asyncio.create_task(self.app(scope, queue.get, send))
+ asyncio.create_task(message_poller(sentinel, handler_task))
+
+ try:
+ return await handler_task
+ except asyncio.CancelledError:
+ print("Cancelling request due to disconnect")
+ # TODO: FIgure out how to get prompt id that disconnected
+ while len(cancelled_request_ids) > 0:
+ cancelled_id = cancelled_request_ids.pop()
+ generator.cancel(prompt_ids2jobs[cancelled_id])
+ del prompt_ids2jobs[cancelled_id]
-print("*** Loaded.. now Inference...:")
app = FastAPI(title="EXL2")
+app.add_middleware(RequestCancelledMiddleware)
async def stream_response(prompt_id, timeout=180):
global partial_responses
while True:
- await asyncio.sleep(0.001) # Sleep to yield control to the event loop
+ await asyncio.sleep(0.05) # Sleep to yield control to the event loop
# Check if prompt_id exists in partial_responses
if prompt_id in partial_responses:
@@ -388,73 +430,75 @@ async def stream_response(prompt_id, timeout=180):
yield f'data: {{"id":"chatcmpl-{prompt_id}","object":"chat.completion.chunk","created":{int(time.time())},"model":"{repo_str}","choices":[{{"index":0,"delta":{{}},"finish_reason":"stop"}}]}}\n\n'
break
-# Worker thread function
+
def process_prompts():
- global partial_responses
+ global partial_responses
+ global prompt_ids2jobs
+ try:
- # To see what's going on, mode 1
- while True:
- while not prompts.empty() or len(input_ids):
- while len(input_ids) < max_batch_size and not prompts.empty():
- prompt_id, prompt, max_tokens, stream, temperature = prompts.get()
- if json_mode:
- prompt += "\n\n Answer in JSON syntax."
- filters = [
- ExLlamaV2TokenEnforcerFilter(schema_parser, tokenizer),
- ExLlamaV2PrefixFilter(model, tokenizer, "{")
- ]
- else:
- filters = None
- ids = tokenizer.encode(prompt, encode_special_tokens = True)
- prompt_tokens = ids.shape[-1]
- new_tokens = prompt_tokens + max_tokens
- #print("Processing prompt: " + str(prompt_id) + " Req tokens: " + str(new_tokens))
- status_area.update(f"Processing prompt: {prompt_id} Req tokens: {new_tokens}", line=STATUS_LINES-1)
- # Truncate if new_tokens exceed max_context
- if new_tokens > max_context:
- # Calculate how many tokens to truncate
- ids = tokenizer.encode("Say, 'Prompt exceeds allowed length. Please try again.'")
- # Update new_tokens after truncation
+ while True:
+ while not prompts.empty() or len(input_ids):
+ while not prompts.empty():
+ prompt_id, prompt, max_tokens, stream, temperature, outlines_dict = prompts.get()
+ stop_at = outlines_dict.get("stop_at", None)
+ if outlines_dict["type"] == "choices":
+ filters = [ChoiceFilter(outlines_dict["choices"], hf_tokenizer)]
+ elif outlines_dict["type"] == "json":
+ filters = [JSONFilter(outlines_dict["json"], hf_tokenizer)]
+ elif outlines_dict["type"] == "regex":
+ filters = [RegexFilter(outlines_dict["regex"], hf_tokenizer)]
+ else:
+ filters = []
+ ids = tokenizer.encode(prompt, encode_special_tokens = True)
prompt_tokens = ids.shape[-1]
new_tokens = prompt_tokens + max_tokens
- print("Truncating prompt: " + str(prompt_id) + " Req tokens: " + str(new_tokens))
- prompt_length.append(prompt_tokens)
- input_ids.append(ids)
- #streamer.append(stream)
- #prompt_ids.append(prompt_id)
-
- job = ExLlamaV2DynamicJob(
- input_ids = ids,
- max_new_tokens = max_tokens,
- stop_conditions = get_stop_conditions('llama', tokenizer),
- gen_settings = ExLlamaV2Sampler.Settings(),
- banned_strings = ban_strings,
- filters = filters,
- filter_prefer_eos = True,
- token_healing = healing
- )
-
- job.prompt_length = prompt_tokens
- job.input_ids = ids
- job.streamer = stream
- job.prompt_ids = prompt_id
-
- generator.enqueue(job)
- #displays = { job: JobStatusDisplay(job, line, STATUS_LINES) for line, job in enumerate(jobs) }
- displays[job] = JobStatusDisplay(job, STATUS_LINES)
-
- for index, (job, display) in enumerate(list(displays.items())):
- display.update_position(index%LLM_LINES) # Set position before updating
-
-
- if(len(input_ids)):
- #inputs = torch.cat([x[:, -1:] for x in input_ids], dim = 0)
- #logits = model.forward(inputs, caches, input_mask = None).float().cpu()
-
+ #print("Processing prompt: " + str(prompt_id) + " Req tokens: " + str(new_tokens))
+ status_area.update(f"Processing prompt: {prompt_id} Req tokens: {new_tokens}", line=STATUS_LINES-1)
+ # Truncate if new_tokens exceed max_context
+ if new_tokens > max_context:
+ # Calculate how many tokens to truncate
+ ids = tokenizer.encode("Say, 'Prompt exceeds allowed length. Please try again.'")
+ # Update new_tokens after truncation
+ prompt_tokens = ids.shape[-1]
+ new_tokens = prompt_tokens + max_tokens
+ print("Truncating prompt: " + str(prompt_id) + " Req tokens: " + str(new_tokens))
+ prompt_length.append(prompt_tokens)
+ input_ids.append(ids)
+ #streamer.append(stream)
+ #prompt_ids.append(prompt_id)
+
+ preferred_eos = get_stop_conditions(tokenizer)
+
+ if stop_at is not None:
+ preferred_eos.append(stop_at)
+
+ gen_settings = ExLlamaV2Sampler.Settings()
+ gen_settings.temperature = 1.0 if temperature>1 else temperature # To make sure the temperature value does not exceed 1
+
+ job = ExLlamaV2DynamicJob(
+ input_ids = ids,
+ max_new_tokens = max_tokens,
+ stop_conditions = preferred_eos if stop_at is None else [tokenizer.eos_token_id, stop_at],
+ gen_settings = gen_settings,
+ filters = filters,
+ token_healing = healing
+ )
+
+ job.prompt_length = prompt_tokens
+ job.input_ids = ids
+ job.streamer = stream
+ job.prompt_ids = prompt_id
+ job.stop_at = stop_at
+
+ generator.enqueue(job)
+ #displays = { job: JobStatusDisplay(job, line, STATUS_LINES) for line, job in enumerate(jobs) }
+ displays[job] = JobStatusDisplay(job, STATUS_LINES)
+
+ for index, (job, display) in enumerate(list(displays.items())):
+ display.update_position(index%LLM_LINES) # Set position before updating
+ prompt_ids2jobs[prompt_id] = job
results = generator.iterate()
for r in results:
- #for i in range(len(input_ids)):
- #r = results[i]
job = r["job"]
displays[job].update(r)
displays[job].display()
@@ -463,21 +507,23 @@ def process_prompts():
outcontent = r.get("text", "")
reason = None
if(job.streamer):
+ if r["eos"] and job.stop_at is not None:
+ outcontent += job.stop_at
partial_response_data = {
- "id": f"chatcmpl-{job.prompt_ids}",
- "object": "chat.completion.chunk",
- "created": int(time.time()),
- "model": repo_str,
- "choices": [
- {
- "index": 0,
- "delta": {
- "content": outcontent
- },
- "finish_reason": reason
- }
- ]
- }
+ "id": f"chatcmpl-{job.prompt_ids}",
+ "object": "chat.completion.chunk",
+ "created": int(time.time()),
+ "model": repo_str,
+ "choices": [
+ {
+ "index": 0,
+ "delta": {
+ "content": outcontent
+ },
+ "finish_reason": reason
+ }
+ ]
+ }
# Initialize a list for new prompt_id or append to existing one
if job.prompt_ids not in partial_responses:
@@ -497,7 +543,6 @@ def process_prompts():
# Calculate token counts
completion_tokens_old = (tokenizer.encode(generated_text)).shape[-1]
- prompt_tokens_old = (tokenizer.encode(prompt)).shape[-1]
completion_tokens = r['new_tokens']
prompt_tokens = r['prompt_tokens']
@@ -515,6 +560,8 @@ def process_prompts():
responses[eos_prompt_id] = partial_response_data
else:# Construct the response based on the format
+ if job.stop_at is not None:
+ generated_text += job.stop_at
response_data = {
"id": f"chatcmpl-{prompt_id}",
"object": "chat.completion",
@@ -534,56 +581,267 @@ def process_prompts():
"total_tokens": full_tokens
}
}
-
responses[eos_prompt_id] = response_data
+ del prompt_ids2jobs[eos_prompt_id]
- # Clean up
- input_ids.pop()
- prompt_length.pop()
- #streamer.pop(i)
- else:
- # Sleep for a short duration when there's no work
- time.sleep(0.1) # Sleep for 100 milliseconds
+ else:
+ # Sleep for a short duration when there's no work
+ time.sleep(0.1) # Sleep for 100 milliseconds
+ except Exception as e:
+ print("Reset server due to ", e)
+ print(traceback.format_exc())
+ for prompt_id in prompt_ids2jobs:
+ job = prompt_ids2jobs[prompt_id]
+ if(job.streamer):
+ ## Generator, yield here..
+ partial_response_data = {
+ "finish_reason": "stop"
+ }
+
+ responses[prompt_id] = partial_response_data
+ else:
+ print("Error handling for full generation current not implemented")
+ generator.cancel(job)
+ prompt_ids2jobs = {}
# Start worker thread
worker = Thread(target=process_prompts)
worker.start()
+async def format_prompt(messages):
+ formatted_prompt = ""
+ for message in messages:
+ if message.role == "system":
+ formatted_prompt += f"{message.content}\n\n"
+ elif message.role == "user":
+ formatted_prompt += f"### User:\n{message.content}\n\n"
+ elif message.role == "assistant":
+ formatted_prompt += f"### Assistant:\n{message.content}\n\n"
+ # Add the final "### Assistant:\n" to prompt for the next response
+ formatted_prompt += "### Assistant:\n"
+ return formatted_prompt
+
+async def format_prompt_llama3(messages):
+ formatted_prompt = ""
+ system_message_found = False
+
+ # Check for a system message first
+ for message in messages:
+ if message.role == "system":
+ system_message_found = True
+ break
+
+ # If no system message was found, prepend a default one
+ if not system_message_found:
+ formatted_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful AI assistant.<|eot_id|>"
+ for message in messages:
+ if message.role == "system":
+ formatted_prompt += f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{message.content}<|eot_id|>"
+ elif message.role == "user":
+ formatted_prompt += f"<|start_header_id|>user<|end_header_id|>\n\n{message.content}<|eot_id|>"
+ elif message.role == "assistant":
+ formatted_prompt += f"<|start_header_id|>assistant<|end_header_id|>\n\n{message.content}<|eot_id|>"
+ # Add the final "### Assistant:\n" to prompt for the next response
+ formatted_prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n"
+ return formatted_prompt
+
+async def format_prompt_yi(messages):
+ formatted_prompt = ""
+ system_message_found = False
+
+ # Check for a system message first
+ for message in messages:
+ if message.role == "system":
+ system_message_found = True
+ break
+
+ # If no system message was found, prepend a default one
+ if not system_message_found:
+ formatted_prompt = "<|im_start|>system\nYou are a helpful AI assistant.<|im_end|>\n"
+ for message in messages:
+ if message.role == "system":
+ formatted_prompt += f"<|im_start|>system\n{message.content}<|im_end|>\n"
+ elif message.role == "user":
+ formatted_prompt += f"<|im_start|>user\n{message.content}<|im_end|>\n"
+ elif message.role == "assistant":
+ formatted_prompt += f"<|im_start|>assistant\n{message.content}<|im_end|>\n"
+ # Add the final "### Assistant:\n" to prompt for the next response
+ formatted_prompt += "<|im_start|>assistant\n"
+ return formatted_prompt
+
+async def format_prompt_nous(messages):
+ formatted_prompt = ""
+ for message in messages:
+ if message.role == "system":
+ formatted_prompt += f"{message.content}\n"
+ elif message.role == "user":
+ formatted_prompt += f"USER: {message.content}\n"
+ elif message.role == "assistant":
+ formatted_prompt += f"ASSISTANT: {message.content}\n"
+ # Add the final "### Assistant:\n" to prompt for the next response
+ formatted_prompt += "ASSISTANT: "
+ return formatted_prompt
+
+async def format_prompt_tess(messages):
+ formatted_prompt = ""
+ for message in messages:
+ if message.role == "system":
+ formatted_prompt += f"SYSTEM: {message.content}\n"
+ elif message.role == "user":
+ formatted_prompt += f"USER: {message.content}\n"
+ elif message.role == "assistant":
+ formatted_prompt += f"ASSISTANT: {message.content}\n"
+ # Add the final "### Assistant:\n" to prompt for the next response
+ formatted_prompt += "ASSISTANT: "
+ return formatted_prompt
+
+async def format_prompt_code(messages):
+ formatted_prompt = ""
+ for message in messages:
+ if message.role == "system":
+ formatted_prompt += f"### System Prompt\nYou are an intelligent programming assistant.\n\n"
+ elif message.role == "user":
+ formatted_prompt += f"### User Message\n{message.content}\n\n"
+ elif message.role == "assistant":
+ formatted_prompt += f"### Assistant\n{message.content}\n\n"
+ # Add the final "### Assistant" with ellipsis to prompt for the next response
+ formatted_prompt += "### Assistant\n..."
+ return formatted_prompt
+
+async def format_prompt_zephyr(messages):
+ formatted_prompt = ""
+ for message in messages:
+ if message.role == "system":
+ formatted_prompt += f"<|system|>\n{message.content}\n"
+ elif message.role == "user":
+ formatted_prompt += f"<|user|>\n{message.content}\n"
+ elif message.role == "assistant":
+ formatted_prompt += f"<|assistant|>\n{message.content}\n"
+ # Add the final "### Assistant:\n" to prompt for the next response
+ formatted_prompt += "<|assistant|>\n"
+ return formatted_prompt
+
+async def format_prompt_starling(messages):
+ formatted_prompt = ""
+ system_message = ""
+ for message in messages:
+ if message.role == "system":
+ # Save system message to prepend to the first user message
+ system_message += f"{message.content}\n\n"
+ elif message.role == "user":
+ # Prepend system message if it exists
+ if system_message:
+ formatted_prompt += f"GPT4 Correct User: {system_message}{message.content}<|end_of_turn|>"
+ system_message = "" # Clear system message after prepending
+ else:
+ formatted_prompt += f"GPT4 Correct User: {message.content}<|end_of_turn|>"
+ elif message.role == "assistant":
+ formatted_prompt += f"GPT4 Correct Assistant: {message.content}<|end_of_turn|>" # Prep for user follow-up
+ formatted_prompt += "GPT4 Correct Assistant: \n\n"
+ return formatted_prompt
+
+async def format_prompt_mixtral(messages):
+ formatted_prompt = " "
+ system_message = ""
+ for message in messages:
+ if message.role == "system":
+ # Save system message to prepend to the first user message
+ system_message += f"{message.content}\n\n"
+ elif message.role == "user":
+ # Prepend system message if it exists
+ if system_message:
+ formatted_prompt += f"[INST] {system_message}{message.content} [/INST] "
+ system_message = "" # Clear system message after prepending
+ else:
+ formatted_prompt += f"[INST] {message.content} [/INST] "
+ elif message.role == "assistant":
+ formatted_prompt += f" {message.content} " # Prep for user follow-up
+ return formatted_prompt
+
+
+async def format_prompt_commandr(messages):
+ formatted_prompt = ""
+ system_message_found = False
+
+ # Check for a system message first
+ for message in messages:
+ if message.role == "system":
+ system_message_found = True
+ break
+
+ # If no system message was found, prepend a default one
+ if not system_message_found:
+ formatted_prompt += f"<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{message.content}<|END_OF_TURN_TOKEN|>"
+
+ for message in messages:
+ if message.role == "system":
+ formatted_prompt += f"<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{message.content}<|END_OF_TURN_TOKEN|>"
+ elif message.role == "user":
+ formatted_prompt += f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{message.content}<|END_OF_TURN_TOKEN|>"
+ elif message.role == "assistant":
+ formatted_prompt += f"<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{message.content}<|END_OF_TURN_TOKEN|>"
+ # Add the final "### Assistant:\n" to prompt for the next response
+ formatted_prompt += "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
+ return formatted_prompt
+
+
@app.post('/v1/chat/completions')
async def mainchat(request: ChatCompletionRequest):
-
try:
prompt = ''
if repo_str == 'Phind-CodeLlama-34B-v2':
prompt = await format_prompt_code(request.messages)
elif repo_str == 'zephyr-7b-beta':
prompt = await format_prompt_zephyr(request.messages)
- elif repo_str == 'llama3-70b-instruct' or repo_str == 'llama3-70b-instruct-speculative':
+ elif repo_str == 'llama3-70b-instruct':
prompt = await format_prompt_llama3(request.messages)
elif repo_str == 'Starling-LM-7B-alpha':
prompt = await format_prompt_starling(request.messages)
- elif repo_str == 'Mixtral-8x7B-Instruct-v0.1-GPTQ' or repo_str == 'miqu-exl2-speculative':
+ elif repo_str == 'Mixtral-8x7B-Instruct-v0.1-GPTQ':
prompt = await format_prompt_mixtral(request.messages)
- elif repo_str == 'Yi-34B-Chat-GPTQ' or repo_str == 'Nous-Hermes-2-Yi-34B-GPTQ' or repo_str == 'theprofessor-exl2-speculative' or repo_str == 'Yi-34B-Chat':
+ elif repo_str == 'Yi-34B-Chat-GPTQ' or repo_str == 'Nous-Hermes-2-Yi-34B-GPTQ' or repo_str == 'theprofessor-exl2-speculative' or repo_str == 'dbrx-instruct-exl2':
prompt = await format_prompt_yi(request.messages)
elif repo_str == 'Nous-Capybara-34B-GPTQ' or repo_str == 'goliath-120b-GPTQ' or repo_str == 'goliath-120b-exl2' or repo_str == 'goliath-120b-exl2-rpcal':
prompt = await format_prompt_nous(request.messages)
- elif repo_str == 'tess-xl-exl2' or repo_str == 'tess-xl-exl2-speculative' or repo_str == 'venus-exl2-speculative':
+ elif repo_str == 'tess-xl-exl2' or repo_str == 'tess-xl-exl2-speculative':
prompt = await format_prompt_tess(request.messages)
- elif repo_str == 'tinyllama-exl2-speculative':
- prompt = await format_prompt_zephyr(request.messages)
elif repo_str == 'commandr-exl2' or repo_str == 'commandr-exl2-speculative':
prompt = await format_prompt_commandr(request.messages)
else:
prompt = await format_prompt(request.messages)
- status_area.update(f"Prompt: {prompt}")
+ if request.partial_generation is not None:
+ prompt += request.partial_generation
+ print(prompt)
timeout = 180 # seconds
start_time = time.time()
- prompt_id = generate_unique_id() # Replace with a function to generate unique IDs
- prompts.put((prompt_id, prompt, request.max_tokens, request.stream, request.temperature))
+ prompt_id = request.request_id # Replace with a function to generate unique IDs
+ outlines_dict = {}
+
+ # Adjust temperature if it is 0
+ if request.temperature == 0:
+ request.temperature = 0.001
+
+ if request.stop_at is not None:
+ outlines_dict["stop_at"] = request.stop_at
+ if request.outlines_type is not None:
+ outlines_dict["type"] = request.outlines_type
+ else:
+ outlines_dict["type"] = "text"
+ if outlines_dict["type"] == "choices":
+ assert request.choices is not None
+ outlines_dict["choices"] = request.choices
+ elif outlines_dict["type"] == "json":
+ assert request.json is not None
+ outlines_dict["json"] = request.json
+ elif outlines_dict["type"] == "regex":
+ assert request.regex is not None
+ outlines_dict["regex"] = request.regex
+ else:
+ assert outlines_dict["type"] == "text"
+ prompts.put((prompt_id, prompt, request.max_tokens, request.stream, request.temperature, outlines_dict))
if request.stream:
#response = StreamingResponse(streaming_request(prompt, request.max_tokens, tempmodel=repo_str, response_format='chat_completion'), media_type="text/event-stream")
@@ -599,9 +857,11 @@ async def mainchat(request: ChatCompletionRequest):
return responses.pop(prompt_id)
except Exception as e:
+ print(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
- return response
+
+
@app.get('/ping')
@@ -633,10 +893,6 @@ async def get_nvidia_smi():
if __name__ == "__main__":
+ import uvicorn
- uvicorn.run(app, host=host, port=port, log_level="error")
-
- print(term.enter_fullscreen())
-
-
-
+ uvicorn.run(app, host=host, port=port, log_level="debug")