Skip to content

Commit 90bbdaf

Browse files
the-cybersapienAditya Aggarwal
andauthored
Add support for python_callable_file for BranchPythonOperator (astronomer#84)
* Add support for python_callable_file for BranchPythonOperator * Fix pylint errors in dagbuilder * Reformat using black Co-authored-by: Aditya Aggarwal <[email protected]>
1 parent c9a1d20 commit 90bbdaf

File tree

5 files changed

+95
-55
lines changed

5 files changed

+95
-55
lines changed

dagfactory/dagbuilder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from airflow.models import Variable
1010
from airflow.contrib.operators.kubernetes_pod_operator import KubernetesPodOperator
1111
from airflow.models import BaseOperator
12-
from airflow.operators.python_operator import PythonOperator
12+
from airflow.operators.python_operator import PythonOperator, BranchPythonOperator
1313
from airflow.utils.module_loading import import_string
1414
from airflow import __version__ as AIRFLOW_VERSION
1515

@@ -179,13 +179,13 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator:
179179
except Exception as err:
180180
raise Exception(f"Failed to import operator: {operator}") from err
181181
try:
182-
if operator_obj == PythonOperator:
182+
if operator_obj in [PythonOperator, BranchPythonOperator]:
183183
if not task_params.get("python_callable_name") and not task_params.get(
184184
"python_callable_file"
185185
):
186186
raise Exception(
187-
"Failed to create task. PythonOperator requires `python_callable_name` \
188-
and `python_callable_file` parameters."
187+
"Failed to create task. PythonOperator and BranchPythonOperator requires \
188+
`python_callable_name` and `python_callable_file` parameters."
189189
)
190190
task_params["python_callable"]: Callable = utils.get_python_callable(
191191
task_params["python_callable_name"],

examples/customized/operators/breakfast_operators.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
class MakeBreadOperator(BaseOperator):
5-
template_fields = ('bread_type',)
5+
template_fields = ("bread_type",)
66

77
def __init__(self, bread_type, *args, **kwargs):
88
super(MakeBreadOperator, self).__init__(*args, **kwargs)
@@ -13,7 +13,7 @@ def execute(self, context):
1313

1414

1515
class MakeCoffeeOperator(BaseOperator):
16-
template_fields = ('coffee_type',)
16+
template_fields = ("coffee_type",)
1717

1818
def __init__(self, coffee_type, *args, **kwargs):
1919
super(MakeCoffeeOperator, self).__init__(*args, **kwargs)

tests/fixtures/doc_md_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
def mydocmdbuilder(**kwargs):
2-
return f"{kwargs}"
2+
return f"{kwargs}"

tests/test_dagbuilder.py

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,12 @@
3333
"default_args": {"owner": "custom_owner"},
3434
"description": "this is an example dag",
3535
"schedule_interval": "0 3 * * *",
36-
"tags" : ["tag1","tag2"],
36+
"tags": ["tag1", "tag2"],
3737
"tasks": {
3838
"task_1": {
3939
"operator": "airflow.operators.bash_operator.BashOperator",
4040
"bash_command": "echo 1",
41-
"execution_timeout_secs" : 5,
41+
"execution_timeout_secs": 5,
4242
},
4343
"task_2": {
4444
"operator": "airflow.operators.bash_operator.BashOperator",
@@ -63,7 +63,7 @@
6363
"task_group_2": {
6464
"dependencies": ["task_group_1"],
6565
},
66-
"task_group_3": {}
66+
"task_group_3": {},
6767
},
6868
"tasks": {
6969
"task_1": {
@@ -110,15 +110,15 @@
110110
},
111111
"description": "this is an example dag",
112112
"schedule_interval": "0 3 * * *",
113-
"tags" : ["tag1","tag2"],
113+
"tags": ["tag1", "tag2"],
114114
"on_failure_callback": f"{__name__}.print_context_callback",
115115
"on_success_callback": f"{__name__}.print_context_callback",
116116
"sla_miss_callback": f"{__name__}.print_context_callback",
117117
"tasks": {
118118
"task_1": {
119119
"operator": "airflow.operators.bash_operator.BashOperator",
120120
"bash_command": "echo 1",
121-
"execution_timeout_secs" : 5,
121+
"execution_timeout_secs": 5,
122122
"on_failure_callback": f"{__name__}.print_context_callback",
123123
"on_success_callback": f"{__name__}.print_context_callback",
124124
"on_execute_callback": f"{__name__}.print_context_callback",
@@ -175,7 +175,7 @@ def test_get_dag_params():
175175
"task_1": {
176176
"operator": "airflow.operators.bash_operator.BashOperator",
177177
"bash_command": "echo 1",
178-
"execution_timeout_secs": 5
178+
"execution_timeout_secs": 5,
179179
},
180180
"task_2": {
181181
"operator": "airflow.operators.bash_operator.BashOperator",
@@ -202,7 +202,11 @@ def test_get_dag_params_no_start_date():
202202
def test_make_task_valid():
203203
td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG, DEFAULT_CONFIG)
204204
operator = "airflow.operators.bash_operator.BashOperator"
205-
task_params = {"task_id": "test_task", "bash_command": "echo 1","execution_timeout_secs":5}
205+
task_params = {
206+
"task_id": "test_task",
207+
"bash_command": "echo 1",
208+
"execution_timeout_secs": 5,
209+
}
206210
actual = td.make_task(operator, task_params)
207211
assert actual.task_id == "test_task"
208212
assert actual.bash_command == "echo 1"
@@ -258,8 +262,8 @@ def test_build():
258262
assert isinstance(actual["dag"], DAG)
259263
assert len(actual["dag"].tasks) == 3
260264
assert actual["dag"].task_dict["task_1"].downstream_task_ids == {"task_2", "task_3"}
261-
if version.parse(AIRFLOW_VERSION) >= version.parse('1.10.8') :
262-
assert actual["dag"].tags == ["tag1","tag2"]
265+
if version.parse(AIRFLOW_VERSION) >= version.parse("1.10.8"):
266+
assert actual["dag"].tags == ["tag1", "tag2"]
263267

264268

265269
def test_get_dag_params():
@@ -274,7 +278,10 @@ def test_get_dag_params():
274278
},
275279
"schedule_interval": "0 3 * * *",
276280
"task_groups": {
277-
"task_group_1": {"tooltip": "this is a task group", "dependencies": ["task_1"]},
281+
"task_group_1": {
282+
"tooltip": "this is a task group",
283+
"dependencies": ["task_1"],
284+
},
278285
"task_group_2": {"dependencies": ["task_group_1"]},
279286
"task_group_3": {},
280287
},
@@ -332,17 +339,24 @@ def test_build_task_groups():
332339
td.build()
333340
else:
334341
actual = td.build()
335-
task_group_1 = {t for t in actual["dag"].task_dict if t.startswith("task_group_1")}
336-
task_group_2 = {t for t in actual["dag"].task_dict if t.startswith("task_group_2")}
342+
task_group_1 = {
343+
t for t in actual["dag"].task_dict if t.startswith("task_group_1")
344+
}
345+
task_group_2 = {
346+
t for t in actual["dag"].task_dict if t.startswith("task_group_2")
347+
}
337348
assert actual["dag_id"] == "test_dag"
338349
assert isinstance(actual["dag"], DAG)
339350
assert len(actual["dag"].tasks) == 6
340-
assert actual["dag"].task_dict["task_1"].downstream_task_ids == {"task_group_1.task_2"}
351+
assert actual["dag"].task_dict["task_1"].downstream_task_ids == {
352+
"task_group_1.task_2"
353+
}
341354
assert actual["dag"].task_dict["task_group_1.task_2"].downstream_task_ids == {
342355
"task_group_1.task_3"
343356
}
344357
assert actual["dag"].task_dict["task_group_1.task_3"].downstream_task_ids == {
345-
"task_4", "task_group_2.task_5",
358+
"task_4",
359+
"task_group_2.task_5",
346360
}
347361
assert actual["dag"].task_dict["task_group_2.task_5"].downstream_task_ids == {
348362
"task_group_2.task_6",
@@ -360,7 +374,9 @@ def test_make_task_groups():
360374
}
361375
dag = "dag"
362376
task_groups = dagbuilder.DagBuilder.make_task_groups(task_group_dict, dag)
363-
expected = MockTaskGroup(tooltip="this is a task group", group_id="task_group", dag=dag)
377+
expected = MockTaskGroup(
378+
tooltip="this is a task group", group_id="task_group", dag=dag
379+
)
364380
if version.parse(AIRFLOW_VERSION) < version.parse("2.0.0"):
365381
assert task_groups == {}
366382
else:
@@ -371,6 +387,7 @@ def test_make_task_groups_empty():
371387
task_groups = dagbuilder.DagBuilder.make_task_groups({}, None)
372388
assert task_groups == {}
373389

390+
374391
def print_context_callback(context, **kwargs):
375392
print(context)
376393

@@ -393,10 +410,11 @@ def test_make_task_with_callback():
393410
assert isinstance(actual, PythonOperator)
394411
assert callable(actual.on_failure_callback)
395412
assert callable(actual.on_success_callback)
396-
if version.parse(AIRFLOW_VERSION) >= version.parse('2.0.0') :
413+
if version.parse(AIRFLOW_VERSION) >= version.parse("2.0.0"):
397414
assert callable(actual.on_execute_callback)
398415
assert callable(actual.on_retry_callback)
399416

417+
400418
def test_make_dag_with_callback():
401419
td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG_CALLBACK, DEFAULT_CONFIG)
402-
td.build()
420+
td.build()

0 commit comments

Comments
 (0)