Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions verifiers/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
Awaitable,
Callable,
Literal,
TypeAlias,
Union,
)

if sys.version_info < (3, 12):
Expand All @@ -13,8 +15,15 @@
from typing import TypedDict

from openai import AsyncOpenAI
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionDeveloperMessageParam,
ChatCompletionFunctionMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionToolMessageParam,
ChatCompletionUserMessageParam,
)
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam

# openai types
from openai.types.chat.chat_completion_message_tool_call import (
Expand All @@ -31,8 +40,24 @@
)
from pydantic import BaseModel


class ChatCompletionAssistantMessageParamWithReasoning(
ChatCompletionAssistantMessageParam
):
reasoning_content: str | None


ChatCompletionMessageParamWithReasoning: TypeAlias = Union[
ChatCompletionDeveloperMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionUserMessageParam,
ChatCompletionAssistantMessageParamWithReasoning,
ChatCompletionToolMessageParam,
ChatCompletionFunctionMessageParam,
]

# typing aliases
ChatMessage = ChatCompletionMessageParam
ChatMessage = ChatCompletionMessageParamWithReasoning
MessageType = Literal["chat", "completion"]
ModelResponse = Completion | ChatCompletion | None

Expand Down
21 changes: 15 additions & 6 deletions verifiers/utils/response_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,17 +89,26 @@ async def parse_response_tokens(


async def parse_response_messages(
response: ModelResponse, message_type: MessageType
response: ModelResponse,
message_type: MessageType,
reasoning_field: str = "reasoning_content",
) -> Messages:
response_text = ""
content = ""
reasoning_content = ""
if message_type == "chat":
assert isinstance(response, ChatCompletion)
if response.choices and response.choices[0].message:
response_text = response.choices[0].message.content or ""
content = response.choices[0].message.content or ""
reasoning_content = getattr(
response.choices[0].message, reasoning_field, ""
)

response_message: ChatMessage = {
"role": "assistant",
"content": response_text,
"content": content,
"reasoning_content": reasoning_content,
}

if (
response.choices
and response.choices[0].message
Expand All @@ -113,6 +122,6 @@ async def parse_response_messages(
else:
assert isinstance(response, Completion)
if response.choices and response.choices[0]:
response_text = response.choices[0].text or ""
completion_messages = str(response_text)
content = response.choices[0].text or ""
completion_messages = str(content)
return completion_messages
Loading