Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 74 additions & 23 deletions src/generative_ai_toolkit/agent/bedrock_converse_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,7 @@ def converse(
user_input: str | None,
tools: Sequence[Callable | Tool] | None = None,
stop_event: Event | None = None,
enable_cache: bool | None = False,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: should we do

enable_cache: bool | None = None,

so you can see the difference between explcitly turned of and not set? Otherwise, it could be

enable_cache: bool = False,

) -> str:
"""
Start or continue a conversation with the agent and return the agent's response as string.
Expand Down Expand Up @@ -847,16 +848,33 @@ def converse(
raise ValueError("Missing user input")

if user_input is not None:
self._add_message(
{
"role": "user",
"content": [
{
"text": user_input,
},
],
},
)
if not enable_cache:
self._add_message(
{
"role": "user",
"content": [
{
"text": user_input,
},
],
},
)
else:
self._add_message(
{
"role": "user",
"content": [
{
"text": user_input,
},
{
"cachePoint": {
"type": "default"
}
},
],
},
)

request: ConverseRequestTypeDef = {
"modelId": self.model_id,
Expand All @@ -871,6 +889,12 @@ def converse(
"text": self.system_prompt,
},
]
if enable_cache:
request["system"].append({
"cachePoint": {
"type": "default"
}
})
if self.default_model_request_fields:
request["additionalModelRequestFields"] = self.default_model_request_fields
if self.default_model_response_field_paths:
Expand Down Expand Up @@ -1120,6 +1144,7 @@ def converse_stream(
stream: Literal["traces"],
tools: Sequence[Callable | Tool] | None = None,
stop_event: Event | None = None,
enable_cache: bool | None = False,
) -> Iterable[Trace]:
"""
Start or continue a conversation with the agent.
Expand Down Expand Up @@ -1158,9 +1183,10 @@ def converse_stream(
stream: Literal["traces"] | Literal["text"] = "text",
tools: Sequence[Callable | Tool] | None = None,
stop_event: Event | None = None,
enable_cache: bool | None = False,
) -> Iterable[str] | Iterable[Trace]:
gen = self._converse_stream(
user_input=user_input, tools=tools, stop_event=stop_event
user_input=user_input, tools=tools, stop_event=stop_event, enable_cache=enable_cache
)
if stream == "text":
yield from gen
Expand Down Expand Up @@ -1195,6 +1221,7 @@ def _converse_stream(
user_input: str | None,
tools: Sequence[Callable | Tool] | None = None,
stop_event: Event | None = None,
enable_cache: bool | None = False,
) -> Iterable[str]:
current_trace = self._tracer.current_trace
current_trace.add_attribute("ai.trace.type", "converse-stream")
Expand All @@ -1213,23 +1240,40 @@ def _converse_stream(
current_trace.emit_snapshot()

if self.converse_implementation == "converse":
yield self.converse(user_input, tools=tools, stop_event=stop_event)
yield self.converse(user_input, tools=tools, stop_event=stop_event, enable_cache=enable_cache)
return

if user_input == "":
raise ValueError("Missing user input")

if user_input is not None:
self._add_message(
{
"role": "user",
"content": [
{
"text": user_input,
},
],
},
)
if user_input is not None :
if not enable_cache :
self._add_message(
{
"role": "user",
"content": [
{
"text": user_input,
},
],
},
)
else:
self._add_message(
{
"role": "user",
"content": [
{
"text": user_input,
},
{
"cachePoint": {
"type": "default"
}
},
],
},
)

request: ConverseStreamRequestTypeDef = {
"modelId": self.model_id,
Expand All @@ -1244,6 +1288,13 @@ def _converse_stream(
"text": self.system_prompt,
},
]

if enable_cache:
request["system"].append({
"cachePoint": {
"type": "default"
}
})
if self.default_model_request_fields:
request["additionalModelRequestFields"] = self.default_model_request_fields
if self.default_model_response_field_paths:
Expand Down