diff --git a/python/examples/healthData/translator.py b/python/examples/healthData/translator.py index d5fffccc..64fd70bb 100644 --- a/python/examples/healthData/translator.py +++ b/python/examples/healthData/translator.py @@ -1,7 +1,7 @@ import json from typing_extensions import TypeVar, Any, override, TypedDict, Literal -from typechat import TypeChatValidator, TypeChatLanguageModel, TypeChatTranslator, Result, Failure +from typechat import TypeChatValidator, TypeChatLanguageModel, TypeChatTranslator, Result, Failure, PromptSection from datetime import datetime @@ -27,8 +27,8 @@ def __init__( self._additional_agent_instructions = additional_agent_instructions @override - async def translate(self, request: str) -> Result[T]: - result = await super().translate(request=request) + async def translate(self, request: str, *, prompt_preamble: str | list[PromptSection] | None = None) -> Result[T]: + result = await super().translate(request=request, prompt_preamble=prompt_preamble) if not isinstance(result, Failure): self._chat_history.append(ChatMessage(source="assistant", body=result.value)) return result diff --git a/python/src/typechat/__init__.py b/python/src/typechat/__init__.py index 6fa8f386..9a0e4be0 100644 --- a/python/src/typechat/__init__.py +++ b/python/src/typechat/__init__.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: MIT -from typechat._internal.model import TypeChatLanguageModel, create_language_model +from typechat._internal.model import PromptSection, TypeChatLanguageModel, create_language_model, create_openai_language_model, create_azure_openai_language_model from typechat._internal.result import Failure, Result, Success from typechat._internal.translator import TypeChatTranslator from typechat._internal.ts_conversion import python_type_to_typescript_schema @@ -17,6 +17,9 @@ "Failure", "Result", "python_type_to_typescript_schema", + "PromptSection", "create_language_model", - "process_requests" + "create_openai_language_model", + "create_azure_openai_language_model", + "process_requests", ] diff --git a/python/src/typechat/_internal/model.py b/python/src/typechat/_internal/model.py index 642f5c69..353ecf0f 100644 --- a/python/src/typechat/_internal/model.py +++ b/python/src/typechat/_internal/model.py @@ -6,8 +6,17 @@ import httpx +class PromptSection(TypedDict): + """ + Represents a section of an LLM prompt with an associated role. TypeChat uses the "user" role for + prompts it generates and the "assistant" role for previous LLM responses (which will be part of + the prompt in repair attempts). TypeChat currently doesn't use the "system" role. + """ + role: Literal["system", "user", "assistant"] + content: str + class TypeChatLanguageModel(Protocol): - async def complete(self, prompt: str) -> Result[str]: + async def complete(self, prompt: str | list[PromptSection]) -> Result[str]: """ Represents a AI language model that can complete prompts. @@ -18,15 +27,6 @@ async def complete(self, prompt: str) -> Result[str]: """ ... -class _PromptSection(TypedDict): - """ - Represents a section of an LLM prompt with an associated role. TypeChat uses the "user" role for - prompts it generates and the "assistant" role for previous LLM responses (which will be part of - the prompt in repair attempts). TypeChat currently doesn't use the "system" role. - """ - role: Literal["system", "user", "assistant"] - content: str - _TRANSIENT_ERROR_CODES = [ 429, 500, @@ -51,15 +51,18 @@ def __init__(self, url: str, headers: dict[str, str], default_params: dict[str, self._async_client = httpx.AsyncClient() @override - async def complete(self, prompt: str) -> Success[str] | Failure: + async def complete(self, prompt: str | list[PromptSection]) -> Success[str] | Failure: headers = { "Content-Type": "application/json", **self.headers, } - messages = [{"role": "user", "content": prompt}] + + if isinstance(prompt, str): + prompt = [{"role": "user", "content": prompt}] + body = { **self.default_params, - "messages": messages, + "messages": prompt, "temperature": 0.0, "n": 1, } @@ -73,7 +76,7 @@ async def complete(self, prompt: str) -> Success[str] | Failure: ) if response.is_success: json_result = cast( - dict[Literal["choices"], list[dict[Literal["message"], _PromptSection]]], + dict[Literal["choices"], list[dict[Literal["message"], PromptSection]]], response.json() ) return Success(json_result["choices"][0]["message"]["content"] or "") diff --git a/python/src/typechat/_internal/translator.py b/python/src/typechat/_internal/translator.py index 1d5891b9..06e24795 100644 --- a/python/src/typechat/_internal/translator.py +++ b/python/src/typechat/_internal/translator.py @@ -1,6 +1,6 @@ from typing_extensions import Generic, TypeVar -from typechat._internal.model import TypeChatLanguageModel +from typechat._internal.model import PromptSection, TypeChatLanguageModel from typechat._internal.result import Failure, Result, Success from typechat._internal.ts_conversion import python_type_to_typescript_schema from typechat._internal.validator import TypeChatValidator @@ -43,10 +43,11 @@ def __init__( if _raise_on_schema_errors and conversion_result.errors: error_text = "".join(f"\n- {error}" for error in conversion_result.errors) raise ValueError(f"Could not convert Python type to TypeScript schema: \n{error_text}") + self._type_name = conversion_result.typescript_type_reference self._schema_str = conversion_result.typescript_schema_str - async def translate(self, request: str) -> Result[T]: + async def translate(self, request: str, *, prompt_preamble: str | list[PromptSection] | None = None) -> Result[T]: """ Translates a natural language request into an object of type `T`. If the JSON object returned by the language model fails to validate, repair attempts will be made up until `_max_repair_attempts`. @@ -55,11 +56,22 @@ async def translate(self, request: str) -> Result[T]: Args: request: A natural language request. + prompt_preamble: An optional string or list of prompt sections to prepend to the generated prompt.\ + If a string is given, it is converted to a single "user" role prompt section. """ request = self._create_request_prompt(request) + + prompt: str | list[PromptSection] + if prompt_preamble is None: + prompt = request + else: + if isinstance(prompt_preamble, str): + prompt_preamble = [{"role": "user", "content": prompt_preamble}] + prompt = [*prompt_preamble, {"role": "user", "content": request}] + num_repairs_attempted = 0 while True: - completion_response = await self.model.complete(request) + completion_response = await self.model.complete(prompt) if isinstance(completion_response, Failure): return completion_response