Skip to content

Commit

Permalink
Merge pull request #585 from pepebruari/main
Browse files Browse the repository at this point in the history
Add --system-prompt to exo cli
  • Loading branch information
AlexCheema authored Jan 3, 2025
2 parents 03aa6ce + fe50d4d commit 7b16561
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
7 changes: 6 additions & 1 deletion exo/api/chatgpt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def __init__(self, request_id: str, timestamp: int, prompt: str):
self.prompt = prompt

class ChatGPTAPI:
def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None):
def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None, system_prompt: Optional[str] = None):
self.node = node
self.inference_engine_classname = inference_engine_classname
self.response_timeout = response_timeout
Expand All @@ -170,6 +170,7 @@ def __init__(self, node: Node, inference_engine_classname: str, response_timeout
self.prev_token_lens: Dict[str, int] = {}
self.stream_tasks: Dict[str, asyncio.Task] = {}
self.default_model = default_model or "llama-3.2-1b"
self.system_prompt = system_prompt

cors = aiohttp_cors.setup(self.app)
cors_options = aiohttp_cors.ResourceOptions(
Expand Down Expand Up @@ -336,6 +337,10 @@ async def handle_post_chat_completions(self, request):
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")

# Add system prompt if set
if self.system_prompt and not any(msg.role == "system" for msg in chat_request.messages):
chat_request.messages.insert(0, Message("system", self.system_prompt))

prompt = build_prompt(tokenizer, chat_request.messages, chat_request.tools)
request_id = str(uuid.uuid4())
if self.on_chat_completion_request:
Expand Down
4 changes: 3 additions & 1 deletion exo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
parser.add_argument("--node-id-filter", type=str, default=None, help="Comma separated list of allowed node IDs (only for UDP and Tailscale discovery)")
parser.add_argument("--system-prompt", type=str, default=None, help="System prompt for the ChatGPT API")
args = parser.parse_args()
print(f"Selected inference engine: {args.inference_engine}")

Expand Down Expand Up @@ -146,7 +147,8 @@
inference_engine.__class__.__name__,
response_timeout=args.chatgpt_api_response_timeout,
on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None,
default_model=args.default_model
default_model=args.default_model,
system_prompt=args.system_prompt
)
node.on_token.register("update_topology_viz").on_next(
lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None
Expand Down

0 comments on commit 7b16561

Please sign in to comment.