diff --git a/pyproject.toml b/pyproject.toml index 2dc122f14..bb6edb63c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -165,6 +165,7 @@ reportAny = "none" reportCallInDefaultInitializer = "none" reportExplicitAny = "none" reportIgnoreCommentWithoutRule = "none" +reportImplicitAbstractClass = "none" reportImplicitOverride = "none" reportImplicitStringConcatenation = "none" reportImportCycles = "none" @@ -184,11 +185,6 @@ exclude = [ "temporalio/bridge/proto", "tests/worker/workflow_sandbox/testmodules/proto", "temporalio/bridge/worker.py", - "temporalio/contrib/opentelemetry.py", - "temporalio/contrib/pydantic.py", - "temporalio/converter.py", - "temporalio/testing/_workflow.py", - "temporalio/worker/_activity.py", "temporalio/worker/_replayer.py", "temporalio/worker/_worker.py", "temporalio/worker/workflow_sandbox/_importer.py", @@ -203,9 +199,7 @@ exclude = [ "tests/contrib/pydantic/workflows.py", "tests/test_converter.py", "tests/test_service.py", - "tests/test_workflow.py", "tests/worker/test_activity.py", - "tests/worker/test_workflow.py", "tests/worker/workflow_sandbox/test_importer.py", "tests/worker/workflow_sandbox/test_restrictions.py", # TODO: these pass locally but fail in CI with diff --git a/temporalio/client.py b/temporalio/client.py index 868e77a26..6c5363ce8 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -2889,7 +2889,7 @@ def _from_raw_info( cls, info: temporalio.api.workflow.v1.WorkflowExecutionInfo, converter: temporalio.converter.DataConverter, - **additional_fields, + **additional_fields: Any, ) -> WorkflowExecution: return cls( close_time=info.close_time.ToDatetime().replace(tzinfo=timezone.utc) diff --git a/temporalio/contrib/openai_agents/_heartbeat_decorator.py b/temporalio/contrib/openai_agents/_heartbeat_decorator.py index bce015ed8..0ddee8ad4 100644 --- a/temporalio/contrib/openai_agents/_heartbeat_decorator.py +++ b/temporalio/contrib/openai_agents/_heartbeat_decorator.py @@ -10,7 +10,7 @@ def _auto_heartbeater(fn: F) -> F: # Propagate type hints from the original callable. @wraps(fn) - async def wrapper(*args, **kwargs): + async def wrapper(*args: Any, **kwargs: Any) -> Any: heartbeat_timeout = activity.info().heartbeat_timeout heartbeat_task = None if heartbeat_timeout: diff --git a/temporalio/contrib/openai_agents/_openai_runner.py b/temporalio/contrib/openai_agents/_openai_runner.py index f43d01388..7a5153141 100644 --- a/temporalio/contrib/openai_agents/_openai_runner.py +++ b/temporalio/contrib/openai_agents/_openai_runner.py @@ -1,11 +1,9 @@ from dataclasses import replace -from datetime import timedelta -from typing import Optional, Union +from typing import Any, Union from agents import ( Agent, RunConfig, - RunHooks, RunResult, RunResultStreaming, TContext, @@ -14,10 +12,8 @@ from agents.run import DEFAULT_AGENT_RUNNER, DEFAULT_MAX_TURNS, AgentRunner from temporalio import workflow -from temporalio.common import Priority, RetryPolicy from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters from temporalio.contrib.openai_agents._temporal_model_stub import _TemporalModelStub -from temporalio.workflow import ActivityCancellationType, VersioningIntent class TemporalOpenAIRunner(AgentRunner): @@ -36,7 +32,7 @@ async def run( self, starting_agent: Agent[TContext], input: Union[str, list[TResponseInputItem]], - **kwargs, + **kwargs: Any, ) -> RunResult: """Run the agent in a Temporal workflow.""" if not workflow.in_workflow(): @@ -82,7 +78,7 @@ def run_sync( self, starting_agent: Agent[TContext], input: Union[str, list[TResponseInputItem]], - **kwargs, + **kwargs: Any, ) -> RunResult: """Run the agent synchronously (not supported in Temporal workflows).""" if not workflow.in_workflow(): @@ -97,7 +93,7 @@ def run_streamed( self, starting_agent: Agent[TContext], input: Union[str, list[TResponseInputItem]], - **kwargs, + **kwargs: Any, ) -> RunResultStreaming: """Run the agent with streaming responses (not supported in Temporal workflows).""" if not workflow.in_workflow(): diff --git a/temporalio/contrib/openai_agents/_temporal_trace_provider.py b/temporalio/contrib/openai_agents/_temporal_trace_provider.py index 1d9b09866..7518a9a94 100644 --- a/temporalio/contrib/openai_agents/_temporal_trace_provider.py +++ b/temporalio/contrib/openai_agents/_temporal_trace_provider.py @@ -1,7 +1,8 @@ """Provides support for integration with OpenAI Agents SDK tracing across workflows""" import uuid -from typing import Any, Optional, Union, cast +from types import TracebackType +from typing import Any, Optional, cast from agents import SpanData, Trace, TracingProcessor from agents.tracing import ( @@ -184,6 +185,11 @@ def __enter__(self): """Enter the context of the Temporal trace provider.""" return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: type[BaseException], + exc_val: BaseException, + exc_tb: TracebackType, + ): """Exit the context of the Temporal trace provider.""" self._multi_processor.shutdown() diff --git a/temporalio/contrib/opentelemetry.py b/temporalio/contrib/opentelemetry.py index 84773fd43..380b666dc 100644 --- a/temporalio/contrib/opentelemetry.py +++ b/temporalio/contrib/opentelemetry.py @@ -71,7 +71,7 @@ class should return the workflow interceptor subclass from custom attributes desired. """ - def __init__( + def __init__( # type: ignore[reportMissingSuperCall] self, tracer: Optional[opentelemetry.trace.Tracer] = None, *, @@ -125,11 +125,10 @@ def workflow_interceptor_class( :py:meth:`temporalio.worker.Interceptor.workflow_interceptor_class`. """ # Set the externs needed - # TODO(cretz): MyPy works w/ spread kwargs instead of direct passing input.unsafe_extern_functions.update( - **_WorkflowExternFunctions( - __temporal_opentelemetry_completed_span=self._completed_workflow_span, - ) + { + "__temporal_opentelemetry_completed_span": self._completed_workflow_span, + } ) return TracingWorkflowInboundInterceptor diff --git a/temporalio/converter.py b/temporalio/converter.py index b4f005802..cd7c2bbf5 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -508,7 +508,7 @@ def default(self, o: Any) -> Any: if isinstance(o, datetime): return o.isoformat() # Dataclass support - if dataclasses.is_dataclass(o): + if dataclasses.is_dataclass(o) and not isinstance(o, type): return dataclasses.asdict(o) # Support for Pydantic v1's dict method dict_fn = getattr(o, "dict", None) @@ -1701,7 +1701,7 @@ def value_to_type( arg_type = type_args[i] elif type_args[-1] is Ellipsis: # Ellipsis means use the second to last one - arg_type = type_args[-2] + arg_type = type_args[-2] # type: ignore else: raise TypeError( f"Type {hint} only expecting {len(type_args)} values, got at least {i + 1}" diff --git a/temporalio/nexus/_operation_context.py b/temporalio/nexus/_operation_context.py index e0f28b28f..f40ed460d 100644 --- a/temporalio/nexus/_operation_context.py +++ b/temporalio/nexus/_operation_context.py @@ -178,7 +178,7 @@ class WorkflowRunOperationContext(StartOperationContext): This API is experimental and unstable. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: """Initialize the workflow run operation context.""" super().__init__(*args, **kwargs) self._temporal_context = _TemporalStartOperationContext.get() diff --git a/temporalio/testing/_workflow.py b/temporalio/testing/_workflow.py index 85c5404ea..3d5359c3c 100644 --- a/temporalio/testing/_workflow.py +++ b/temporalio/testing/_workflow.py @@ -534,7 +534,7 @@ def assert_error_as_app_error(self) -> Iterator[None]: class _TimeSkippingClientInterceptor(temporalio.client.Interceptor): - def __init__(self, env: _EphemeralServerWorkflowEnvironment) -> None: + def __init__(self, env: _EphemeralServerWorkflowEnvironment) -> None: # type: ignore[reportMissingSuperCall] self.env = env def intercept_client( @@ -563,7 +563,7 @@ async def start_workflow( class _TimeSkippingWorkflowHandle(temporalio.client.WorkflowHandle): - env: _EphemeralServerWorkflowEnvironment + env: _EphemeralServerWorkflowEnvironment # type: ignore[reportUninitializedInstanceAttribute] async def result( self, diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index 61a072186..c76c8f005 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -14,19 +14,15 @@ import threading import warnings from abc import ABC, abstractmethod +from collections.abc import Iterator, Sequence from contextlib import contextmanager from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone from typing import ( Any, Callable, - Dict, - Iterator, NoReturn, Optional, - Sequence, - Tuple, - Type, Union, ) @@ -34,12 +30,6 @@ import google.protobuf.timestamp_pb2 import temporalio.activity -import temporalio.api.common.v1 -import temporalio.bridge.client -import temporalio.bridge.proto -import temporalio.bridge.proto.activity_result -import temporalio.bridge.proto.activity_task -import temporalio.bridge.proto.common import temporalio.bridge.runtime import temporalio.bridge.worker import temporalio.client @@ -76,7 +66,7 @@ def __init__( self._task_queue = task_queue self._activity_executor = activity_executor self._shared_state_manager = shared_state_manager - self._running_activities: Dict[bytes, _RunningActivity] = {} + self._running_activities: dict[bytes, _RunningActivity] = {} self._data_converter = data_converter self._interceptors = interceptors self._metric_meter = metric_meter @@ -90,7 +80,7 @@ def __init__( self._client = client # Validate and build activity dict - self._activities: Dict[str, temporalio.activity._Definition] = {} + self._activities: dict[str, temporalio.activity._Definition] = {} self._dynamic_activity: Optional[temporalio.activity._Definition] = None for activity in activities: # Get definition @@ -178,7 +168,7 @@ async def raise_from_exception_queue() -> NoReturn: self._handle_cancel_activity_task(task.task_token, task.cancel) else: raise RuntimeError(f"Unrecognized activity task: {task}") - except temporalio.bridge.worker.PollShutdownError: + except temporalio.bridge.worker.PollShutdownError: # type: ignore[reportPrivateLocalImportUsage] exception_task.cancel() return except Exception as err: @@ -195,12 +185,12 @@ async def drain_poll_queue(self) -> None: try: # Just take all tasks and say we can't handle them task = await self._bridge_worker().poll_activity_task() - completion = temporalio.bridge.proto.ActivityTaskCompletion( + completion = temporalio.bridge.proto.ActivityTaskCompletion( # type: ignore[reportAttributeAccessIssue] task_token=task.task_token ) completion.result.failed.failure.message = "Worker shutting down" await self._bridge_worker().complete_activity_task(completion) - except temporalio.bridge.worker.PollShutdownError: + except temporalio.bridge.worker.PollShutdownError: # type: ignore[reportPrivateLocalImportUsage] return # Only call this after run()/drain_poll_queue() have returned. This will not @@ -214,7 +204,9 @@ async def wait_all_completed(self) -> None: await asyncio.gather(*running_tasks, return_exceptions=False) def _handle_cancel_activity_task( - self, task_token: bytes, cancel: temporalio.bridge.proto.activity_task.Cancel + self, + task_token: bytes, + cancel: temporalio.bridge.proto.activity_task.Cancel, # type: ignore[reportAttributeAccessIssue] ) -> None: """Request cancellation of a running activity task.""" activity = self._running_activities.get(task_token) @@ -262,7 +254,9 @@ async def _heartbeat_async( # Perform the heartbeat try: - heartbeat = temporalio.bridge.proto.ActivityHeartbeat(task_token=task_token) + heartbeat = temporalio.bridge.proto.ActivityHeartbeat( # type: ignore[reportAttributeAccessIssue] + task_token=task_token + ) if details: # Convert to core payloads heartbeat.details.extend(await self._data_converter.encode(details)) @@ -284,7 +278,7 @@ async def _heartbeat_async( async def _handle_start_activity_task( self, task_token: bytes, - start: temporalio.bridge.proto.activity_task.Start, + start: temporalio.bridge.proto.activity_task.Start, # type: ignore[reportAttributeAccessIssue] running_activity: _RunningActivity, ) -> None: """Handle a start activity task. @@ -296,7 +290,7 @@ async def _handle_start_activity_task( # We choose to surround interceptor creation and activity invocation in # a try block so we can mark the workflow as failed on any error instead # of having error handling in the interceptor - completion = temporalio.bridge.proto.ActivityTaskCompletion( + completion = temporalio.bridge.proto.ActivityTaskCompletion( # type: ignore[reportAttributeAccessIssue] task_token=task_token ) try: @@ -413,7 +407,7 @@ async def _handle_start_activity_task( async def _execute_activity( self, - start: temporalio.bridge.proto.activity_task.Start, + start: temporalio.bridge.proto.activity_task.Start, # type: ignore[reportAttributeAccessIssue] running_activity: _RunningActivity, task_token: bytes, ) -> Any: @@ -649,14 +643,14 @@ class _ThreadExceptionRaiser: def __init__(self) -> None: self._lock = threading.Lock() self._thread_id: Optional[int] = None - self._pending_exception: Optional[Type[Exception]] = None + self._pending_exception: Optional[type[Exception]] = None self._shield_depth = 0 def set_thread_id(self, thread_id: int) -> None: with self._lock: self._thread_id = thread_id - def raise_in_thread(self, exc_type: Type[Exception]) -> None: + def raise_in_thread(self, exc_type: type[Exception]) -> None: with self._lock: self._pending_exception = exc_type self._raise_in_thread_if_pending_unlocked() @@ -812,7 +806,7 @@ def _execute_sync_activity( cancelled_event: threading.Event, worker_shutdown_event: threading.Event, payload_converter_class_or_instance: Union[ - Type[temporalio.converter.PayloadConverter], + type[temporalio.converter.PayloadConverter], temporalio.converter.PayloadConverter, ], runtime_metric_meter: Optional[temporalio.common.MetricMeter], @@ -824,13 +818,10 @@ def _execute_sync_activity( thread_id = threading.current_thread().ident if thread_id is not None: cancel_thread_raiser.set_thread_id(thread_id) - heartbeat_fn: Callable[..., None] if isinstance(heartbeat, SharedHeartbeatSender): - # To make mypy happy - heartbeat_sender = heartbeat - heartbeat_fn = lambda *details: heartbeat_sender.send_heartbeat( - info.task_token, *details - ) + + def heartbeat_fn(*details: Any) -> None: + heartbeat.send_heartbeat(info.task_token, *details) else: heartbeat_fn = heartbeat temporalio.activity._Context.set( @@ -940,11 +931,11 @@ def __init__( self._mgr = mgr self._queue_poller_executor = queue_poller_executor # 1000 in-flight heartbeats should be plenty - self._heartbeat_queue: queue.Queue[Tuple[bytes, Sequence[Any]]] = mgr.Queue( + self._heartbeat_queue: queue.Queue[tuple[bytes, Sequence[Any]]] = mgr.Queue( 1000 ) - self._heartbeats: Dict[bytes, Callable[..., None]] = {} - self._heartbeat_completions: Dict[bytes, Callable] = {} + self._heartbeats: dict[bytes, Callable[..., None]] = {} + self._heartbeat_completions: dict[bytes, Callable] = {} def new_event(self) -> threading.Event: return self._mgr.Event() @@ -1002,7 +993,7 @@ def _heartbeat_processor(self) -> None: class _MultiprocessingSharedHeartbeatSender(SharedHeartbeatSender): def __init__( - self, heartbeat_queue: queue.Queue[Tuple[bytes, Sequence[Any]]] + self, heartbeat_queue: queue.Queue[tuple[bytes, Sequence[Any]]] ) -> None: super().__init__() self._heartbeat_queue = heartbeat_queue diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 75e80b3e1..81f21588f 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -197,7 +197,7 @@ def create_instance(self, det: WorkflowInstanceDetails) -> WorkflowInstance: _ExceptionHandler: TypeAlias = Callable[[asyncio.AbstractEventLoop, _Context], Any] -class _WorkflowInstanceImpl( +class _WorkflowInstanceImpl( # type: ignore[reportImplicitAbstractClass] WorkflowInstance, temporalio.workflow._Runtime, asyncio.AbstractEventLoop ): def __init__(self, det: WorkflowInstanceDetails) -> None: diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index 57dc5c252..b6e5b3dbf 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -114,10 +114,10 @@ class TestHelloModel(StaticTestModel): class HelloWorldAgent: @workflow.run async def run(self, prompt: str) -> str: - agent = Agent( + agent = Agent[None]( name="Assistant", instructions="You only respond in haikus.", - ) # type: Agent + ) result = await Runner.run(starting_agent=agent, input=prompt) return result.final_output @@ -358,7 +358,7 @@ class TestNexusWeatherModel(StaticTestModel): class ToolsWorkflow: @workflow.run async def run(self, question: str) -> str: - agent: Agent = Agent( + agent = Agent[str]( name="Tools Workflow", instructions="You are a helpful agent.", tools=[ @@ -390,7 +390,7 @@ async def run(self, question: str) -> str: class NexusToolsWorkflow: @workflow.run async def run(self, question: str) -> str: - agent = Agent( + agent = Agent[str]( name="Nexus Tools Workflow", instructions="You are a helpful agent.", tools=[ @@ -401,7 +401,7 @@ async def run(self, question: str) -> str: schedule_to_close_timeout=timedelta(seconds=10), ), ], - ) # type: Agent + ) result = await Runner.run( starting_agent=agent, input=question, context="Stormy" ) @@ -747,25 +747,25 @@ async def test_research_workflow(client: Client, use_local_model: bool): def orchestrator_agent() -> Agent: - spanish_agent = Agent( + spanish_agent = Agent[None]( name="spanish_agent", instructions="You translate the user's message to Spanish", handoff_description="An english to spanish translator", - ) # type: Agent + ) - french_agent = Agent( + french_agent = Agent[None]( name="french_agent", instructions="You translate the user's message to French", handoff_description="An english to french translator", - ) # type: Agent + ) - italian_agent = Agent( + italian_agent = Agent[None]( name="italian_agent", instructions="You translate the user's message to Italian", handoff_description="An english to italian translator", - ) # type: Agent + ) - orchestrator_agent = Agent( + orchestrator_agent = Agent[None]( name="orchestrator_agent", instructions=( "You are a translation agent. You use the tools given to you to translate." @@ -786,7 +786,7 @@ def orchestrator_agent() -> Agent: tool_description="Translate the user's message to Italian", ), ], - ) # type: Agent + ) return orchestrator_agent diff --git a/tests/test_envconfig.py b/tests/test_envconfig.py index c755549aa..dc675dc77 100644 --- a/tests/test_envconfig.py +++ b/tests/test_envconfig.py @@ -205,7 +205,7 @@ def test_load_profile_disable_env(base_config_file: Path): assert config.get("target_host") == "default-address" -def test_load_profile_disable_file(monkeypatch): +def test_load_profile_disable_file(monkeypatch): # type: ignore[reportMissingParameterType] """Test that `disable_file` loads configuration only from environment.""" monkeypatch.setattr("pathlib.Path.exists", lambda _: False) env = {"TEMPORAL_ADDRESS": "env-address"} @@ -268,7 +268,7 @@ def test_load_profiles_no_env_override(tmp_path: Path, monkeypatch): assert connect_config.get("target_host") == "default-address" -def test_load_profiles_no_config_file(monkeypatch): +def test_load_profiles_no_config_file(monkeypatch): # type: ignore[reportMissingParameterType] """Test that load_profiles works when no config file is found.""" monkeypatch.setattr("pathlib.Path.exists", lambda _: False) monkeypatch.setattr(os, "environ", {}) @@ -276,7 +276,7 @@ def test_load_profiles_no_config_file(monkeypatch): assert not client_config.profiles -def test_load_profiles_discovery(tmp_path: Path, monkeypatch): +def test_load_profiles_discovery(tmp_path: Path, monkeypatch): # type: ignore[reportMissingParameterType] """Test file discovery via environment variables.""" config_file = tmp_path / "config.toml" config_file.write_text(TOML_CONFIG_BASE) diff --git a/tests/test_type_errors.py b/tests/test_type_errors.py index 3c700ff37..d8e6e2afb 100644 --- a/tests/test_type_errors.py +++ b/tests/test_type_errors.py @@ -83,10 +83,6 @@ def _test_type_errors( f"{test_file}:{line_num}: Expected error matching '{expected_pattern}' but got '{actual_msg}'" ) - for line_num, actual_msg in sorted(actual_errors.items()): - if line_num not in expected_errors: - pytest.fail(f"{test_file}:{line_num}: Unexpected type error: {actual_msg}") - def _has_type_error_assertions(test_file: Path) -> bool: """Check if a file contains any type error assertions.""" diff --git a/tests/test_workflow.py b/tests/test_workflow.py index cfeeb91b7..f5ef923b2 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -225,8 +225,7 @@ def update2(self, arg1: str): pass # Intentionally missing decorator - # assert-type-error-pyright: "overrides symbol of same name" - def base_update(self): # type: ignore + def base_update(self): # type: ignore[reportIncompatibleVariableOverride] pass @@ -289,7 +288,7 @@ def run(self): def test_workflow_defn_non_async_run(): with pytest.raises(ValueError) as err: - # assert-type-error-pyright: "Argument .+ cannot be assigned to parameter" + # assert-type-error-pyright: 'Argument .+ cannot be assigned to parameter "fn"' workflow.run(NonAsyncRun.run) # type: ignore assert "must be an async function" in str(err.value) @@ -351,10 +350,10 @@ async def run(self): def some_dynamic1(self): pass - def some_dynamic2(self, no_vararg): + def some_dynamic2(self, no_vararg): # type: ignore[reportMissingParameterType] pass - def old_dynamic(self, name, *args): + def old_dynamic(self, name, *args): # type: ignore[reportMissingParameterType] pass @@ -384,10 +383,10 @@ def test_workflow_defn_dynamic_handler_warnings(): class _TestParametersIdenticalUpToNaming: - def a1(self, a): + def a1(self, a): # type: ignore[reportMissingParameterType] pass - def a2(self, b): + def a2(self, b): # type: ignore[reportMissingParameterType] pass def b1(self, a: int): @@ -402,19 +401,19 @@ def c1(self, a1: int, a2: str) -> str: def c2(self, b1: int, b2: str) -> int: return 0 - def d1(self, a1, a2: str) -> None: + def d1(self, a1, a2: str) -> None: # type: ignore[reportMissingParameterType] pass - def d2(self, b1, b2: str) -> str: + def d2(self, b1, b2: str) -> str: # type: ignore[reportMissingParameterType] return "" - def e1(self, a1, a2: str = "") -> None: + def e1(self, a1, a2: str = "") -> None: # type: ignore[reportMissingParameterType] return None - def e2(self, b1, b2: str = "") -> str: + def e2(self, b1, b2: str = "") -> str: # type: ignore[reportMissingParameterType] return "" - def f1(self, a1, a2: str = "a") -> None: + def f1(self, a1, a2: str = "a") -> None: # type: ignore[reportMissingParameterType] return None diff --git a/tests/worker/test_worker.py b/tests/worker/test_worker.py index 7c08cfa37..3e1a1c8f7 100644 --- a/tests/worker/test_worker.py +++ b/tests/worker/test_worker.py @@ -429,7 +429,7 @@ def release_slot(self, ctx: SlotReleaseContext) -> None: self.seen_release_info_nonempty = True self.releases += 1 - def reserve_asserts(self, ctx): + def reserve_asserts(self, ctx: SlotReserveContext) -> None: assert ctx.task_queue is not None assert ctx.worker_identity is not None assert ctx.worker_build_id is not None @@ -816,7 +816,7 @@ async def run(self, args: Sequence[RawValue]) -> str: async def _test_worker_deployment_dynamic_workflow( client: Client, env: WorkflowEnvironment, - workflow_class, + workflow_class: type[Any], expected_versioning_behavior: temporalio.api.enums.v1.VersioningBehavior.ValueType, ): if env.supports_time_skipping: @@ -1146,7 +1146,7 @@ def __init__(self, worker: Worker) -> None: def __enter__(self) -> WorkerFailureInjector: return self - def __exit__(self, *args, **kwargs) -> None: + def __exit__(self, *args: Any, **kwargs: Any) -> None: self.workflow.shutdown() self.activity.shutdown() diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 8b333cbae..52dd97d8e 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -42,6 +42,8 @@ from typing_extensions import Literal, Protocol, runtime_checkable import temporalio.activity +import temporalio.api.sdk.v1 +import temporalio.client import temporalio.worker import temporalio.workflow from temporalio import activity, workflow @@ -51,6 +53,7 @@ from temporalio.api.sdk.v1 import EnhancedStackTrace from temporalio.api.workflowservice.v1 import ( GetWorkflowExecutionHistoryRequest, + PollWorkflowExecutionUpdateResponse, ResetStickyTaskQueueRequest, ) from temporalio.bridge.proto.workflow_activation import WorkflowActivation @@ -1023,6 +1026,8 @@ async def test_workflow_simple_child(client: Client): @workflow.defn class LongSleepWorkflow: + _started = False + @workflow.run async def run(self) -> None: self._started = True @@ -1085,6 +1090,8 @@ async def wait_forever() -> NoReturn: @workflow.defn class UncaughtCancelWorkflow: + _started = False + @workflow.run async def run(self, activity: bool) -> NoReturn: self._started = True @@ -1099,6 +1106,7 @@ async def run(self, activity: bool) -> NoReturn: True, id=f"{workflow.info().workflow_id}_child", ) + raise RuntimeError("Unreachable") @workflow.query def started(self) -> bool: @@ -1133,6 +1141,7 @@ async def started() -> bool: class CancelChildWorkflow: def __init__(self) -> None: self._ready = False + self._task: Optional[asyncio.Task[Any]] = None @workflow.run async def run(self, use_execute: bool) -> None: @@ -1155,6 +1164,7 @@ def ready(self) -> bool: @workflow.signal async def cancel_child(self) -> None: + assert self._task self._task.cancel() @@ -2338,7 +2348,7 @@ def complete(self) -> None: ... class DataClassTypedWorkflowAbstract(ABC): @workflow.run @abstractmethod - async def run(self, arg: MyDataClass) -> MyDataClass: ... + async def run(self, param: MyDataClass) -> MyDataClass: ... @workflow.signal @abstractmethod @@ -2922,7 +2932,7 @@ async def waiting_signal() -> bool: ) # Send signal to both and check results - await pre_patch_handle.signal(PatchMemoizedWorkflowPatched.signal) + await pre_patch_handle.signal(PatchMemoizedWorkflowUnpatched.signal) await post_patch_handle.signal(PatchMemoizedWorkflowPatched.signal) # Confirm expected values @@ -3342,6 +3352,8 @@ async def test_workflow_query_does_not_run_condition(client: Client): @workflow.defn class CancelSignalAndTimerFiredInSameTaskWorkflow: + timer_task: asyncio.Task[None] # type: ignore[reportUninitializedInstanceVariable] + @workflow.run async def run(self) -> None: # Start a 1 hour timer @@ -3424,6 +3436,7 @@ async def run(self) -> NoReturn: ) except ActivityError: raise MyCustomError("workflow error!") + raise RuntimeError("Unreachable") class CustomFailureConverter(DefaultFailureConverterWithEncodedAttributes): @@ -4790,7 +4803,9 @@ async def test_workflow_update_timeout_or_cancel(client: Client): called = asyncio.Event() unpatched_call = client.workflow_service.poll_workflow_execution_update - async def patched_call(*args, **kwargs): + async def patched_call( + *args: Any, **kwargs: Any + ) -> PollWorkflowExecutionUpdateResponse: called.set() return await unpatched_call(*args, **kwargs) @@ -5855,7 +5870,7 @@ async def _run_workflow_and_get_warning(self) -> bool: ): with pytest.WarningsRecorder() as warnings: if self.handler_type == "-update-": - assert update_task + assert update_task # type: ignore[reportUnboundVariable] if self.handler_waiting == "-wait-all-handlers-finish-": await update_task else: @@ -6053,7 +6068,9 @@ async def test_update_completion_is_honored_when_after_workflow_return_2( @workflow.defn class FirstCompletionCommandIsHonoredWorkflow: - def __init__(self, main_workflow_returns_before_signal_completions=False) -> None: + def __init__( + self, main_workflow_returns_before_signal_completions: bool = False + ) -> None: self.seen_first_signal = False self.seen_second_signal = False self.main_workflow_returns_before_signal_completions = ( @@ -6589,7 +6606,7 @@ async def bad_failure_converter_activity() -> None: @workflow.defn(sandboxed=False) class BadFailureConverterWorkflow: @workflow.run - async def run(self, fail_workflow_task) -> None: + async def run(self, fail_workflow_task: bool) -> None: if fail_workflow_task: raise BadFailureConverterError else: @@ -6911,7 +6928,7 @@ def __init__(self) -> None: @workflow.run async def run( self, - _: Optional[UseLockOrSemaphoreWorkflowParameters] = None, + params: Optional[UseLockOrSemaphoreWorkflowParameters] = None, ) -> LockOrSemaphoreWorkflowConcurrencySummary: await workflow.wait_condition(lambda: self.workflow_may_exit) return LockOrSemaphoreWorkflowConcurrencySummary( @@ -7110,7 +7127,7 @@ async def test_update_handler_semaphore_acquisition_respects_timeout( @workflow.defn class TimeoutErrorWorkflow: @workflow.run - async def run(self, scenario) -> None: + async def run(self, scenario: str) -> None: if scenario == "workflow.wait_condition": await workflow.wait_condition(lambda: False, timeout=0.01) elif scenario == "asyncio.wait_for": @@ -7509,7 +7526,7 @@ async def run(self) -> Optional[temporalio.workflow.RootInfo]: @workflow.defn class ExposeRootWorkflow: @workflow.run - async def run(self, child_wf_id) -> Optional[temporalio.workflow.RootInfo]: + async def run(self, child_wf_id: str) -> Optional[temporalio.workflow.RootInfo]: return await workflow.execute_child_workflow( ExposeRootChildWorkflow.run, id=child_wf_id ) @@ -8045,7 +8062,7 @@ async def test_workflow_logging_trace_identifier(client: Client): ) as worker: await client.execute_workflow( TaskFailOnceWorkflow.run, - id=f"workflow_failure_trace_identifier", + id="workflow_failure_trace_identifier", task_queue=worker.task_queue, ) @@ -8085,7 +8102,7 @@ async def test_in_workflow_sync(client: Client): ) as worker: res = await client.execute_workflow( UseInWorkflow.run, - id=f"test_in_workflow_sync", + id="test_in_workflow_sync", task_queue=worker.task_queue, execution_timeout=timedelta(minutes=1), )