Skip to content

Commit e3a1c8b

Browse files
committed
Support @task.bash with Task SDK
closes apache#48046 closes apache#45639
1 parent 243fe86 commit e3a1c8b

File tree

8 files changed

+85
-73
lines changed

8 files changed

+85
-73
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1356,7 +1356,7 @@ repos:
13561356
name: Check templated fields mapped in operators/sensors
13571357
language: python
13581358
entry: ./scripts/ci/pre_commit/check_template_fields.py
1359-
files: ^(providers/.*/)?airflow/.*/(sensors|operators)/.*\.py$
1359+
files: ^(providers/.*/)?airflow-core/.*/(sensors|operators)/.*\.py$
13601360
additional_dependencies: [ 'rich>=12.4.4' ]
13611361
require_serial: true
13621362
- id: update-migration-references

airflow-core/src/airflow/decorators/bash.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@
2323

2424
from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory
2525
from airflow.providers.standard.operators.bash import BashOperator
26+
from airflow.sdk.definitions._internal.types import SET_DURING_EXECUTION
2627
from airflow.utils.context import context_merge
2728
from airflow.utils.operator_helpers import determine_kwargs
28-
from airflow.utils.types import NOTSET
2929

3030
if TYPE_CHECKING:
3131
from airflow.sdk.definitions.context import Context
@@ -49,6 +49,7 @@ class _BashDecoratedOperator(DecoratedOperator, BashOperator):
4949
}
5050

5151
custom_operator_name: str = "@task.bash"
52+
overwrite_rtif_after_execution: bool = True
5253

5354
def __init__(
5455
self,
@@ -69,7 +70,7 @@ def __init__(
6970
python_callable=python_callable,
7071
op_args=op_args,
7172
op_kwargs=op_kwargs,
72-
bash_command=NOTSET,
73+
bash_command=SET_DURING_EXECUTION,
7374
multiple_outputs=False,
7475
**kwargs,
7576
)
@@ -83,6 +84,9 @@ def execute(self, context: Context) -> Any:
8384
if not isinstance(self.bash_command, str) or self.bash_command.strip() == "":
8485
raise TypeError("The returned value from the TaskFlow callable must be a non-empty string.")
8586

87+
self._is_inline_cmd = self._is_inline_command(bash_command=self.bash_command)
88+
context["ti"].render_templates()
89+
8690
return super().execute(context)
8791

8892

airflow-core/tests/unit/decorators/test_bash.py

+29-23
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,11 @@
2929
from airflow.decorators import task
3030
from airflow.exceptions import AirflowException, AirflowSkipException
3131
from airflow.models.renderedtifields import RenderedTaskInstanceFields
32+
from airflow.sdk.definitions._internal.types import SET_DURING_EXECUTION
3233
from airflow.utils import timezone
33-
from airflow.utils.types import NOTSET
3434

3535
from tests_common.test_utils.db import clear_db_dags, clear_db_runs, clear_rendered_ti_fields
36+
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
3637

3738
if TYPE_CHECKING:
3839
from airflow.models import TaskInstance
@@ -69,7 +70,10 @@ def execute_task(self, task):
6970

7071
@staticmethod
7172
def validate_bash_command_rtif(ti, expected_command):
72-
assert RenderedTaskInstanceFields.get_templated_fields(ti)["bash_command"] == expected_command
73+
if AIRFLOW_V_3_0_PLUS:
74+
assert ti.task.overwrite_rtif_after_execution
75+
else:
76+
assert RenderedTaskInstanceFields.get_templated_fields(ti)["bash_command"] == expected_command
7377

7478
def test_bash_decorator_init(self):
7579
"""Test the initialization of the @task.bash decorator."""
@@ -81,13 +85,13 @@ def bash(): ...
8185
bash_task = bash()
8286

8387
assert bash_task.operator.task_id == "bash"
84-
assert bash_task.operator.bash_command == NOTSET
88+
assert bash_task.operator.bash_command == SET_DURING_EXECUTION
8589
assert bash_task.operator.env is None
8690
assert bash_task.operator.append_env is False
8791
assert bash_task.operator.output_encoding == "utf-8"
8892
assert bash_task.operator.skip_on_exit_code == [99]
8993
assert bash_task.operator.cwd is None
90-
assert bash_task.operator._init_bash_command_not_set is True
94+
assert bash_task.operator._is_inline_cmd is None
9195

9296
@pytest.mark.parametrize(
9397
argnames=["command", "expected_command", "expected_return_val"],
@@ -108,13 +112,12 @@ def bash():
108112

109113
bash_task = bash()
110114

111-
assert bash_task.operator.bash_command == NOTSET
115+
assert bash_task.operator.bash_command == SET_DURING_EXECUTION
112116

113117
ti, return_val = self.execute_task(bash_task)
114118

115119
assert bash_task.operator.bash_command == expected_command
116120
assert return_val == expected_return_val
117-
118121
self.validate_bash_command_rtif(ti, expected_command)
119122

120123
def test_op_args_kwargs(self):
@@ -127,7 +130,7 @@ def bash(id, other_id):
127130

128131
bash_task = bash("world", other_id="2")
129132

130-
assert bash_task.operator.bash_command == NOTSET
133+
assert bash_task.operator.bash_command == SET_DURING_EXECUTION
131134

132135
ti, return_val = self.execute_task(bash_task)
133136

@@ -152,7 +155,7 @@ def bash(foo):
152155

153156
bash_task = bash("foo")
154157

155-
assert bash_task.operator.bash_command == NOTSET
158+
assert bash_task.operator.bash_command == SET_DURING_EXECUTION
156159

157160
ti, return_val = self.execute_task(bash_task)
158161

@@ -178,7 +181,7 @@ def bash():
178181

179182
bash_task = bash()
180183

181-
assert bash_task.operator.bash_command == NOTSET
184+
assert bash_task.operator.bash_command == SET_DURING_EXECUTION
182185

183186
with mock.patch.dict("os.environ", {"AIRFLOW_HOME": "path/to/airflow/home"}):
184187
ti, return_val = self.execute_task(bash_task)
@@ -207,7 +210,7 @@ def bash(code):
207210

208211
bash_task = bash(exit_code)
209212

210-
assert bash_task.operator.bash_command == NOTSET
213+
assert bash_task.operator.bash_command == SET_DURING_EXECUTION
211214

212215
with expected:
213216
ti, return_val = self.execute_task(bash_task)
@@ -251,7 +254,7 @@ def bash(code):
251254

252255
bash_task = bash(exit_code)
253256

254-
assert bash_task.operator.bash_command == NOTSET
257+
assert bash_task.operator.bash_command == SET_DURING_EXECUTION
255258

256259
with expected:
257260
ti, return_val = self.execute_task(bash_task)
@@ -297,7 +300,7 @@ def bash(command_file_name):
297300
with mock.patch.dict("os.environ", {"AIRFLOW_HOME": "path/to/airflow/home"}):
298301
bash_task = bash(f"{cmd_file} ")
299302

300-
assert bash_task.operator.bash_command == NOTSET
303+
assert bash_task.operator.bash_command == SET_DURING_EXECUTION
301304

302305
ti, return_val = self.execute_task(bash_task)
303306

@@ -319,7 +322,7 @@ def bash():
319322

320323
bash_task = bash()
321324

322-
assert bash_task.operator.bash_command == NOTSET
325+
assert bash_task.operator.bash_command == SET_DURING_EXECUTION
323326

324327
ti, return_val = self.execute_task(bash_task)
325328

@@ -339,7 +342,7 @@ def bash():
339342

340343
bash_task = bash()
341344

342-
assert bash_task.operator.bash_command == NOTSET
345+
assert bash_task.operator.bash_command == SET_DURING_EXECUTION
343346

344347
dr = self.dag_maker.create_dagrun()
345348
ti = dr.task_instances[0]
@@ -360,7 +363,7 @@ def bash():
360363

361364
bash_task = bash()
362365

363-
assert bash_task.operator.bash_command == NOTSET
366+
assert bash_task.operator.bash_command == SET_DURING_EXECUTION
364367

365368
dr = self.dag_maker.create_dagrun()
366369
ti = dr.task_instances[0]
@@ -378,7 +381,7 @@ def bash():
378381

379382
bash_task = bash()
380383

381-
assert bash_task.operator.bash_command == NOTSET
384+
assert bash_task.operator.bash_command == SET_DURING_EXECUTION
382385

383386
dr = self.dag_maker.create_dagrun()
384387
ti = dr.task_instances[0]
@@ -401,20 +404,21 @@ def bash():
401404
):
402405
bash_task = bash()
403406

404-
assert bash_task.operator.bash_command == NOTSET
407+
assert bash_task.operator.bash_command == SET_DURING_EXECUTION
405408

406409
ti, _ = self.execute_task(bash_task)
407410

408411
assert bash_task.operator.multiple_outputs is False
409412
self.validate_bash_command_rtif(ti, "echo")
410413

411414
@pytest.mark.parametrize(
412-
"multiple_outputs", [False, pytest.param(None, id="none"), pytest.param(NOTSET, id="not-set")]
415+
"multiple_outputs",
416+
[False, pytest.param(None, id="none"), pytest.param(SET_DURING_EXECUTION, id="not-set")],
413417
)
414418
def test_multiple_outputs(self, multiple_outputs):
415419
"""Verify setting `multiple_outputs` for a @task.bash-decorated function is ignored."""
416420
decorator_kwargs = {}
417-
if multiple_outputs is not NOTSET:
421+
if multiple_outputs is not SET_DURING_EXECUTION:
418422
decorator_kwargs["multiple_outputs"] = multiple_outputs
419423

420424
with self.dag:
@@ -428,7 +432,7 @@ def bash():
428432

429433
bash_task = bash()
430434

431-
assert bash_task.operator.bash_command == NOTSET
435+
assert bash_task.operator.bash_command == SET_DURING_EXECUTION
432436

433437
ti, _ = self.execute_task(bash_task)
434438

@@ -440,7 +444,9 @@ def bash():
440444
argvalues=[
441445
pytest.param(None, pytest.raises(TypeError), id="return_none_typeerror"),
442446
pytest.param(1, pytest.raises(TypeError), id="return_int_typeerror"),
443-
pytest.param(NOTSET, pytest.raises(TypeError), id="return_notset_typeerror"),
447+
pytest.param(
448+
SET_DURING_EXECUTION, pytest.raises(TypeError), id="return_SET_DURING_EXECUTION_typeerror"
449+
),
444450
pytest.param(True, pytest.raises(TypeError), id="return_boolean_typeerror"),
445451
pytest.param("", pytest.raises(TypeError), id="return_empty_string_typerror"),
446452
pytest.param(" ", pytest.raises(TypeError), id="return_spaces_string_typerror"),
@@ -458,7 +464,7 @@ def bash():
458464

459465
bash_task = bash()
460466

461-
assert bash_task.operator.bash_command == NOTSET
467+
assert bash_task.operator.bash_command == SET_DURING_EXECUTION
462468

463469
with expected:
464470
ti, _ = self.execute_task(bash_task)
@@ -475,7 +481,7 @@ def bash():
475481

476482
bash_task = bash()
477483

478-
assert bash_task.operator.bash_command == NOTSET
484+
assert bash_task.operator.bash_command == SET_DURING_EXECUTION
479485

480486
dr = self.dag_maker.create_dagrun()
481487
ti = dr.task_instances[0]

providers/standard/src/airflow/providers/standard/operators/bash.py

+5-47
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,9 @@
2828
from airflow.models.baseoperator import BaseOperator
2929
from airflow.providers.standard.hooks.subprocess import SubprocessHook, SubprocessResult, working_directory
3030
from airflow.utils.operator_helpers import context_to_airflow_vars
31-
from airflow.utils.session import NEW_SESSION, provide_session
32-
from airflow.utils.types import ArgNotSet
3331

3432
if TYPE_CHECKING:
35-
from sqlalchemy.orm import Session as SASession
33+
from airflow.utils.types import ArgNotSet
3634

3735
try:
3836
from airflow.sdk.definitions.context import Context
@@ -182,43 +180,15 @@ def __init__(
182180
self.cwd = cwd
183181
self.append_env = append_env
184182
self.output_processor = output_processor
185-
186-
# When using the @task.bash decorator, the Bash command is not known until the underlying Python
187-
# callable is executed and therefore set to NOTSET initially. This flag is useful during execution to
188-
# determine whether the bash_command value needs to re-rendered.
189-
self._init_bash_command_not_set = isinstance(self.bash_command, ArgNotSet)
190-
191-
# Keep a copy of the original bash_command, without the Jinja template rendered.
192-
# This is later used to determine if the bash_command is a script or an inline string command.
193-
# We do this later, because the bash_command is not available in __init__ when using @task.bash.
194-
self._unrendered_bash_command: str | ArgNotSet = bash_command
183+
self._is_inline_cmd = None
184+
if isinstance(bash_command, str):
185+
self._is_inline_cmd = self._is_inline_command(bash_command=bash_command)
195186

196187
@cached_property
197188
def subprocess_hook(self):
198189
"""Returns hook for running the bash command."""
199190
return SubprocessHook()
200191

201-
# TODO: This should be replaced with Task SDK API call
202-
@staticmethod
203-
@provide_session
204-
def refresh_bash_command(ti, session: SASession = NEW_SESSION) -> None:
205-
"""
206-
Rewrite the underlying rendered bash_command value for a task instance in the metadatabase.
207-
208-
TaskInstance.get_rendered_template_fields() cannot be used because this will retrieve the
209-
RenderedTaskInstanceFields from the metadatabase which doesn't have the runtime-evaluated bash_command
210-
value.
211-
212-
:meta private:
213-
"""
214-
from airflow.models.renderedtifields import RenderedTaskInstanceFields
215-
216-
"""Update rendered task instance fields for cases where runtime evaluated, not templated."""
217-
218-
rtif = RenderedTaskInstanceFields(ti)
219-
RenderedTaskInstanceFields.write(rtif, session=session)
220-
RenderedTaskInstanceFields.delete_old_records(ti.task_id, ti.dag_id, session=session)
221-
222192
def get_env(self, context) -> dict:
223193
"""Build the set of environment variables to be exposed for the bash command."""
224194
system_env = os.environ.copy()
@@ -247,19 +217,7 @@ def execute(self, context: Context):
247217
raise AirflowException(f"The cwd {self.cwd} must be a directory")
248218
env = self.get_env(context)
249219

250-
# Because the bash_command value is evaluated at runtime using the @task.bash decorator, the
251-
# RenderedTaskInstanceField data needs to be rewritten and the bash_command value re-rendered -- the
252-
# latter because the returned command from the decorated callable could contain a Jinja expression.
253-
# Both will ensure the correct Bash command is executed and that the Rendered Template view in the UI
254-
# displays the executed command (otherwise it will display as an ArgNotSet type).
255-
if self._init_bash_command_not_set:
256-
is_inline_command = self._is_inline_command(bash_command=cast(str, self.bash_command))
257-
ti = context["ti"]
258-
self.refresh_bash_command(ti)
259-
else:
260-
is_inline_command = self._is_inline_command(bash_command=cast(str, self._unrendered_bash_command))
261-
262-
if is_inline_command:
220+
if self._is_inline_cmd:
263221
result = self._run_inline_command(bash_path=bash_path, env=env)
264222
else:
265223
result = self._run_rendered_script_file(bash_path=bash_path, env=env)

task-sdk/src/airflow/sdk/definitions/_internal/types.py

+8
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,14 @@ def deserialize(cls):
4949
"""Sentinel value for argument default. See ``ArgNotSet``."""
5050

5151

52+
class SetDuringExecution(ArgNotSet):
53+
def serialize(self) -> str:
54+
return "DYNAMIC (set during execution)"
55+
56+
57+
SET_DURING_EXECUTION = SetDuringExecution()
58+
59+
5260
if TYPE_CHECKING:
5361
import logging
5462

task-sdk/src/airflow/sdk/definitions/baseoperator.py

+5
Original file line numberDiff line numberDiff line change
@@ -899,6 +899,11 @@ def say_hello_world(**context):
899899
# Defines if the operator supports lineage without manual definitions
900900
supports_lineage: bool = False
901901

902+
# If True, the Rendered Template fields will be overwritten in DB after execution
903+
# This is useful for Taskflow decorators that modify the template fields during execution like
904+
# @task.bash decorator.
905+
overwrite_rtif_after_execution: bool = False
906+
902907
# If True then the class constructor was called
903908
__instantiated: bool = False
904909
# List of args as passed to `init()`, after apply_defaults() has been updated. Used to "recreate" the task

task-sdk/src/airflow/sdk/execution_time/task_runner.py

+10
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,16 @@ def finalize(
889889
log.debug("Setting xcom for operator extra link", link=link, xcom_key=xcom_key)
890890
_xcom_push(ti, key=xcom_key, value=link)
891891

892+
if getattr(ti.task, "overwrite_rtif_after_execution", False):
893+
log.debug("Overwriting Rendered template fields.")
894+
if ti.task.template_fields:
895+
SUPERVISOR_COMMS.send_request(
896+
log=log,
897+
msg=SetRenderedFields(
898+
rendered_fields={field: getattr(ti.task, field) for field in ti.task.template_fields}
899+
),
900+
)
901+
892902
log.debug("Running finalizers", ti=ti)
893903
if state in [TerminalTIState.SUCCESS]:
894904
get_listener_manager().hook.on_task_instance_success(

0 commit comments

Comments
 (0)