Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from collections.abc import Sequence
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Protocol, Unpack
from typing import TYPE_CHECKING, Any, Protocol, Unpack, runtime_checkable

import boto3
import boto3.session
Expand Down Expand Up @@ -51,6 +51,7 @@ class ConversationPage:
next_page_token: Any | None = None


@runtime_checkable
class ConversationList(Protocol):

@property
Expand Down
52 changes: 47 additions & 5 deletions src/generative_ai_toolkit/ui/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class TraceSummary:
trace_id: str
span_id: str
started_at: datetime
ended_at: datetime | None
duration_ms: int | None
conversation_id: str
auth_context: AuthContext = field(default_factory=lambda: {"principal_id": None})
Expand Down Expand Up @@ -70,6 +71,7 @@ def get_summaries_for_traces(traces: Sequence[Trace]):
span_id=root_trace.span_id,
duration_ms=root_trace.ended_at and root_trace.duration_ms,
started_at=root_trace.started_at,
ended_at=root_trace.ended_at,
all_traces=traces_for_trace_id,
agent_cycle_traces={
trace.span_id: trace
Expand Down Expand Up @@ -475,7 +477,11 @@ def chat_messages_from_trace_summary(
)
if cycle_response:
metadata = metadata.copy()
metadata.pop("status", None)
if not trace.ended_at:
metadata["status"] = "pending"
elif metadata.get("status") == "done":
# Always fold open
metadata.pop("status")
chat_messages.append(
gr.ChatMessage(
role="assistant",
Expand Down Expand Up @@ -537,7 +543,11 @@ def chat_messages_from_trace_summary(
agent_response = trace.attributes.get("ai.agent.cycle.response")
if agent_response:
metadata = get_metadata(trace)
metadata.pop("status", None)
if not trace.ended_at:
metadata["status"] = "pending"
elif metadata.get("status") == "done":
# Always fold open
metadata.pop("status")
chat_messages.append(
gr.ChatMessage(
role="assistant",
Expand All @@ -553,20 +563,29 @@ def chat_messages_from_trace_summary(
return chat_messages


@dataclass
class ChatMessages:
conversation_id: str
principal_id: str | None
messages: Sequence[gr.ChatMessage]
assistant_busy: bool


def chat_messages_from_traces(
traces: Iterable[Trace],
show_traces: Literal["ALL", "CORE", "CONVERSATION_ONLY"] = "CORE",
):
traces = list(traces)
if not traces:
return None, None, []
return ChatMessages("", None, [], False)
summaries = get_summaries_for_traces(traces)
conversations = {
(s.conversation_id, s.auth_context["principal_id"]) for s in summaries
}
if len(conversations) > 1:
raise ValueError("More than one conversation id found")
conversation_id, auth_context = conversations.pop()
conversation_id, principal_id = conversations.pop()
assistant_busy = not bool(summaries and summaries[-1].ended_at)
messages = [
msg
for summary in summaries
Expand All @@ -575,7 +594,7 @@ def chat_messages_from_traces(
include_traces=show_traces,
)
]
return conversation_id, auth_context, messages
return ChatMessages(conversation_id, principal_id, messages, assistant_busy)


def chat_messages_from_conversation_measurements(
Expand Down Expand Up @@ -679,3 +698,26 @@ def format_date(dt: datetime):
) # "Today" / "Yesterday" / "Monday"

return f"{day_text} at {dt.strftime("%X")}"


def find_nearest_folded_open_message(messages: Sequence[gr.ChatMessage]):
search_from = 0
message = messages[-1]
while message:
message_parent_id = message.metadata.get("parent_id")
if message.metadata.get("status") != "done": # Folded open!
return message.metadata.get("id")
elif message_parent_id:
offset, message = next(
(
enumerate(
msg
for msg in reversed(messages[: len(messages) - search_from])
if msg.metadata.get("id") == message_parent_id
)
),
(-1, None),
)
search_from += offset
continue
return
Loading