diff --git a/temporalio/contrib/openai_agents/_trace_interceptor.py b/temporalio/contrib/openai_agents/_trace_interceptor.py index 483e01147..b96f2f30d 100644 --- a/temporalio/contrib/openai_agents/_trace_interceptor.py +++ b/temporalio/contrib/openai_agents/_trace_interceptor.py @@ -3,13 +3,15 @@ from __future__ import annotations from contextlib import contextmanager -from typing import Any, Mapping, Protocol, Type +from typing import Any, Mapping, Protocol, Type, cast -from agents import custom_span, get_current_span, trace +from agents import CustomSpanData, custom_span, get_current_span, trace from agents.tracing import ( get_trace_provider, ) -from agents.tracing.spans import NoOpSpan +from agents.tracing.provider import DefaultTraceProvider +from agents.tracing.scope import Scope +from agents.tracing.spans import NoOpSpan, SpanImpl import temporalio.activity import temporalio.api.common.v1 @@ -65,11 +67,15 @@ def context_from_header( else workflow.info().workflow_type ) data = ( - {"activityId": activity.info().activity_id} + { + "activityId": activity.info().activity_id, + "activity": activity.info().activity_type, + } if activity.in_activity() else None ) - if get_trace_provider().get_current_trace() is None: + current_trace = get_trace_provider().get_current_trace() + if current_trace is None: metadata = { "temporal:workflowId": activity.info().workflow_id if activity.in_activity() @@ -79,16 +85,21 @@ def context_from_header( else workflow.info().run_id, "temporal:workflowType": workflow_type, } - with trace( + current_trace = trace( span_info["traceName"], trace_id=span_info["traceId"], metadata=metadata, - ) as t: - with custom_span(name=span_name, parent=t, data=data): - yield - else: - with custom_span(name=span_name, parent=None, data=data): - yield + ) + Scope.set_current_trace(current_trace) + current_span = get_trace_provider().get_current_span() + if current_span is None: + current_span = get_trace_provider().create_span( + span_data=CustomSpanData(name="", data={}), span_id=span_info["spanId"] + ) + Scope.set_current_span(current_span) + + with custom_span(name=span_name, parent=current_span, data=data): + yield class OpenAIAgentsTracingInterceptor( @@ -115,7 +126,7 @@ class OpenAIAgentsTracingInterceptor( worker = Worker(client, task_queue="my-task-queue", interceptors=[interceptor]) """ - def __init__( # type: ignore[reportMissingSuperCall] + def __init__( self, payload_converter: temporalio.converter.PayloadConverter = temporalio.converter.default().payload_converter, ) -> None: @@ -325,32 +336,55 @@ class _ContextPropagationWorkflowOutboundInterceptor( async def signal_child_workflow( self, input: temporalio.worker.SignalChildWorkflowInput ) -> None: - set_header_from_context(input, temporalio.workflow.payload_converter()) - return await self.next.signal_child_workflow(input) + with custom_span( + name="temporal:signalChildWorkflow", + data={"workflowId": input.child_workflow_id}, + ): + set_header_from_context(input, temporalio.workflow.payload_converter()) + await self.next.signal_child_workflow(input) async def signal_external_workflow( self, input: temporalio.worker.SignalExternalWorkflowInput ) -> None: - set_header_from_context(input, temporalio.workflow.payload_converter()) - return await self.next.signal_external_workflow(input) + with custom_span( + name="temporal:signalExternalWorkflow", + data={"workflowId": input.workflow_id}, + ): + set_header_from_context(input, temporalio.workflow.payload_converter()) + await self.next.signal_external_workflow(input) def start_activity( self, input: temporalio.worker.StartActivityInput ) -> temporalio.workflow.ActivityHandle: - with custom_span( - name=f"temporal:startActivity:{input.activity}", - ): - set_header_from_context(input, temporalio.workflow.payload_converter()) - return self.next.start_activity(input) + span = custom_span( + name="temporal:startActivity", data={"activity": input.activity} + ) + span.start(mark_as_current=True) + set_header_from_context(input, temporalio.workflow.payload_converter()) + handle = self.next.start_activity(input) + handle.add_done_callback(lambda _: span.finish()) + return handle async def start_child_workflow( self, input: temporalio.worker.StartChildWorkflowInput ) -> temporalio.workflow.ChildWorkflowHandle: + span = custom_span( + name="temporal:startChildWorkflow", data={"workflow": input.workflow} + ) + span.start(mark_as_current=True) set_header_from_context(input, temporalio.workflow.payload_converter()) - return await self.next.start_child_workflow(input) + handle = await self.next.start_child_workflow(input) + handle.add_done_callback(lambda _: span.finish()) + return handle def start_local_activity( self, input: temporalio.worker.StartLocalActivityInput ) -> temporalio.workflow.ActivityHandle: + span = custom_span( + name="temporal:startLocalActivity", data={"activity": input.activity} + ) + span.start(mark_as_current=True) set_header_from_context(input, temporalio.workflow.payload_converter()) - return self.next.start_local_activity(input) + handle = self.next.start_local_activity(input) + handle.add_done_callback(lambda _: span.finish()) + return handle diff --git a/tests/contrib/openai_agents/test_openai_tracing.py b/tests/contrib/openai_agents/test_openai_tracing.py new file mode 100644 index 000000000..5a7d03785 --- /dev/null +++ b/tests/contrib/openai_agents/test_openai_tracing.py @@ -0,0 +1,151 @@ +import datetime +import uuid +from datetime import timedelta +from typing import Any, Optional + +from agents import Span, Trace, TracingProcessor +from agents.tracing import get_trace_provider + +from temporalio.client import Client +from temporalio.contrib.openai_agents import ( + ModelActivity, + OpenAIAgentsTracingInterceptor, + TestModelProvider, + set_open_ai_agent_temporal_overrides, +) +from temporalio.contrib.pydantic import pydantic_data_converter +from tests.contrib.openai_agents.test_openai import ResearchWorkflow, TestResearchModel +from tests.helpers import new_worker + + +class MemoryTracingProcessor(TracingProcessor): + # True for start events, false for end + trace_events: list[tuple[Trace, bool]] = [] + span_events: list[tuple[Span, bool]] = [] + + def on_trace_start(self, trace: Trace) -> None: + self.trace_events.append((trace, True)) + + def on_trace_end(self, trace: Trace) -> None: + self.trace_events.append((trace, False)) + + def on_span_start(self, span: Span[Any]) -> None: + self.span_events.append((span, True)) + + def on_span_end(self, span: Span[Any]) -> None: + self.span_events.append((span, False)) + + def shutdown(self) -> None: + pass + + def force_flush(self) -> None: + pass + + +async def test_tracing(client: Client): + new_config = client.config() + new_config["data_converter"] = pydantic_data_converter + client = Client(**new_config) + + with set_open_ai_agent_temporal_overrides(): + provider = get_trace_provider() + + processor = MemoryTracingProcessor() + provider.set_processors([processor]) + + model_activity = ModelActivity(TestModelProvider(TestResearchModel())) + async with new_worker( + client, + ResearchWorkflow, + activities=[model_activity.invoke_model_activity], + interceptors=[OpenAIAgentsTracingInterceptor()], + ) as worker: + workflow_handle = await client.start_workflow( + ResearchWorkflow.run, + "Caribbean vacation spots in April, optimizing for surfing, hiking and water sports", + id=f"research-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=120), + ) + result = await workflow_handle.result() + + # There is one closed root trace + assert len(processor.trace_events) == 2 + assert ( + processor.trace_events[0][0].trace_id + == processor.trace_events[1][0].trace_id + ) + assert processor.trace_events[0][1] + assert not processor.trace_events[1][1] + + def paired_span(a: tuple[Span[Any], bool], b: tuple[Span[Any], bool]) -> None: + assert a[0].trace_id == b[0].trace_id + assert a[1] + assert not b[1] + + # Initial planner spans - There are only 3 because we don't make an actual model call + paired_span(processor.span_events[0], processor.span_events[5]) + assert ( + processor.span_events[0][0].span_data.export().get("name") == "PlannerAgent" + ) + + paired_span(processor.span_events[1], processor.span_events[4]) + assert ( + processor.span_events[1][0].span_data.export().get("name") + == "temporal:startActivity" + ) + + paired_span(processor.span_events[2], processor.span_events[3]) + assert ( + processor.span_events[2][0].span_data.export().get("name") + == "temporal:executeActivity" + ) + + for span, start in processor.span_events[6:-6]: + span_data = span.span_data.export() + + # All spans should be closed + if start: + assert any( + span.span_id == s.span_id and not s_start + for (s, s_start) in processor.span_events + ) + + # Start activity is always parented to an agent + if span_data.get("name") == "temporal:startActivity": + parents = [ + s for (s, _) in processor.span_events if s.span_id == span.parent_id + ] + assert ( + len(parents) == 2 + and parents[0].span_data.export()["type"] == "agent" + ) + + # Execute is parented to start + if span_data.get("name") == "temporal:executeActivity": + parents = [ + s for (s, _) in processor.span_events if s.span_id == span.parent_id + ] + assert ( + len(parents) == 2 + and parents[0].span_data.export()["name"] + == "temporal:startActivity" + ) + + # Final writer spans - There are only 3 because we don't make an actual model call + paired_span(processor.span_events[-6], processor.span_events[-1]) + assert ( + processor.span_events[-6][0].span_data.export().get("name") == "WriterAgent" + ) + + paired_span(processor.span_events[-5], processor.span_events[-2]) + assert ( + processor.span_events[-5][0].span_data.export().get("name") + == "temporal:startActivity" + ) + + paired_span(processor.span_events[-4], processor.span_events[-3]) + assert ( + processor.span_events[-4][0].span_data.export().get("name") + == "temporal:executeActivity" + )