diff --git a/ai21/clients/common/beta/assistant/threads.py b/ai21/clients/common/beta/assistant/threads.py index ff67ed60..68cc9460 100644 --- a/ai21/clients/common/beta/assistant/threads.py +++ b/ai21/clients/common/beta/assistant/threads.py @@ -1,11 +1,13 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import List +from typing import List, Optional from ai21.clients.common.beta.assistant.messages import BaseMessages -from ai21.models.assistant.message import Message +from ai21.models.assistant.message import Message, modify_message_content from ai21.models.responses.thread_response import ThreadResponse +from ai21.types import NOT_GIVEN, NotGiven +from ai21.utils.typing import remove_not_given class BaseThreads(ABC): @@ -15,11 +17,19 @@ class BaseThreads(ABC): @abstractmethod def create( self, - messages: List[Message], + messages: List[Message] | NotGiven = NOT_GIVEN, **kwargs, ) -> ThreadResponse: pass + def _create_body(self, messages: List[Message] | NotGiven, **kwargs) -> Optional[dict]: + body = remove_not_given({"messages": messages, **kwargs}) + + if "messages" in body: + body["messages"] = [modify_message_content(message) for message in body["messages"]] + + return body + @abstractmethod def retrieve(self, thread_id: str) -> ThreadResponse: pass diff --git a/ai21/clients/studio/resources/beta/assistant/thread.py b/ai21/clients/studio/resources/beta/assistant/thread.py index 54ef0c1d..91279317 100644 --- a/ai21/clients/studio/resources/beta/assistant/thread.py +++ b/ai21/clients/studio/resources/beta/assistant/thread.py @@ -8,8 +8,9 @@ from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource from ai21.http_client.async_http_client import AsyncAI21HTTPClient from ai21.http_client.http_client import AI21HTTPClient -from ai21.models.assistant.message import Message, modify_message_content +from ai21.models.assistant.message import Message from ai21.models.responses.thread_response import ThreadResponse +from ai21.types import NOT_GIVEN, NotGiven class Threads(StudioResource, BaseThreads): @@ -21,10 +22,10 @@ def __init__(self, client: AI21HTTPClient): def create( self, - messages: List[Message], + messages: List[Message] | NotGiven = NOT_GIVEN, **kwargs, ) -> ThreadResponse: - body = dict(messages=[modify_message_content(message) for message in messages]) + body = self._create_body(messages=messages, **kwargs) return self._post(path=f"/{self._module_name}", body=body, response_cls=ThreadResponse) @@ -41,10 +42,10 @@ def __init__(self, client: AsyncAI21HTTPClient): async def create( self, - messages: List[Message], + messages: List[Message] | NotGiven = NOT_GIVEN, **kwargs, ) -> ThreadResponse: - body = dict(messages=[modify_message_content(message) for message in messages]) + body = self._create_body(messages=messages, **kwargs) return await self._post(path=f"/{self._module_name}", body=body, response_cls=ThreadResponse)