diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index 61bff54b..3c540d54 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -211,6 +211,7 @@ def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]: response = self._invoke_bedrock(chat_request, stream=True) message_id = self.generate_message_id() stream = response.get("stream") + self.think_emitted = False for chunk in stream: stream_response = self._create_response_stream( model_id=chat_request.model, message_id=message_id, chunk=chunk @@ -235,6 +236,7 @@ def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]: # return an [DONE] message at the end. yield self.stream_response_to_bytes() + self.think_emitted = False # Cleanup def _parse_system_prompts(self, chat_request: ChatRequest) -> list[dict[str, str]]: """Create system prompts. @@ -498,6 +500,9 @@ def _create_response( message.content = c["text"] else: logger.warning("Unknown tag in message content " + ",".join(c.keys())) + if message.reasoning_content: + message.content = f"{message.reasoning_content}{message.content}" + message.reasoning_content = None response = ChatResponse( id=message_id, @@ -566,11 +571,19 @@ def _create_response_stream( content=delta["text"], ) elif "reasoningContent" in delta: - # ignore "signature" in the delta. if "text" in delta["reasoningContent"]: - message = ChatResponseMessage( - reasoning_content=delta["reasoningContent"]["text"], - ) + content = delta["reasoningContent"]["text"] + if not self.think_emitted: + # Port of "content_block_start" with "thinking" + content = "" + content + self.think_emitted = True + message = ChatResponseMessage(content=content) + elif "signature" in delta["reasoningContent"]: + # Port of "signature_delta" + if self.think_emitted: + message = ChatResponseMessage(content="\n \n\n") + else: + return None # Ignore signature if no started else: # tool use index = chunk["contentBlockDelta"]["contentBlockIndex"] - 1