Skip to content

Commit 60f67d9

Browse files
authored
Fixing OpenAI tracing issues (#974)
* Fixing tracing issues. Propagate span parentage, remove startActivity * Remove traceback * Add start traces with correct durations * Add tracing test * Remove cast * Remove flaky test check
1 parent 4949c1e commit 60f67d9

File tree

2 files changed

+209
-24
lines changed

2 files changed

+209
-24
lines changed

temporalio/contrib/openai_agents/_trace_interceptor.py

Lines changed: 58 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
from __future__ import annotations
44

55
from contextlib import contextmanager
6-
from typing import Any, Mapping, Protocol, Type
6+
from typing import Any, Mapping, Protocol, Type, cast
77

8-
from agents import custom_span, get_current_span, trace
8+
from agents import CustomSpanData, custom_span, get_current_span, trace
99
from agents.tracing import (
1010
get_trace_provider,
1111
)
12-
from agents.tracing.spans import NoOpSpan
12+
from agents.tracing.provider import DefaultTraceProvider
13+
from agents.tracing.scope import Scope
14+
from agents.tracing.spans import NoOpSpan, SpanImpl
1315

1416
import temporalio.activity
1517
import temporalio.api.common.v1
@@ -65,11 +67,15 @@ def context_from_header(
6567
else workflow.info().workflow_type
6668
)
6769
data = (
68-
{"activityId": activity.info().activity_id}
70+
{
71+
"activityId": activity.info().activity_id,
72+
"activity": activity.info().activity_type,
73+
}
6974
if activity.in_activity()
7075
else None
7176
)
72-
if get_trace_provider().get_current_trace() is None:
77+
current_trace = get_trace_provider().get_current_trace()
78+
if current_trace is None:
7379
metadata = {
7480
"temporal:workflowId": activity.info().workflow_id
7581
if activity.in_activity()
@@ -79,16 +85,21 @@ def context_from_header(
7985
else workflow.info().run_id,
8086
"temporal:workflowType": workflow_type,
8187
}
82-
with trace(
88+
current_trace = trace(
8389
span_info["traceName"],
8490
trace_id=span_info["traceId"],
8591
metadata=metadata,
86-
) as t:
87-
with custom_span(name=span_name, parent=t, data=data):
88-
yield
89-
else:
90-
with custom_span(name=span_name, parent=None, data=data):
91-
yield
92+
)
93+
Scope.set_current_trace(current_trace)
94+
current_span = get_trace_provider().get_current_span()
95+
if current_span is None:
96+
current_span = get_trace_provider().create_span(
97+
span_data=CustomSpanData(name="", data={}), span_id=span_info["spanId"]
98+
)
99+
Scope.set_current_span(current_span)
100+
101+
with custom_span(name=span_name, parent=current_span, data=data):
102+
yield
92103

93104

94105
class OpenAIAgentsTracingInterceptor(
@@ -115,7 +126,7 @@ class OpenAIAgentsTracingInterceptor(
115126
worker = Worker(client, task_queue="my-task-queue", interceptors=[interceptor])
116127
"""
117128

118-
def __init__( # type: ignore[reportMissingSuperCall]
129+
def __init__(
119130
self,
120131
payload_converter: temporalio.converter.PayloadConverter = temporalio.converter.default().payload_converter,
121132
) -> None:
@@ -325,32 +336,55 @@ class _ContextPropagationWorkflowOutboundInterceptor(
325336
async def signal_child_workflow(
326337
self, input: temporalio.worker.SignalChildWorkflowInput
327338
) -> None:
328-
set_header_from_context(input, temporalio.workflow.payload_converter())
329-
return await self.next.signal_child_workflow(input)
339+
with custom_span(
340+
name="temporal:signalChildWorkflow",
341+
data={"workflowId": input.child_workflow_id},
342+
):
343+
set_header_from_context(input, temporalio.workflow.payload_converter())
344+
await self.next.signal_child_workflow(input)
330345

331346
async def signal_external_workflow(
332347
self, input: temporalio.worker.SignalExternalWorkflowInput
333348
) -> None:
334-
set_header_from_context(input, temporalio.workflow.payload_converter())
335-
return await self.next.signal_external_workflow(input)
349+
with custom_span(
350+
name="temporal:signalExternalWorkflow",
351+
data={"workflowId": input.workflow_id},
352+
):
353+
set_header_from_context(input, temporalio.workflow.payload_converter())
354+
await self.next.signal_external_workflow(input)
336355

337356
def start_activity(
338357
self, input: temporalio.worker.StartActivityInput
339358
) -> temporalio.workflow.ActivityHandle:
340-
with custom_span(
341-
name=f"temporal:startActivity:{input.activity}",
342-
):
343-
set_header_from_context(input, temporalio.workflow.payload_converter())
344-
return self.next.start_activity(input)
359+
span = custom_span(
360+
name="temporal:startActivity", data={"activity": input.activity}
361+
)
362+
span.start(mark_as_current=True)
363+
set_header_from_context(input, temporalio.workflow.payload_converter())
364+
handle = self.next.start_activity(input)
365+
handle.add_done_callback(lambda _: span.finish())
366+
return handle
345367

346368
async def start_child_workflow(
347369
self, input: temporalio.worker.StartChildWorkflowInput
348370
) -> temporalio.workflow.ChildWorkflowHandle:
371+
span = custom_span(
372+
name="temporal:startChildWorkflow", data={"workflow": input.workflow}
373+
)
374+
span.start(mark_as_current=True)
349375
set_header_from_context(input, temporalio.workflow.payload_converter())
350-
return await self.next.start_child_workflow(input)
376+
handle = await self.next.start_child_workflow(input)
377+
handle.add_done_callback(lambda _: span.finish())
378+
return handle
351379

352380
def start_local_activity(
353381
self, input: temporalio.worker.StartLocalActivityInput
354382
) -> temporalio.workflow.ActivityHandle:
383+
span = custom_span(
384+
name="temporal:startLocalActivity", data={"activity": input.activity}
385+
)
386+
span.start(mark_as_current=True)
355387
set_header_from_context(input, temporalio.workflow.payload_converter())
356-
return self.next.start_local_activity(input)
388+
handle = self.next.start_local_activity(input)
389+
handle.add_done_callback(lambda _: span.finish())
390+
return handle
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
import datetime
2+
import uuid
3+
from datetime import timedelta
4+
from typing import Any, Optional
5+
6+
from agents import Span, Trace, TracingProcessor
7+
from agents.tracing import get_trace_provider
8+
9+
from temporalio.client import Client
10+
from temporalio.contrib.openai_agents import (
11+
ModelActivity,
12+
OpenAIAgentsTracingInterceptor,
13+
TestModelProvider,
14+
set_open_ai_agent_temporal_overrides,
15+
)
16+
from temporalio.contrib.pydantic import pydantic_data_converter
17+
from tests.contrib.openai_agents.test_openai import ResearchWorkflow, TestResearchModel
18+
from tests.helpers import new_worker
19+
20+
21+
class MemoryTracingProcessor(TracingProcessor):
22+
# True for start events, false for end
23+
trace_events: list[tuple[Trace, bool]] = []
24+
span_events: list[tuple[Span, bool]] = []
25+
26+
def on_trace_start(self, trace: Trace) -> None:
27+
self.trace_events.append((trace, True))
28+
29+
def on_trace_end(self, trace: Trace) -> None:
30+
self.trace_events.append((trace, False))
31+
32+
def on_span_start(self, span: Span[Any]) -> None:
33+
self.span_events.append((span, True))
34+
35+
def on_span_end(self, span: Span[Any]) -> None:
36+
self.span_events.append((span, False))
37+
38+
def shutdown(self) -> None:
39+
pass
40+
41+
def force_flush(self) -> None:
42+
pass
43+
44+
45+
async def test_tracing(client: Client):
46+
new_config = client.config()
47+
new_config["data_converter"] = pydantic_data_converter
48+
client = Client(**new_config)
49+
50+
with set_open_ai_agent_temporal_overrides():
51+
provider = get_trace_provider()
52+
53+
processor = MemoryTracingProcessor()
54+
provider.set_processors([processor])
55+
56+
model_activity = ModelActivity(TestModelProvider(TestResearchModel()))
57+
async with new_worker(
58+
client,
59+
ResearchWorkflow,
60+
activities=[model_activity.invoke_model_activity],
61+
interceptors=[OpenAIAgentsTracingInterceptor()],
62+
) as worker:
63+
workflow_handle = await client.start_workflow(
64+
ResearchWorkflow.run,
65+
"Caribbean vacation spots in April, optimizing for surfing, hiking and water sports",
66+
id=f"research-workflow-{uuid.uuid4()}",
67+
task_queue=worker.task_queue,
68+
execution_timeout=timedelta(seconds=120),
69+
)
70+
result = await workflow_handle.result()
71+
72+
# There is one closed root trace
73+
assert len(processor.trace_events) == 2
74+
assert (
75+
processor.trace_events[0][0].trace_id
76+
== processor.trace_events[1][0].trace_id
77+
)
78+
assert processor.trace_events[0][1]
79+
assert not processor.trace_events[1][1]
80+
81+
def paired_span(a: tuple[Span[Any], bool], b: tuple[Span[Any], bool]) -> None:
82+
assert a[0].trace_id == b[0].trace_id
83+
assert a[1]
84+
assert not b[1]
85+
86+
# Initial planner spans - There are only 3 because we don't make an actual model call
87+
paired_span(processor.span_events[0], processor.span_events[5])
88+
assert (
89+
processor.span_events[0][0].span_data.export().get("name") == "PlannerAgent"
90+
)
91+
92+
paired_span(processor.span_events[1], processor.span_events[4])
93+
assert (
94+
processor.span_events[1][0].span_data.export().get("name")
95+
== "temporal:startActivity"
96+
)
97+
98+
paired_span(processor.span_events[2], processor.span_events[3])
99+
assert (
100+
processor.span_events[2][0].span_data.export().get("name")
101+
== "temporal:executeActivity"
102+
)
103+
104+
for span, start in processor.span_events[6:-6]:
105+
span_data = span.span_data.export()
106+
107+
# All spans should be closed
108+
if start:
109+
assert any(
110+
span.span_id == s.span_id and not s_start
111+
for (s, s_start) in processor.span_events
112+
)
113+
114+
# Start activity is always parented to an agent
115+
if span_data.get("name") == "temporal:startActivity":
116+
parents = [
117+
s for (s, _) in processor.span_events if s.span_id == span.parent_id
118+
]
119+
assert (
120+
len(parents) == 2
121+
and parents[0].span_data.export()["type"] == "agent"
122+
)
123+
124+
# Execute is parented to start
125+
if span_data.get("name") == "temporal:executeActivity":
126+
parents = [
127+
s for (s, _) in processor.span_events if s.span_id == span.parent_id
128+
]
129+
assert (
130+
len(parents) == 2
131+
and parents[0].span_data.export()["name"]
132+
== "temporal:startActivity"
133+
)
134+
135+
# Final writer spans - There are only 3 because we don't make an actual model call
136+
paired_span(processor.span_events[-6], processor.span_events[-1])
137+
assert (
138+
processor.span_events[-6][0].span_data.export().get("name") == "WriterAgent"
139+
)
140+
141+
paired_span(processor.span_events[-5], processor.span_events[-2])
142+
assert (
143+
processor.span_events[-5][0].span_data.export().get("name")
144+
== "temporal:startActivity"
145+
)
146+
147+
paired_span(processor.span_events[-4], processor.span_events[-3])
148+
assert (
149+
processor.span_events[-4][0].span_data.export().get("name")
150+
== "temporal:executeActivity"
151+
)

0 commit comments

Comments
 (0)