Skip to content

Commit 382ee2c

Browse files
committed
Fix type errors
Paramaterize OpenAI agents with context type in test
1 parent 4949c1e commit 382ee2c

File tree

16 files changed

+108
-107
lines changed

16 files changed

+108
-107
lines changed

pyproject.toml

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ reportAny = "none"
165165
reportCallInDefaultInitializer = "none"
166166
reportExplicitAny = "none"
167167
reportIgnoreCommentWithoutRule = "none"
168+
reportImplicitAbstractClass = "none"
168169
reportImplicitOverride = "none"
169170
reportImplicitStringConcatenation = "none"
170171
reportImportCycles = "none"
@@ -184,11 +185,6 @@ exclude = [
184185
"temporalio/bridge/proto",
185186
"tests/worker/workflow_sandbox/testmodules/proto",
186187
"temporalio/bridge/worker.py",
187-
"temporalio/contrib/opentelemetry.py",
188-
"temporalio/contrib/pydantic.py",
189-
"temporalio/converter.py",
190-
"temporalio/testing/_workflow.py",
191-
"temporalio/worker/_activity.py",
192188
"temporalio/worker/_replayer.py",
193189
"temporalio/worker/_worker.py",
194190
"temporalio/worker/workflow_sandbox/_importer.py",
@@ -203,9 +199,7 @@ exclude = [
203199
"tests/contrib/pydantic/workflows.py",
204200
"tests/test_converter.py",
205201
"tests/test_service.py",
206-
"tests/test_workflow.py",
207202
"tests/worker/test_activity.py",
208-
"tests/worker/test_workflow.py",
209203
"tests/worker/workflow_sandbox/test_importer.py",
210204
"tests/worker/workflow_sandbox/test_restrictions.py",
211205
# TODO: these pass locally but fail in CI with

temporalio/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2859,7 +2859,7 @@ def _from_raw_info(
28592859
cls,
28602860
info: temporalio.api.workflow.v1.WorkflowExecutionInfo,
28612861
converter: temporalio.converter.DataConverter,
2862-
**additional_fields,
2862+
**additional_fields: Any,
28632863
) -> WorkflowExecution:
28642864
return cls(
28652865
close_time=info.close_time.ToDatetime().replace(tzinfo=timezone.utc)

temporalio/contrib/openai_agents/_heartbeat_decorator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
def _auto_heartbeater(fn: F) -> F:
1111
# Propagate type hints from the original callable.
1212
@wraps(fn)
13-
async def wrapper(*args, **kwargs):
13+
async def wrapper(*args: Any, **kwargs: Any) -> Any:
1414
heartbeat_timeout = activity.info().heartbeat_timeout
1515
heartbeat_task = None
1616
if heartbeat_timeout:

temporalio/contrib/openai_agents/_openai_runner.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
from dataclasses import replace
2-
from datetime import timedelta
3-
from typing import Optional, Union
2+
from typing import Any, Union
43

54
from agents import (
65
Agent,
76
RunConfig,
8-
RunHooks,
97
RunResult,
108
RunResultStreaming,
119
TContext,
@@ -14,10 +12,8 @@
1412
from agents.run import DEFAULT_AGENT_RUNNER, DEFAULT_MAX_TURNS, AgentRunner
1513

1614
from temporalio import workflow
17-
from temporalio.common import Priority, RetryPolicy
1815
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
1916
from temporalio.contrib.openai_agents._temporal_model_stub import _TemporalModelStub
20-
from temporalio.workflow import ActivityCancellationType, VersioningIntent
2117

2218

2319
class TemporalOpenAIRunner(AgentRunner):
@@ -36,7 +32,7 @@ async def run(
3632
self,
3733
starting_agent: Agent[TContext],
3834
input: Union[str, list[TResponseInputItem]],
39-
**kwargs,
35+
**kwargs: Any,
4036
) -> RunResult:
4137
"""Run the agent in a Temporal workflow."""
4238
if not workflow.in_workflow():
@@ -82,7 +78,7 @@ def run_sync(
8278
self,
8379
starting_agent: Agent[TContext],
8480
input: Union[str, list[TResponseInputItem]],
85-
**kwargs,
81+
**kwargs: Any,
8682
) -> RunResult:
8783
"""Run the agent synchronously (not supported in Temporal workflows)."""
8884
if not workflow.in_workflow():
@@ -97,7 +93,7 @@ def run_streamed(
9793
self,
9894
starting_agent: Agent[TContext],
9995
input: Union[str, list[TResponseInputItem]],
100-
**kwargs,
96+
**kwargs: Any,
10197
) -> RunResultStreaming:
10298
"""Run the agent with streaming responses (not supported in Temporal workflows)."""
10399
if not workflow.in_workflow():

temporalio/contrib/openai_agents/_temporal_trace_provider.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""Provides support for integration with OpenAI Agents SDK tracing across workflows"""
22

33
import uuid
4-
from typing import Any, Optional, Union, cast
4+
from types import TracebackType
5+
from typing import Any, Optional, cast
56

67
from agents import SpanData, Trace, TracingProcessor
78
from agents.tracing import (
@@ -184,6 +185,11 @@ def __enter__(self):
184185
"""Enter the context of the Temporal trace provider."""
185186
return self
186187

187-
def __exit__(self, exc_type, exc_val, exc_tb):
188+
def __exit__(
189+
self,
190+
exc_type: type[BaseException],
191+
exc_val: BaseException,
192+
exc_tb: TracebackType,
193+
):
188194
"""Exit the context of the Temporal trace provider."""
189195
self._multi_processor.shutdown()

temporalio/contrib/opentelemetry.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class should return the workflow interceptor subclass from
7171
custom attributes desired.
7272
"""
7373

74-
def __init__(
74+
def __init__( # type: ignore[reportMissingSuperCall]
7575
self,
7676
tracer: Optional[opentelemetry.trace.Tracer] = None,
7777
*,
@@ -125,11 +125,10 @@ def workflow_interceptor_class(
125125
:py:meth:`temporalio.worker.Interceptor.workflow_interceptor_class`.
126126
"""
127127
# Set the externs needed
128-
# TODO(cretz): MyPy works w/ spread kwargs instead of direct passing
129128
input.unsafe_extern_functions.update(
130-
**_WorkflowExternFunctions(
131-
__temporal_opentelemetry_completed_span=self._completed_workflow_span,
132-
)
129+
{
130+
"__temporal_opentelemetry_completed_span": self._completed_workflow_span,
131+
}
133132
)
134133
return TracingWorkflowInboundInterceptor
135134

temporalio/converter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ def default(self, o: Any) -> Any:
508508
if isinstance(o, datetime):
509509
return o.isoformat()
510510
# Dataclass support
511-
if dataclasses.is_dataclass(o):
511+
if dataclasses.is_dataclass(o) and not isinstance(o, type):
512512
return dataclasses.asdict(o)
513513
# Support for Pydantic v1's dict method
514514
dict_fn = getattr(o, "dict", None)
@@ -1701,7 +1701,7 @@ def value_to_type(
17011701
arg_type = type_args[i]
17021702
elif type_args[-1] is Ellipsis:
17031703
# Ellipsis means use the second to last one
1704-
arg_type = type_args[-2]
1704+
arg_type = type_args[-2] # type: ignore
17051705
else:
17061706
raise TypeError(
17071707
f"Type {hint} only expecting {len(type_args)} values, got at least {i + 1}"

temporalio/nexus/_operation_context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ class WorkflowRunOperationContext(StartOperationContext):
178178
This API is experimental and unstable.
179179
"""
180180

181-
def __init__(self, *args, **kwargs):
181+
def __init__(self, *args: Any, **kwargs: Any) -> None:
182182
"""Initialize the workflow run operation context."""
183183
super().__init__(*args, **kwargs)
184184
self._temporal_context = _TemporalStartOperationContext.get()

temporalio/testing/_workflow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,7 @@ def assert_error_as_app_error(self) -> Iterator[None]:
534534

535535

536536
class _TimeSkippingClientInterceptor(temporalio.client.Interceptor):
537-
def __init__(self, env: _EphemeralServerWorkflowEnvironment) -> None:
537+
def __init__(self, env: _EphemeralServerWorkflowEnvironment) -> None: # type: ignore[reportMissingSuperCall]
538538
self.env = env
539539

540540
def intercept_client(
@@ -563,7 +563,7 @@ async def start_workflow(
563563

564564

565565
class _TimeSkippingWorkflowHandle(temporalio.client.WorkflowHandle):
566-
env: _EphemeralServerWorkflowEnvironment
566+
env: _EphemeralServerWorkflowEnvironment # type: ignore[reportUninitializedInstanceAttribute]
567567

568568
async def result(
569569
self,

temporalio/worker/_activity.py

Lines changed: 25 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -14,32 +14,22 @@
1414
import threading
1515
import warnings
1616
from abc import ABC, abstractmethod
17+
from collections.abc import Iterator, Sequence
1718
from contextlib import contextmanager
1819
from dataclasses import dataclass, field
1920
from datetime import datetime, timedelta, timezone
2021
from typing import (
2122
Any,
2223
Callable,
23-
Dict,
24-
Iterator,
2524
NoReturn,
2625
Optional,
27-
Sequence,
28-
Tuple,
29-
Type,
3026
Union,
3127
)
3228

3329
import google.protobuf.duration_pb2
3430
import google.protobuf.timestamp_pb2
3531

3632
import temporalio.activity
37-
import temporalio.api.common.v1
38-
import temporalio.bridge.client
39-
import temporalio.bridge.proto
40-
import temporalio.bridge.proto.activity_result
41-
import temporalio.bridge.proto.activity_task
42-
import temporalio.bridge.proto.common
4333
import temporalio.bridge.runtime
4434
import temporalio.bridge.worker
4535
import temporalio.client
@@ -76,7 +66,7 @@ def __init__(
7666
self._task_queue = task_queue
7767
self._activity_executor = activity_executor
7868
self._shared_state_manager = shared_state_manager
79-
self._running_activities: Dict[bytes, _RunningActivity] = {}
69+
self._running_activities: dict[bytes, _RunningActivity] = {}
8070
self._data_converter = data_converter
8171
self._interceptors = interceptors
8272
self._metric_meter = metric_meter
@@ -90,7 +80,7 @@ def __init__(
9080
self._client = client
9181

9282
# Validate and build activity dict
93-
self._activities: Dict[str, temporalio.activity._Definition] = {}
83+
self._activities: dict[str, temporalio.activity._Definition] = {}
9484
self._dynamic_activity: Optional[temporalio.activity._Definition] = None
9585
for activity in activities:
9686
# Get definition
@@ -178,7 +168,7 @@ async def raise_from_exception_queue() -> NoReturn:
178168
self._handle_cancel_activity_task(task.task_token, task.cancel)
179169
else:
180170
raise RuntimeError(f"Unrecognized activity task: {task}")
181-
except temporalio.bridge.worker.PollShutdownError:
171+
except temporalio.bridge.worker.PollShutdownError: # type: ignore[reportPrivateLocalImportUsage]
182172
exception_task.cancel()
183173
return
184174
except Exception as err:
@@ -195,12 +185,12 @@ async def drain_poll_queue(self) -> None:
195185
try:
196186
# Just take all tasks and say we can't handle them
197187
task = await self._bridge_worker().poll_activity_task()
198-
completion = temporalio.bridge.proto.ActivityTaskCompletion(
188+
completion = temporalio.bridge.proto.ActivityTaskCompletion( # type: ignore[reportAttributeAccessIssue]
199189
task_token=task.task_token
200190
)
201191
completion.result.failed.failure.message = "Worker shutting down"
202192
await self._bridge_worker().complete_activity_task(completion)
203-
except temporalio.bridge.worker.PollShutdownError:
193+
except temporalio.bridge.worker.PollShutdownError: # type: ignore[reportPrivateLocalImportUsage]
204194
return
205195

206196
# Only call this after run()/drain_poll_queue() have returned. This will not
@@ -214,7 +204,9 @@ async def wait_all_completed(self) -> None:
214204
await asyncio.gather(*running_tasks, return_exceptions=False)
215205

216206
def _handle_cancel_activity_task(
217-
self, task_token: bytes, cancel: temporalio.bridge.proto.activity_task.Cancel
207+
self,
208+
task_token: bytes,
209+
cancel: temporalio.bridge.proto.activity_task.Cancel, # type: ignore[reportAttributeAccessIssue]
218210
) -> None:
219211
"""Request cancellation of a running activity task."""
220212
activity = self._running_activities.get(task_token)
@@ -262,7 +254,9 @@ async def _heartbeat_async(
262254

263255
# Perform the heartbeat
264256
try:
265-
heartbeat = temporalio.bridge.proto.ActivityHeartbeat(task_token=task_token)
257+
heartbeat = temporalio.bridge.proto.ActivityHeartbeat( # type: ignore[reportAttributeAccessIssue]
258+
task_token=task_token
259+
)
266260
if details:
267261
# Convert to core payloads
268262
heartbeat.details.extend(await self._data_converter.encode(details))
@@ -284,7 +278,7 @@ async def _heartbeat_async(
284278
async def _handle_start_activity_task(
285279
self,
286280
task_token: bytes,
287-
start: temporalio.bridge.proto.activity_task.Start,
281+
start: temporalio.bridge.proto.activity_task.Start, # type: ignore[reportAttributeAccessIssue]
288282
running_activity: _RunningActivity,
289283
) -> None:
290284
"""Handle a start activity task.
@@ -296,7 +290,7 @@ async def _handle_start_activity_task(
296290
# We choose to surround interceptor creation and activity invocation in
297291
# a try block so we can mark the workflow as failed on any error instead
298292
# of having error handling in the interceptor
299-
completion = temporalio.bridge.proto.ActivityTaskCompletion(
293+
completion = temporalio.bridge.proto.ActivityTaskCompletion( # type: ignore[reportAttributeAccessIssue]
300294
task_token=task_token
301295
)
302296
try:
@@ -413,7 +407,7 @@ async def _handle_start_activity_task(
413407

414408
async def _execute_activity(
415409
self,
416-
start: temporalio.bridge.proto.activity_task.Start,
410+
start: temporalio.bridge.proto.activity_task.Start, # type: ignore[reportAttributeAccessIssue]
417411
running_activity: _RunningActivity,
418412
task_token: bytes,
419413
) -> Any:
@@ -649,14 +643,14 @@ class _ThreadExceptionRaiser:
649643
def __init__(self) -> None:
650644
self._lock = threading.Lock()
651645
self._thread_id: Optional[int] = None
652-
self._pending_exception: Optional[Type[Exception]] = None
646+
self._pending_exception: Optional[type[Exception]] = None
653647
self._shield_depth = 0
654648

655649
def set_thread_id(self, thread_id: int) -> None:
656650
with self._lock:
657651
self._thread_id = thread_id
658652

659-
def raise_in_thread(self, exc_type: Type[Exception]) -> None:
653+
def raise_in_thread(self, exc_type: type[Exception]) -> None:
660654
with self._lock:
661655
self._pending_exception = exc_type
662656
self._raise_in_thread_if_pending_unlocked()
@@ -812,7 +806,7 @@ def _execute_sync_activity(
812806
cancelled_event: threading.Event,
813807
worker_shutdown_event: threading.Event,
814808
payload_converter_class_or_instance: Union[
815-
Type[temporalio.converter.PayloadConverter],
809+
type[temporalio.converter.PayloadConverter],
816810
temporalio.converter.PayloadConverter,
817811
],
818812
runtime_metric_meter: Optional[temporalio.common.MetricMeter],
@@ -824,13 +818,10 @@ def _execute_sync_activity(
824818
thread_id = threading.current_thread().ident
825819
if thread_id is not None:
826820
cancel_thread_raiser.set_thread_id(thread_id)
827-
heartbeat_fn: Callable[..., None]
828821
if isinstance(heartbeat, SharedHeartbeatSender):
829-
# To make mypy happy
830-
heartbeat_sender = heartbeat
831-
heartbeat_fn = lambda *details: heartbeat_sender.send_heartbeat(
832-
info.task_token, *details
833-
)
822+
823+
def heartbeat_fn(*details: Any) -> None:
824+
heartbeat.send_heartbeat(info.task_token, *details)
834825
else:
835826
heartbeat_fn = heartbeat
836827
temporalio.activity._Context.set(
@@ -940,11 +931,11 @@ def __init__(
940931
self._mgr = mgr
941932
self._queue_poller_executor = queue_poller_executor
942933
# 1000 in-flight heartbeats should be plenty
943-
self._heartbeat_queue: queue.Queue[Tuple[bytes, Sequence[Any]]] = mgr.Queue(
934+
self._heartbeat_queue: queue.Queue[tuple[bytes, Sequence[Any]]] = mgr.Queue(
944935
1000
945936
)
946-
self._heartbeats: Dict[bytes, Callable[..., None]] = {}
947-
self._heartbeat_completions: Dict[bytes, Callable] = {}
937+
self._heartbeats: dict[bytes, Callable[..., None]] = {}
938+
self._heartbeat_completions: dict[bytes, Callable] = {}
948939

949940
def new_event(self) -> threading.Event:
950941
return self._mgr.Event()
@@ -1002,7 +993,7 @@ def _heartbeat_processor(self) -> None:
1002993

1003994
class _MultiprocessingSharedHeartbeatSender(SharedHeartbeatSender):
1004995
def __init__(
1005-
self, heartbeat_queue: queue.Queue[Tuple[bytes, Sequence[Any]]]
996+
self, heartbeat_queue: queue.Queue[tuple[bytes, Sequence[Any]]]
1006997
) -> None:
1007998
super().__init__()
1008999
self._heartbeat_queue = heartbeat_queue

0 commit comments

Comments
 (0)