Skip to content

Fix more type errors #977

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ dependencies = [
"nexus-rpc>=1.1.0",
"protobuf>=3.20,<6",
"python-dateutil>=2.8.2,<3 ; python_version < '3.11'",
"temporalio-xray",
"types-protobuf>=3.20",
"typing-extensions>=4.2.0,<5",
]
Expand Down Expand Up @@ -165,6 +166,7 @@ reportAny = "none"
reportCallInDefaultInitializer = "none"
reportExplicitAny = "none"
reportIgnoreCommentWithoutRule = "none"
reportImplicitAbstractClass = "none"
reportImplicitOverride = "none"
reportImplicitStringConcatenation = "none"
reportImportCycles = "none"
Expand All @@ -184,11 +186,6 @@ exclude = [
"temporalio/bridge/proto",
"tests/worker/workflow_sandbox/testmodules/proto",
"temporalio/bridge/worker.py",
"temporalio/contrib/opentelemetry.py",
"temporalio/contrib/pydantic.py",
"temporalio/converter.py",
"temporalio/testing/_workflow.py",
"temporalio/worker/_activity.py",
"temporalio/worker/_replayer.py",
"temporalio/worker/_worker.py",
"temporalio/worker/workflow_sandbox/_importer.py",
Expand All @@ -203,9 +200,7 @@ exclude = [
"tests/contrib/pydantic/workflows.py",
"tests/test_converter.py",
"tests/test_service.py",
"tests/test_workflow.py",
"tests/worker/test_activity.py",
"tests/worker/test_workflow.py",
"tests/worker/workflow_sandbox/test_importer.py",
"tests/worker/workflow_sandbox/test_restrictions.py",
# TODO: these pass locally but fail in CI with
Expand Down Expand Up @@ -239,3 +234,4 @@ package = false

[tool.uv.sources]
nexus-rpc = { git = "https://github.com/nexus-rpc/sdk-python.git", rev = "35f574c711193a6e2560d3e6665732a5bb7ae92c" }
temporalio-xray = { path = "../xray/sdks/python" }
2 changes: 1 addition & 1 deletion temporalio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2859,7 +2859,7 @@ def _from_raw_info(
cls,
info: temporalio.api.workflow.v1.WorkflowExecutionInfo,
converter: temporalio.converter.DataConverter,
**additional_fields,
**additional_fields: Any,
) -> WorkflowExecution:
return cls(
close_time=info.close_time.ToDatetime().replace(tzinfo=timezone.utc)
Expand Down
2 changes: 1 addition & 1 deletion temporalio/contrib/openai_agents/_heartbeat_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
def _auto_heartbeater(fn: F) -> F:
# Propagate type hints from the original callable.
@wraps(fn)
async def wrapper(*args, **kwargs):
async def wrapper(*args: Any, **kwargs: Any) -> Any:
heartbeat_timeout = activity.info().heartbeat_timeout
heartbeat_task = None
if heartbeat_timeout:
Expand Down
12 changes: 4 additions & 8 deletions temporalio/contrib/openai_agents/_openai_runner.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from dataclasses import replace
from datetime import timedelta
from typing import Optional, Union
from typing import Any, Union

from agents import (
Agent,
RunConfig,
RunHooks,
RunResult,
RunResultStreaming,
TContext,
Expand All @@ -14,10 +12,8 @@
from agents.run import DEFAULT_AGENT_RUNNER, DEFAULT_MAX_TURNS, AgentRunner

from temporalio import workflow
from temporalio.common import Priority, RetryPolicy
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
from temporalio.contrib.openai_agents._temporal_model_stub import _TemporalModelStub
from temporalio.workflow import ActivityCancellationType, VersioningIntent


class TemporalOpenAIRunner(AgentRunner):
Expand All @@ -36,7 +32,7 @@ async def run(
self,
starting_agent: Agent[TContext],
input: Union[str, list[TResponseInputItem]],
**kwargs,
**kwargs: Any,
) -> RunResult:
"""Run the agent in a Temporal workflow."""
if not workflow.in_workflow():
Expand Down Expand Up @@ -82,7 +78,7 @@ def run_sync(
self,
starting_agent: Agent[TContext],
input: Union[str, list[TResponseInputItem]],
**kwargs,
**kwargs: Any,
) -> RunResult:
"""Run the agent synchronously (not supported in Temporal workflows)."""
if not workflow.in_workflow():
Expand All @@ -97,7 +93,7 @@ def run_streamed(
self,
starting_agent: Agent[TContext],
input: Union[str, list[TResponseInputItem]],
**kwargs,
**kwargs: Any,
) -> RunResultStreaming:
"""Run the agent with streaming responses (not supported in Temporal workflows)."""
if not workflow.in_workflow():
Expand Down
10 changes: 8 additions & 2 deletions temporalio/contrib/openai_agents/_temporal_trace_provider.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Provides support for integration with OpenAI Agents SDK tracing across workflows"""

import uuid
from typing import Any, Optional, Union, cast
from types import TracebackType
from typing import Any, Optional, cast

from agents import SpanData, Trace, TracingProcessor
from agents.tracing import (
Expand Down Expand Up @@ -184,6 +185,11 @@ def __enter__(self):
"""Enter the context of the Temporal trace provider."""
return self

def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(
self,
exc_type: type[BaseException],
exc_val: BaseException,
exc_tb: TracebackType,
):
"""Exit the context of the Temporal trace provider."""
self._multi_processor.shutdown()
9 changes: 4 additions & 5 deletions temporalio/contrib/opentelemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class should return the workflow interceptor subclass from
custom attributes desired.
"""

def __init__(
def __init__( # type: ignore[reportMissingSuperCall]
self,
tracer: Optional[opentelemetry.trace.Tracer] = None,
*,
Expand Down Expand Up @@ -125,11 +125,10 @@ def workflow_interceptor_class(
:py:meth:`temporalio.worker.Interceptor.workflow_interceptor_class`.
"""
# Set the externs needed
# TODO(cretz): MyPy works w/ spread kwargs instead of direct passing
input.unsafe_extern_functions.update(
**_WorkflowExternFunctions(
__temporal_opentelemetry_completed_span=self._completed_workflow_span,
)
{
"__temporal_opentelemetry_completed_span": self._completed_workflow_span,
}
Comment on lines -130 to +131
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the harm in using the constructor form of the TypedDict here instead of the dict literal with string literals?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't see how to make it work with pyright (and your comment already indicated mypy probelems)

$ uv run pyright
/Users/dan/src/temporalio/sdk-python/temporalio/contrib/opentelemetry.py
  /Users/dan/src/temporalio/sdk-python/temporalio/contrib/opentelemetry.py:130:15 - error: Argument of type "object" cannot be assigned to parameter "kwargs" of type "(...) -> Unknown" in function "update"
    Type "object" is not assignable to type "(...) -> Unknown" (reportArgumentType)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, unfortunate that we can't use that nice constructor-like kwarg approach afforded to us by typed dicts

)
return TracingWorkflowInboundInterceptor

Expand Down
4 changes: 2 additions & 2 deletions temporalio/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ def default(self, o: Any) -> Any:
if isinstance(o, datetime):
return o.isoformat()
# Dataclass support
if dataclasses.is_dataclass(o):
if dataclasses.is_dataclass(o) and not isinstance(o, type):
return dataclasses.asdict(o)
# Support for Pydantic v1's dict method
dict_fn = getattr(o, "dict", None)
Expand Down Expand Up @@ -1701,7 +1701,7 @@ def value_to_type(
arg_type = type_args[i]
elif type_args[-1] is Ellipsis:
# Ellipsis means use the second to last one
arg_type = type_args[-2]
arg_type = type_args[-2] # type: ignore
else:
raise TypeError(
f"Type {hint} only expecting {len(type_args)} values, got at least {i + 1}"
Expand Down
2 changes: 1 addition & 1 deletion temporalio/nexus/_operation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ class WorkflowRunOperationContext(StartOperationContext):
This API is experimental and unstable.
"""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize the workflow run operation context."""
super().__init__(*args, **kwargs)
self._temporal_context = _TemporalStartOperationContext.get()
Expand Down
4 changes: 2 additions & 2 deletions temporalio/testing/_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ def assert_error_as_app_error(self) -> Iterator[None]:


class _TimeSkippingClientInterceptor(temporalio.client.Interceptor):
def __init__(self, env: _EphemeralServerWorkflowEnvironment) -> None:
def __init__(self, env: _EphemeralServerWorkflowEnvironment) -> None: # type: ignore[reportMissingSuperCall]
self.env = env

def intercept_client(
Expand Down Expand Up @@ -563,7 +563,7 @@ async def start_workflow(


class _TimeSkippingWorkflowHandle(temporalio.client.WorkflowHandle):
env: _EphemeralServerWorkflowEnvironment
env: _EphemeralServerWorkflowEnvironment # type: ignore[reportUninitializedInstanceAttribute]

async def result(
self,
Expand Down
59 changes: 25 additions & 34 deletions temporalio/worker/_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,22 @@
import threading
import warnings
from abc import ABC, abstractmethod
from collections.abc import Iterator, Sequence
from contextlib import contextmanager
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from typing import (
Any,
Callable,
Dict,
Iterator,
NoReturn,
Optional,
Sequence,
Tuple,
Type,
Union,
)

import google.protobuf.duration_pb2
import google.protobuf.timestamp_pb2

import temporalio.activity
import temporalio.api.common.v1
import temporalio.bridge.client
import temporalio.bridge.proto
import temporalio.bridge.proto.activity_result
import temporalio.bridge.proto.activity_task
import temporalio.bridge.proto.common
import temporalio.bridge.runtime
import temporalio.bridge.worker
import temporalio.client
Expand Down Expand Up @@ -76,7 +66,7 @@ def __init__(
self._task_queue = task_queue
self._activity_executor = activity_executor
self._shared_state_manager = shared_state_manager
self._running_activities: Dict[bytes, _RunningActivity] = {}
self._running_activities: dict[bytes, _RunningActivity] = {}
self._data_converter = data_converter
self._interceptors = interceptors
self._metric_meter = metric_meter
Expand All @@ -90,7 +80,7 @@ def __init__(
self._client = client

# Validate and build activity dict
self._activities: Dict[str, temporalio.activity._Definition] = {}
self._activities: dict[str, temporalio.activity._Definition] = {}
self._dynamic_activity: Optional[temporalio.activity._Definition] = None
for activity in activities:
# Get definition
Expand Down Expand Up @@ -178,7 +168,7 @@ async def raise_from_exception_queue() -> NoReturn:
self._handle_cancel_activity_task(task.task_token, task.cancel)
else:
raise RuntimeError(f"Unrecognized activity task: {task}")
except temporalio.bridge.worker.PollShutdownError:
except temporalio.bridge.worker.PollShutdownError: # type: ignore[reportPrivateLocalImportUsage]
exception_task.cancel()
return
except Exception as err:
Expand All @@ -195,12 +185,12 @@ async def drain_poll_queue(self) -> None:
try:
# Just take all tasks and say we can't handle them
task = await self._bridge_worker().poll_activity_task()
completion = temporalio.bridge.proto.ActivityTaskCompletion(
completion = temporalio.bridge.proto.ActivityTaskCompletion( # type: ignore[reportAttributeAccessIssue]
task_token=task.task_token
)
completion.result.failed.failure.message = "Worker shutting down"
await self._bridge_worker().complete_activity_task(completion)
except temporalio.bridge.worker.PollShutdownError:
except temporalio.bridge.worker.PollShutdownError: # type: ignore[reportPrivateLocalImportUsage]
return

# Only call this after run()/drain_poll_queue() have returned. This will not
Expand All @@ -214,7 +204,9 @@ async def wait_all_completed(self) -> None:
await asyncio.gather(*running_tasks, return_exceptions=False)

def _handle_cancel_activity_task(
self, task_token: bytes, cancel: temporalio.bridge.proto.activity_task.Cancel
self,
task_token: bytes,
cancel: temporalio.bridge.proto.activity_task.Cancel, # type: ignore[reportAttributeAccessIssue]
) -> None:
"""Request cancellation of a running activity task."""
activity = self._running_activities.get(task_token)
Expand Down Expand Up @@ -262,7 +254,9 @@ async def _heartbeat_async(

# Perform the heartbeat
try:
heartbeat = temporalio.bridge.proto.ActivityHeartbeat(task_token=task_token)
heartbeat = temporalio.bridge.proto.ActivityHeartbeat( # type: ignore[reportAttributeAccessIssue]
task_token=task_token
)
if details:
# Convert to core payloads
heartbeat.details.extend(await self._data_converter.encode(details))
Expand All @@ -284,7 +278,7 @@ async def _heartbeat_async(
async def _handle_start_activity_task(
self,
task_token: bytes,
start: temporalio.bridge.proto.activity_task.Start,
start: temporalio.bridge.proto.activity_task.Start, # type: ignore[reportAttributeAccessIssue]
running_activity: _RunningActivity,
) -> None:
"""Handle a start activity task.
Expand All @@ -296,7 +290,7 @@ async def _handle_start_activity_task(
# We choose to surround interceptor creation and activity invocation in
# a try block so we can mark the workflow as failed on any error instead
# of having error handling in the interceptor
completion = temporalio.bridge.proto.ActivityTaskCompletion(
completion = temporalio.bridge.proto.ActivityTaskCompletion( # type: ignore[reportAttributeAccessIssue]
task_token=task_token
)
try:
Expand Down Expand Up @@ -413,7 +407,7 @@ async def _handle_start_activity_task(

async def _execute_activity(
self,
start: temporalio.bridge.proto.activity_task.Start,
start: temporalio.bridge.proto.activity_task.Start, # type: ignore[reportAttributeAccessIssue]
running_activity: _RunningActivity,
task_token: bytes,
) -> Any:
Expand Down Expand Up @@ -649,14 +643,14 @@ class _ThreadExceptionRaiser:
def __init__(self) -> None:
self._lock = threading.Lock()
self._thread_id: Optional[int] = None
self._pending_exception: Optional[Type[Exception]] = None
self._pending_exception: Optional[type[Exception]] = None
self._shield_depth = 0

def set_thread_id(self, thread_id: int) -> None:
with self._lock:
self._thread_id = thread_id

def raise_in_thread(self, exc_type: Type[Exception]) -> None:
def raise_in_thread(self, exc_type: type[Exception]) -> None:
with self._lock:
self._pending_exception = exc_type
self._raise_in_thread_if_pending_unlocked()
Expand Down Expand Up @@ -812,7 +806,7 @@ def _execute_sync_activity(
cancelled_event: threading.Event,
worker_shutdown_event: threading.Event,
payload_converter_class_or_instance: Union[
Type[temporalio.converter.PayloadConverter],
type[temporalio.converter.PayloadConverter],
temporalio.converter.PayloadConverter,
],
runtime_metric_meter: Optional[temporalio.common.MetricMeter],
Expand All @@ -824,13 +818,10 @@ def _execute_sync_activity(
thread_id = threading.current_thread().ident
if thread_id is not None:
cancel_thread_raiser.set_thread_id(thread_id)
heartbeat_fn: Callable[..., None]
if isinstance(heartbeat, SharedHeartbeatSender):
# To make mypy happy
heartbeat_sender = heartbeat
heartbeat_fn = lambda *details: heartbeat_sender.send_heartbeat(
info.task_token, *details
)

def heartbeat_fn(*details: Any) -> None:
heartbeat.send_heartbeat(info.task_token, *details)
else:
heartbeat_fn = heartbeat
temporalio.activity._Context.set(
Expand Down Expand Up @@ -940,11 +931,11 @@ def __init__(
self._mgr = mgr
self._queue_poller_executor = queue_poller_executor
# 1000 in-flight heartbeats should be plenty
self._heartbeat_queue: queue.Queue[Tuple[bytes, Sequence[Any]]] = mgr.Queue(
self._heartbeat_queue: queue.Queue[tuple[bytes, Sequence[Any]]] = mgr.Queue(
1000
)
self._heartbeats: Dict[bytes, Callable[..., None]] = {}
self._heartbeat_completions: Dict[bytes, Callable] = {}
self._heartbeats: dict[bytes, Callable[..., None]] = {}
self._heartbeat_completions: dict[bytes, Callable] = {}

def new_event(self) -> threading.Event:
return self._mgr.Event()
Expand Down Expand Up @@ -1002,7 +993,7 @@ def _heartbeat_processor(self) -> None:

class _MultiprocessingSharedHeartbeatSender(SharedHeartbeatSender):
def __init__(
self, heartbeat_queue: queue.Queue[Tuple[bytes, Sequence[Any]]]
self, heartbeat_queue: queue.Queue[tuple[bytes, Sequence[Any]]]
) -> None:
super().__init__()
self._heartbeat_queue = heartbeat_queue
Expand Down
Loading
Loading