Skip to content

Fixing OpenAI tracing issues #974

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 58 additions & 24 deletions temporalio/contrib/openai_agents/_trace_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
from __future__ import annotations

from contextlib import contextmanager
from typing import Any, Mapping, Protocol, Type
from typing import Any, Mapping, Protocol, Type, cast

from agents import custom_span, get_current_span, trace
from agents import CustomSpanData, custom_span, get_current_span, trace
from agents.tracing import (
get_trace_provider,
)
from agents.tracing.spans import NoOpSpan
from agents.tracing.provider import DefaultTraceProvider
from agents.tracing.scope import Scope
from agents.tracing.spans import NoOpSpan, SpanImpl

import temporalio.activity
import temporalio.api.common.v1
Expand Down Expand Up @@ -65,11 +67,15 @@ def context_from_header(
else workflow.info().workflow_type
)
data = (
{"activityId": activity.info().activity_id}
{
"activityId": activity.info().activity_id,
"activity": activity.info().activity_type,
}
if activity.in_activity()
else None
)
if get_trace_provider().get_current_trace() is None:
current_trace = get_trace_provider().get_current_trace()
if current_trace is None:
metadata = {
"temporal:workflowId": activity.info().workflow_id
if activity.in_activity()
Expand All @@ -79,16 +85,21 @@ def context_from_header(
else workflow.info().run_id,
"temporal:workflowType": workflow_type,
}
with trace(
current_trace = trace(
span_info["traceName"],
trace_id=span_info["traceId"],
metadata=metadata,
) as t:
with custom_span(name=span_name, parent=t, data=data):
yield
else:
with custom_span(name=span_name, parent=None, data=data):
yield
)
Scope.set_current_trace(current_trace)
current_span = get_trace_provider().get_current_span()
if current_span is None:
current_span = get_trace_provider().create_span(
span_data=CustomSpanData(name="", data={}), span_id=span_info["spanId"]
)
Scope.set_current_span(current_span)

with custom_span(name=span_name, parent=current_span, data=data):
yield


class OpenAIAgentsTracingInterceptor(
Expand All @@ -115,7 +126,7 @@ class OpenAIAgentsTracingInterceptor(
worker = Worker(client, task_queue="my-task-queue", interceptors=[interceptor])
"""

def __init__( # type: ignore[reportMissingSuperCall]
def __init__(
self,
payload_converter: temporalio.converter.PayloadConverter = temporalio.converter.default().payload_converter,
) -> None:
Expand Down Expand Up @@ -325,32 +336,55 @@ class _ContextPropagationWorkflowOutboundInterceptor(
async def signal_child_workflow(
self, input: temporalio.worker.SignalChildWorkflowInput
) -> None:
set_header_from_context(input, temporalio.workflow.payload_converter())
return await self.next.signal_child_workflow(input)
with custom_span(
name="temporal:signalChildWorkflow",
data={"workflowId": input.child_workflow_id},
):
set_header_from_context(input, temporalio.workflow.payload_converter())
await self.next.signal_child_workflow(input)

async def signal_external_workflow(
self, input: temporalio.worker.SignalExternalWorkflowInput
) -> None:
set_header_from_context(input, temporalio.workflow.payload_converter())
return await self.next.signal_external_workflow(input)
with custom_span(
name="temporal:signalExternalWorkflow",
data={"workflowId": input.workflow_id},
):
set_header_from_context(input, temporalio.workflow.payload_converter())
await self.next.signal_external_workflow(input)

def start_activity(
self, input: temporalio.worker.StartActivityInput
) -> temporalio.workflow.ActivityHandle:
with custom_span(
name=f"temporal:startActivity:{input.activity}",
):
set_header_from_context(input, temporalio.workflow.payload_converter())
return self.next.start_activity(input)
span = custom_span(
name="temporal:startActivity", data={"activity": input.activity}
)
span.start(mark_as_current=True)
set_header_from_context(input, temporalio.workflow.payload_converter())
handle = self.next.start_activity(input)
handle.add_done_callback(lambda _: span.finish())
return handle

async def start_child_workflow(
self, input: temporalio.worker.StartChildWorkflowInput
) -> temporalio.workflow.ChildWorkflowHandle:
span = custom_span(
name="temporal:startChildWorkflow", data={"workflow": input.workflow}
)
span.start(mark_as_current=True)
set_header_from_context(input, temporalio.workflow.payload_converter())
return await self.next.start_child_workflow(input)
handle = await self.next.start_child_workflow(input)
handle.add_done_callback(lambda _: span.finish())
return handle

def start_local_activity(
self, input: temporalio.worker.StartLocalActivityInput
) -> temporalio.workflow.ActivityHandle:
span = custom_span(
name="temporal:startLocalActivity", data={"activity": input.activity}
)
span.start(mark_as_current=True)
set_header_from_context(input, temporalio.workflow.payload_converter())
return self.next.start_local_activity(input)
handle = self.next.start_local_activity(input)
handle.add_done_callback(lambda _: span.finish())
return handle
151 changes: 151 additions & 0 deletions tests/contrib/openai_agents/test_openai_tracing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import datetime
import uuid
from datetime import timedelta
from typing import Any, Optional

from agents import Span, Trace, TracingProcessor
from agents.tracing import get_trace_provider

from temporalio.client import Client
from temporalio.contrib.openai_agents import (
ModelActivity,
OpenAIAgentsTracingInterceptor,
TestModelProvider,
set_open_ai_agent_temporal_overrides,
)
from temporalio.contrib.pydantic import pydantic_data_converter
from tests.contrib.openai_agents.test_openai import ResearchWorkflow, TestResearchModel
from tests.helpers import new_worker


class MemoryTracingProcessor(TracingProcessor):
# True for start events, false for end
trace_events: list[tuple[Trace, bool]] = []
span_events: list[tuple[Span, bool]] = []

def on_trace_start(self, trace: Trace) -> None:
self.trace_events.append((trace, True))

def on_trace_end(self, trace: Trace) -> None:
self.trace_events.append((trace, False))

def on_span_start(self, span: Span[Any]) -> None:
self.span_events.append((span, True))

def on_span_end(self, span: Span[Any]) -> None:
self.span_events.append((span, False))

def shutdown(self) -> None:
pass

def force_flush(self) -> None:
pass


async def test_tracing(client: Client):
new_config = client.config()
new_config["data_converter"] = pydantic_data_converter
client = Client(**new_config)

with set_open_ai_agent_temporal_overrides():
provider = get_trace_provider()

processor = MemoryTracingProcessor()
provider.set_processors([processor])

model_activity = ModelActivity(TestModelProvider(TestResearchModel()))
async with new_worker(
client,
ResearchWorkflow,
activities=[model_activity.invoke_model_activity],
interceptors=[OpenAIAgentsTracingInterceptor()],
) as worker:
workflow_handle = await client.start_workflow(
ResearchWorkflow.run,
"Caribbean vacation spots in April, optimizing for surfing, hiking and water sports",
id=f"research-workflow-{uuid.uuid4()}",
task_queue=worker.task_queue,
execution_timeout=timedelta(seconds=120),
)
result = await workflow_handle.result()

# There is one closed root trace
assert len(processor.trace_events) == 2
assert (
processor.trace_events[0][0].trace_id
== processor.trace_events[1][0].trace_id
)
assert processor.trace_events[0][1]
assert not processor.trace_events[1][1]

def paired_span(a: tuple[Span[Any], bool], b: tuple[Span[Any], bool]) -> None:
assert a[0].trace_id == b[0].trace_id
assert a[1]
assert not b[1]

# Initial planner spans - There are only 3 because we don't make an actual model call
paired_span(processor.span_events[0], processor.span_events[5])
assert (
processor.span_events[0][0].span_data.export().get("name") == "PlannerAgent"
)

paired_span(processor.span_events[1], processor.span_events[4])
assert (
processor.span_events[1][0].span_data.export().get("name")
== "temporal:startActivity"
)

paired_span(processor.span_events[2], processor.span_events[3])
assert (
processor.span_events[2][0].span_data.export().get("name")
== "temporal:executeActivity"
)

for span, start in processor.span_events[6:-6]:
span_data = span.span_data.export()

# All spans should be closed
if start:
assert any(
span.span_id == s.span_id and not s_start
for (s, s_start) in processor.span_events
)

# Start activity is always parented to an agent
if span_data.get("name") == "temporal:startActivity":
parents = [
s for (s, _) in processor.span_events if s.span_id == span.parent_id
]
assert (
len(parents) == 2
and parents[0].span_data.export()["type"] == "agent"
)

# Execute is parented to start
if span_data.get("name") == "temporal:executeActivity":
parents = [
s for (s, _) in processor.span_events if s.span_id == span.parent_id
]
assert (
len(parents) == 2
and parents[0].span_data.export()["name"]
== "temporal:startActivity"
)

# Final writer spans - There are only 3 because we don't make an actual model call
paired_span(processor.span_events[-6], processor.span_events[-1])
assert (
processor.span_events[-6][0].span_data.export().get("name") == "WriterAgent"
)

paired_span(processor.span_events[-5], processor.span_events[-2])
assert (
processor.span_events[-5][0].span_data.export().get("name")
== "temporal:startActivity"
)

paired_span(processor.span_events[-4], processor.span_events[-3])
assert (
processor.span_events[-4][0].span_data.export().get("name")
== "temporal:executeActivity"
)