Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow models to take lists of PromptSections, allow translators to take preambles. #203

Merged
merged 3 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/examples/healthData/translator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from typing_extensions import TypeVar, Any, override, TypedDict, Literal

from typechat import TypeChatValidator, TypeChatLanguageModel, TypeChatTranslator, Result, Failure
from typechat import TypeChatValidator, TypeChatLanguageModel, TypeChatTranslator, Result, Failure, PromptSection

from datetime import datetime

Expand All @@ -27,8 +27,8 @@ def __init__(
self._additional_agent_instructions = additional_agent_instructions

@override
async def translate(self, request: str) -> Result[T]:
result = await super().translate(request=request)
async def translate(self, request: str, *, prompt_preamble: str | list[PromptSection] | None = None) -> Result[T]:
result = await super().translate(request=request, prompt_preamble=prompt_preamble)
if not isinstance(result, Failure):
self._chat_history.append(ChatMessage(source="assistant", body=result.value))
return result
Expand Down
7 changes: 5 additions & 2 deletions python/src/typechat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: MIT

from typechat._internal.model import TypeChatLanguageModel, create_language_model
from typechat._internal.model import PromptSection, TypeChatLanguageModel, create_language_model, create_openai_language_model, create_azure_openai_language_model
from typechat._internal.result import Failure, Result, Success
from typechat._internal.translator import TypeChatTranslator
from typechat._internal.ts_conversion import python_type_to_typescript_schema
Expand All @@ -17,6 +17,9 @@
"Failure",
"Result",
"python_type_to_typescript_schema",
"PromptSection",
"create_language_model",
"process_requests"
"create_openai_language_model",
"create_azure_openai_language_model",
"process_requests",
]
31 changes: 17 additions & 14 deletions python/src/typechat/_internal/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,17 @@

import httpx

class PromptSection(TypedDict):
"""
Represents a section of an LLM prompt with an associated role. TypeChat uses the "user" role for
prompts it generates and the "assistant" role for previous LLM responses (which will be part of
the prompt in repair attempts). TypeChat currently doesn't use the "system" role.
"""
role: Literal["system", "user", "assistant"]
content: str

class TypeChatLanguageModel(Protocol):
async def complete(self, prompt: str) -> Result[str]:
async def complete(self, prompt: str | list[PromptSection]) -> Result[str]:
"""
Represents a AI language model that can complete prompts.

Expand All @@ -18,15 +27,6 @@ async def complete(self, prompt: str) -> Result[str]:
"""
...

class _PromptSection(TypedDict):
"""
Represents a section of an LLM prompt with an associated role. TypeChat uses the "user" role for
prompts it generates and the "assistant" role for previous LLM responses (which will be part of
the prompt in repair attempts). TypeChat currently doesn't use the "system" role.
"""
role: Literal["system", "user", "assistant"]
content: str

_TRANSIENT_ERROR_CODES = [
429,
500,
Expand All @@ -51,15 +51,18 @@ def __init__(self, url: str, headers: dict[str, str], default_params: dict[str,
self._async_client = httpx.AsyncClient()

@override
async def complete(self, prompt: str) -> Success[str] | Failure:
async def complete(self, prompt: str | list[PromptSection]) -> Success[str] | Failure:
headers = {
"Content-Type": "application/json",
**self.headers,
}
messages = [{"role": "user", "content": prompt}]

if isinstance(prompt, str):
prompt = [{"role": "user", "content": prompt}]

body = {
**self.default_params,
"messages": messages,
"messages": prompt,
"temperature": 0.0,
"n": 1,
}
Expand All @@ -73,7 +76,7 @@ async def complete(self, prompt: str) -> Success[str] | Failure:
)
if response.is_success:
json_result = cast(
dict[Literal["choices"], list[dict[Literal["message"], _PromptSection]]],
dict[Literal["choices"], list[dict[Literal["message"], PromptSection]]],
response.json()
)
return Success(json_result["choices"][0]["message"]["content"] or "")
Expand Down
18 changes: 15 additions & 3 deletions python/src/typechat/_internal/translator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing_extensions import Generic, TypeVar

from typechat._internal.model import TypeChatLanguageModel
from typechat._internal.model import PromptSection, TypeChatLanguageModel
from typechat._internal.result import Failure, Result, Success
from typechat._internal.ts_conversion import python_type_to_typescript_schema
from typechat._internal.validator import TypeChatValidator
Expand Down Expand Up @@ -43,10 +43,11 @@ def __init__(
if _raise_on_schema_errors and conversion_result.errors:
error_text = "".join(f"\n- {error}" for error in conversion_result.errors)
raise ValueError(f"Could not convert Python type to TypeScript schema: \n{error_text}")

self._type_name = conversion_result.typescript_type_reference
self._schema_str = conversion_result.typescript_schema_str

async def translate(self, request: str) -> Result[T]:
async def translate(self, request: str, *, prompt_preamble: str | list[PromptSection] | None = None) -> Result[T]:
"""
Translates a natural language request into an object of type `T`. If the JSON object returned by
the language model fails to validate, repair attempts will be made up until `_max_repair_attempts`.
Expand All @@ -55,11 +56,22 @@ async def translate(self, request: str) -> Result[T]:

Args:
request: A natural language request.
prompt_preamble: An optional string or list of prompt sections to prepend to the generated prompt.\
If a string is given, it is converted to a single "user" role prompt section.
"""
request = self._create_request_prompt(request)

prompt: str | list[PromptSection]
if prompt_preamble is None:
prompt = request
else:
if isinstance(prompt_preamble, str):
prompt_preamble = [{"role": "user", "content": prompt_preamble}]
prompt = [*prompt_preamble, {"role": "user", "content": request}]

num_repairs_attempted = 0
while True:
completion_response = await self.model.complete(request)
completion_response = await self.model.complete(prompt)
if isinstance(completion_response, Failure):
return completion_response

Expand Down
Loading