diff --git a/src/generative_ai_toolkit/agent/bedrock_converse_agent.py b/src/generative_ai_toolkit/agent/bedrock_converse_agent.py index 8238916..9c40e20 100644 --- a/src/generative_ai_toolkit/agent/bedrock_converse_agent.py +++ b/src/generative_ai_toolkit/agent/bedrock_converse_agent.py @@ -794,6 +794,7 @@ def converse( user_input: str | None, tools: Sequence[Callable | Tool] | None = None, stop_event: Event | None = None, + enable_cache: bool | None = False, ) -> str: """ Start or continue a conversation with the agent and return the agent's response as string. @@ -847,16 +848,33 @@ def converse( raise ValueError("Missing user input") if user_input is not None: - self._add_message( - { - "role": "user", - "content": [ - { - "text": user_input, - }, - ], - }, - ) + if not enable_cache: + self._add_message( + { + "role": "user", + "content": [ + { + "text": user_input, + }, + ], + }, + ) + else: + self._add_message( + { + "role": "user", + "content": [ + { + "text": user_input, + }, + { + "cachePoint": { + "type": "default" + } + }, + ], + }, + ) request: ConverseRequestTypeDef = { "modelId": self.model_id, @@ -871,6 +889,12 @@ def converse( "text": self.system_prompt, }, ] + if enable_cache: + request["system"].append({ + "cachePoint": { + "type": "default" + } + }) if self.default_model_request_fields: request["additionalModelRequestFields"] = self.default_model_request_fields if self.default_model_response_field_paths: @@ -1120,6 +1144,7 @@ def converse_stream( stream: Literal["traces"], tools: Sequence[Callable | Tool] | None = None, stop_event: Event | None = None, + enable_cache: bool | None = False, ) -> Iterable[Trace]: """ Start or continue a conversation with the agent. @@ -1158,9 +1183,10 @@ def converse_stream( stream: Literal["traces"] | Literal["text"] = "text", tools: Sequence[Callable | Tool] | None = None, stop_event: Event | None = None, + enable_cache: bool | None = False, ) -> Iterable[str] | Iterable[Trace]: gen = self._converse_stream( - user_input=user_input, tools=tools, stop_event=stop_event + user_input=user_input, tools=tools, stop_event=stop_event, enable_cache=enable_cache ) if stream == "text": yield from gen @@ -1195,6 +1221,7 @@ def _converse_stream( user_input: str | None, tools: Sequence[Callable | Tool] | None = None, stop_event: Event | None = None, + enable_cache: bool | None = False, ) -> Iterable[str]: current_trace = self._tracer.current_trace current_trace.add_attribute("ai.trace.type", "converse-stream") @@ -1213,23 +1240,40 @@ def _converse_stream( current_trace.emit_snapshot() if self.converse_implementation == "converse": - yield self.converse(user_input, tools=tools, stop_event=stop_event) + yield self.converse(user_input, tools=tools, stop_event=stop_event, enable_cache=enable_cache) return if user_input == "": raise ValueError("Missing user input") - if user_input is not None: - self._add_message( - { - "role": "user", - "content": [ - { - "text": user_input, - }, - ], - }, - ) + if user_input is not None : + if not enable_cache : + self._add_message( + { + "role": "user", + "content": [ + { + "text": user_input, + }, + ], + }, + ) + else: + self._add_message( + { + "role": "user", + "content": [ + { + "text": user_input, + }, + { + "cachePoint": { + "type": "default" + } + }, + ], + }, + ) request: ConverseStreamRequestTypeDef = { "modelId": self.model_id, @@ -1244,6 +1288,13 @@ def _converse_stream( "text": self.system_prompt, }, ] + + if enable_cache: + request["system"].append({ + "cachePoint": { + "type": "default" + } + }) if self.default_model_request_fields: request["additionalModelRequestFields"] = self.default_model_request_fields if self.default_model_response_field_paths: