Skip to content

Commit 182ebbd

Browse files
authored
Merge pull request #259 from AI21Labs/thread-no-msgs
chore: ♻️ thread messages should be optional
2 parents 0195f62 + 7198db9 commit 182ebbd

File tree

2 files changed

+19
-8
lines changed

2 files changed

+19
-8
lines changed
+13-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from __future__ import annotations
22

33
from abc import ABC, abstractmethod
4-
from typing import List
4+
from typing import List, Optional
55

66
from ai21.clients.common.beta.assistant.messages import BaseMessages
7-
from ai21.models.assistant.message import Message
7+
from ai21.models.assistant.message import Message, modify_message_content
88
from ai21.models.responses.thread_response import ThreadResponse
9+
from ai21.types import NOT_GIVEN, NotGiven
10+
from ai21.utils.typing import remove_not_given
911

1012

1113
class BaseThreads(ABC):
@@ -15,11 +17,19 @@ class BaseThreads(ABC):
1517
@abstractmethod
1618
def create(
1719
self,
18-
messages: List[Message],
20+
messages: List[Message] | NotGiven = NOT_GIVEN,
1921
**kwargs,
2022
) -> ThreadResponse:
2123
pass
2224

25+
def _create_body(self, messages: List[Message] | NotGiven, **kwargs) -> Optional[dict]:
26+
body = remove_not_given({"messages": messages, **kwargs})
27+
28+
if "messages" in body:
29+
body["messages"] = [modify_message_content(message) for message in body["messages"]]
30+
31+
return body
32+
2333
@abstractmethod
2434
def retrieve(self, thread_id: str) -> ThreadResponse:
2535
pass

ai21/clients/studio/resources/beta/assistant/thread.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource
99
from ai21.http_client.async_http_client import AsyncAI21HTTPClient
1010
from ai21.http_client.http_client import AI21HTTPClient
11-
from ai21.models.assistant.message import Message, modify_message_content
11+
from ai21.models.assistant.message import Message
1212
from ai21.models.responses.thread_response import ThreadResponse
13+
from ai21.types import NOT_GIVEN, NotGiven
1314

1415

1516
class Threads(StudioResource, BaseThreads):
@@ -21,10 +22,10 @@ def __init__(self, client: AI21HTTPClient):
2122

2223
def create(
2324
self,
24-
messages: List[Message],
25+
messages: List[Message] | NotGiven = NOT_GIVEN,
2526
**kwargs,
2627
) -> ThreadResponse:
27-
body = dict(messages=[modify_message_content(message) for message in messages])
28+
body = self._create_body(messages=messages, **kwargs)
2829

2930
return self._post(path=f"/{self._module_name}", body=body, response_cls=ThreadResponse)
3031

@@ -41,10 +42,10 @@ def __init__(self, client: AsyncAI21HTTPClient):
4142

4243
async def create(
4344
self,
44-
messages: List[Message],
45+
messages: List[Message] | NotGiven = NOT_GIVEN,
4546
**kwargs,
4647
) -> ThreadResponse:
47-
body = dict(messages=[modify_message_content(message) for message in messages])
48+
body = self._create_body(messages=messages, **kwargs)
4849

4950
return await self._post(path=f"/{self._module_name}", body=body, response_cls=ThreadResponse)
5051

0 commit comments

Comments
 (0)