Skip to content

Commit 53e1723

Browse files
authored
AIP-82 Handle trigger serialization (apache#45562)
1 parent b625c70 commit 53e1723

File tree

8 files changed

+122
-35
lines changed

8 files changed

+122
-35
lines changed

airflow/dag_processing/collection.py

+20-19
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
import json
3131
import logging
3232
import traceback
33-
from typing import TYPE_CHECKING, Any, NamedTuple
33+
from typing import TYPE_CHECKING, Any, NamedTuple, cast
3434

3535
from sqlalchemy import and_, delete, exists, func, insert, select, tuple_
3636
from sqlalchemy.exc import OperationalError
@@ -53,8 +53,7 @@
5353
from airflow.models.errors import ParseImportError
5454
from airflow.models.trigger import Trigger
5555
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUriRef
56-
from airflow.serialization.serialized_objects import BaseSerialization
57-
from airflow.triggers.base import BaseTrigger
56+
from airflow.serialization.serialized_objects import BaseSerialization, SerializedAssetWatcher
5857
from airflow.utils.retries import MAX_DB_RETRIES, run_with_db_retries
5958
from airflow.utils.sqlalchemy import with_row_locks
6059
from airflow.utils.timezone import utcnow
@@ -68,7 +67,6 @@
6867

6968
from airflow.models.dagwarning import DagWarning
7069
from airflow.serialization.serialized_objects import MaybeSerializedDAG
71-
from airflow.triggers.base import BaseTrigger
7270
from airflow.typing_compat import Self
7371

7472
log = logging.getLogger(__name__)
@@ -747,16 +745,23 @@ def add_asset_trigger_references(
747745
# Update references from assets being used
748746
refs_to_add: dict[tuple[str, str], set[int]] = {}
749747
refs_to_remove: dict[tuple[str, str], set[int]] = {}
750-
triggers: dict[int, BaseTrigger] = {}
748+
triggers: dict[int, dict] = {}
751749

752750
# Optimization: if no asset collected, skip fetching active assets
753751
active_assets = _find_active_assets(self.assets.keys(), session=session) if self.assets else {}
754752

755753
for name_uri, asset in self.assets.items():
756754
# If the asset belong to a DAG not active or paused, consider there is no watcher associated to it
757-
asset_watchers = asset.watchers if name_uri in active_assets else []
758-
trigger_hash_to_trigger_dict: dict[int, BaseTrigger] = {
759-
self._get_base_trigger_hash(trigger): trigger for trigger in asset_watchers
755+
asset_watchers: list[SerializedAssetWatcher] = (
756+
[cast(SerializedAssetWatcher, watcher) for watcher in asset.watchers]
757+
if name_uri in active_assets
758+
else []
759+
)
760+
trigger_hash_to_trigger_dict: dict[int, dict] = {
761+
self._get_trigger_hash(
762+
watcher.trigger["classpath"], watcher.trigger["kwargs"]
763+
): watcher.trigger
764+
for watcher in asset_watchers
760765
}
761766
triggers.update(trigger_hash_to_trigger_dict)
762767
trigger_hash_from_asset: set[int] = set(trigger_hash_to_trigger_dict.keys())
@@ -783,7 +788,10 @@ def add_asset_trigger_references(
783788
}
784789

785790
all_trigger_keys: set[tuple[str, str]] = {
786-
self._encrypt_trigger_kwargs(triggers[trigger_hash])
791+
(
792+
triggers[trigger_hash]["classpath"],
793+
Trigger.encrypt_kwargs(triggers[trigger_hash]["kwargs"]),
794+
)
787795
for trigger_hashes in refs_to_add.values()
788796
for trigger_hash in trigger_hashes
789797
}
@@ -800,7 +808,9 @@ def add_asset_trigger_references(
800808
new_trigger_models = [
801809
trigger
802810
for trigger in [
803-
Trigger.from_object(triggers[trigger_hash])
811+
Trigger(
812+
classpath=triggers[trigger_hash]["classpath"], kwargs=triggers[trigger_hash]["kwargs"]
813+
)
804814
for trigger_hash in all_trigger_hashes
805815
if trigger_hash not in orm_triggers
806816
]
@@ -836,11 +846,6 @@ def add_asset_trigger_references(
836846
if (asset_model.name, asset_model.uri) not in self.assets:
837847
asset_model.triggers = []
838848

839-
@staticmethod
840-
def _encrypt_trigger_kwargs(trigger: BaseTrigger) -> tuple[str, str]:
841-
classpath, kwargs = trigger.serialize()
842-
return classpath, Trigger.encrypt_kwargs(kwargs)
843-
844849
@staticmethod
845850
def _get_trigger_hash(classpath: str, kwargs: dict[str, Any]) -> int:
846851
"""
@@ -852,7 +857,3 @@ def _get_trigger_hash(classpath: str, kwargs: dict[str, Any]) -> int:
852857
This is not true for event driven scheduling.
853858
"""
854859
return hash((classpath, json.dumps(BaseSerialization.serialize(kwargs)).encode("utf-8")))
855-
856-
def _get_base_trigger_hash(self, trigger: BaseTrigger) -> int:
857-
classpath, kwargs = trigger.serialize()
858-
return self._get_trigger_hash(classpath, kwargs)

airflow/example_dags/example_asset_with_watchers.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,14 @@
2121
from __future__ import annotations
2222

2323
import os
24-
import tempfile
2524

2625
from airflow.decorators import task
2726
from airflow.models.baseoperator import chain
2827
from airflow.models.dag import DAG
2928
from airflow.providers.standard.triggers.file import FileTrigger
30-
from airflow.sdk.definitions.asset import Asset
29+
from airflow.sdk import Asset, AssetWatcher
3130

32-
file_path = tempfile.NamedTemporaryFile().name
31+
file_path = "/tmp/test"
3332

3433
with DAG(
3534
dag_id="example_create_file",
@@ -44,7 +43,7 @@ def create_file():
4443
chain(create_file())
4544

4645
trigger = FileTrigger(filepath=file_path, poke_interval=10)
47-
asset = Asset("example_asset", watchers=[trigger])
46+
asset = Asset("example_asset", watchers=[AssetWatcher(name="test_file_watcher", trigger=trigger)])
4847

4948
with DAG(
5049
dag_id="example_asset_with_watchers",

airflow/serialization/schema.json

+12
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@
6464
{"type": "null"},
6565
{ "$ref": "#/definitions/dict" }
6666
]
67+
},
68+
"watchers": {
69+
"type": "array",
70+
"items": { "$ref": "#/definitions/trigger" }
6771
}
6872
},
6973
"required": [ "uri", "extra" ]
@@ -126,6 +130,14 @@
126130
],
127131
"additionalProperties": false
128132
},
133+
"trigger": {
134+
"type": "object",
135+
"properties": {
136+
"classpath": { "type": "string" },
137+
"kwargs": { "$ref": "#/definitions/dict" }
138+
},
139+
"required": [ "classpath", "kwargs" ]
140+
},
129141
"dict": {
130142
"description": "A python dictionary containing values of any type",
131143
"type": "object"

airflow/serialization/serialized_objects.py

+44-3
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
AssetAny,
6060
AssetRef,
6161
AssetUniqueKey,
62+
AssetWatcher,
6263
BaseAsset,
6364
)
6465
from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator
@@ -251,13 +252,34 @@ def encode_asset_condition(var: BaseAsset) -> dict[str, Any]:
251252
:meta private:
252253
"""
253254
if isinstance(var, Asset):
254-
return {
255+
256+
def _encode_watcher(watcher: AssetWatcher):
257+
return {
258+
"name": watcher.name,
259+
"trigger": _encode_trigger(watcher.trigger),
260+
}
261+
262+
def _encode_trigger(trigger: BaseTrigger | dict):
263+
if isinstance(trigger, dict):
264+
return trigger
265+
classpath, kwargs = trigger.serialize()
266+
return {
267+
"classpath": classpath,
268+
"kwargs": kwargs,
269+
}
270+
271+
asset = {
255272
"__type": DAT.ASSET,
256273
"name": var.name,
257274
"uri": var.uri,
258275
"group": var.group,
259276
"extra": var.extra,
260277
}
278+
279+
if len(var.watchers) > 0:
280+
asset["watchers"] = [_encode_watcher(watcher) for watcher in var.watchers]
281+
282+
return asset
261283
if isinstance(var, AssetAlias):
262284
return {"__type": DAT.ASSET_ALIAS, "name": var.name, "group": var.group}
263285
if isinstance(var, AssetAll):
@@ -283,7 +305,7 @@ def decode_asset_condition(var: dict[str, Any]) -> BaseAsset:
283305
"""
284306
dat = var["__type"]
285307
if dat == DAT.ASSET:
286-
return Asset(name=var["name"], uri=var["uri"], group=var["group"], extra=var["extra"])
308+
return decode_asset(var)
287309
if dat == DAT.ASSET_ALL:
288310
return AssetAll(*(decode_asset_condition(x) for x in var["objects"]))
289311
if dat == DAT.ASSET_ANY:
@@ -295,6 +317,19 @@ def decode_asset_condition(var: dict[str, Any]) -> BaseAsset:
295317
raise ValueError(f"deserialization not implemented for DAT {dat!r}")
296318

297319

320+
def decode_asset(var: dict[str, Any]):
321+
watchers = var.get("watchers", [])
322+
return Asset(
323+
name=var["name"],
324+
uri=var["uri"],
325+
group=var["group"],
326+
extra=var["extra"],
327+
watchers=[
328+
SerializedAssetWatcher(name=watcher["name"], trigger=watcher["trigger"]) for watcher in watchers
329+
],
330+
)
331+
332+
298333
def encode_outlet_event_accessor(var: OutletEventAccessor) -> dict[str, Any]:
299334
key = var.key
300335
return {
@@ -874,7 +909,7 @@ def deserialize(cls, encoded_var: Any) -> Any:
874909
elif type_ == DAT.XCOM_REF:
875910
return _XComRef(var) # Delay deserializing XComArg objects until we have the entire DAG.
876911
elif type_ == DAT.ASSET:
877-
return Asset(**var)
912+
return decode_asset(var)
878913
elif type_ == DAT.ASSET_ALIAS:
879914
return AssetAlias(**var)
880915
elif type_ == DAT.ASSET_ANY:
@@ -1810,6 +1845,12 @@ def set_ref(task: Operator) -> Operator:
18101845
return group
18111846

18121847

1848+
class SerializedAssetWatcher(AssetWatcher):
1849+
"""JSON serializable representation of an asset watcher."""
1850+
1851+
trigger: dict
1852+
1853+
18131854
def _has_kubernetes() -> bool:
18141855
global HAS_KUBERNETES
18151856
if "HAS_KUBERNETES" in globals():

task_sdk/src/airflow/sdk/__init__.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
from typing import TYPE_CHECKING
2020

2121
__all__ = [
22+
"__version__",
23+
"Asset",
24+
"AssetWatcher",
2225
"BaseOperator",
2326
"Connection",
2427
"DAG",
@@ -27,7 +30,6 @@
2730
"MappedOperator",
2831
"TaskGroup",
2932
"XComArg",
30-
"__version__",
3133
"dag",
3234
"get_current_context",
3335
"get_parsing_context",
@@ -36,6 +38,7 @@
3638
__version__ = "1.0.0.alpha1"
3739

3840
if TYPE_CHECKING:
41+
from airflow.sdk.definitions.asset import Asset, AssetWatcher
3942
from airflow.sdk.definitions.baseoperator import BaseOperator
4043
from airflow.sdk.definitions.connection import Connection
4144
from airflow.sdk.definitions.context import get_current_context, get_parsing_context
@@ -60,6 +63,8 @@
6063
"dag": ".definitions.dag",
6164
"get_current_context": ".definitions.context",
6265
"get_parsing_context": ".definitions.context",
66+
"Asset": ".definitions.asset",
67+
"AssetWatcher": ".definitions.asset",
6368
}
6469

6570

task_sdk/src/airflow/sdk/definitions/asset/__init__.py

+20-5
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from sqlalchemy.orm import Session
3838

3939
from airflow.models.asset import AssetModel
40+
from airflow.serialization.serialized_objects import SerializedAssetWatcher
4041
from airflow.triggers.base import BaseTrigger
4142

4243
AttrsInstance = attrs.AttrsInstance
@@ -54,6 +55,7 @@
5455
"AssetNameRef",
5556
"AssetRef",
5657
"AssetUriRef",
58+
"AssetWatcher",
5759
]
5860

5961

@@ -252,6 +254,19 @@ def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDepe
252254
raise NotImplementedError
253255

254256

257+
@attrs.define(frozen=True)
258+
class AssetWatcher:
259+
"""A representation of an asset watcher. The name uniquely identifies the watch."""
260+
261+
name: str
262+
# This attribute serves double purpose.
263+
# For a "normal" asset instance loaded from DAG, this holds the trigger used to monitor an external
264+
# resource. In that case, ``AssetWatcher`` is used directly by users.
265+
# For an asset recreated from a serialized DAG, this holds the serialized data of the trigger. In that
266+
# case, `SerializedAssetWatcher` is used. We need to keep the two types to make mypy happy.
267+
trigger: BaseTrigger | dict
268+
269+
255270
@attrs.define(init=False, unsafe_hash=False)
256271
class Asset(os.PathLike, BaseAsset):
257272
"""A representation of data asset dependencies between workflows."""
@@ -271,7 +286,7 @@ class Asset(os.PathLike, BaseAsset):
271286
factory=dict,
272287
converter=_set_extra_default,
273288
)
274-
watchers: list[BaseTrigger] = attrs.field(
289+
watchers: list[AssetWatcher | SerializedAssetWatcher] = attrs.field(
275290
factory=list,
276291
)
277292

@@ -286,7 +301,7 @@ def __init__(
286301
*,
287302
group: str = ...,
288303
extra: dict | None = None,
289-
watchers: list[BaseTrigger] = ...,
304+
watchers: list[AssetWatcher | SerializedAssetWatcher] = ...,
290305
) -> None:
291306
"""Canonical; both name and uri are provided."""
292307

@@ -297,7 +312,7 @@ def __init__(
297312
*,
298313
group: str = ...,
299314
extra: dict | None = None,
300-
watchers: list[BaseTrigger] = ...,
315+
watchers: list[AssetWatcher | SerializedAssetWatcher] = ...,
301316
) -> None:
302317
"""It's possible to only provide the name, either by keyword or as the only positional argument."""
303318

@@ -308,7 +323,7 @@ def __init__(
308323
uri: str,
309324
group: str = ...,
310325
extra: dict | None = None,
311-
watchers: list[BaseTrigger] = ...,
326+
watchers: list[AssetWatcher | SerializedAssetWatcher] = ...,
312327
) -> None:
313328
"""It's possible to only provide the URI as a keyword argument."""
314329

@@ -319,7 +334,7 @@ def __init__(
319334
*,
320335
group: str | None = None,
321336
extra: dict | None = None,
322-
watchers: list[BaseTrigger] | None = None,
337+
watchers: list[AssetWatcher | SerializedAssetWatcher] | None = None,
323338
) -> None:
324339
if name is None and uri is None:
325340
raise TypeError("Asset() requires either 'name' or 'uri'")

tests/dag_processing/test_collection.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
from airflow.models.serialized_dag import SerializedDagModel
5151
from airflow.operators.empty import EmptyOperator
5252
from airflow.providers.standard.triggers.temporal import TimeDeltaTrigger
53-
from airflow.sdk.definitions.asset import Asset
53+
from airflow.sdk.definitions.asset import Asset, AssetWatcher
5454
from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG
5555
from airflow.utils import timezone as tz
5656
from airflow.utils.session import create_session
@@ -131,7 +131,11 @@ def per_test(self) -> Generator:
131131
)
132132
def test_add_asset_trigger_references(self, is_active, is_paused, expected_num_triggers, dag_maker):
133133
trigger = TimeDeltaTrigger(timedelta(seconds=0))
134-
asset = Asset("test_add_asset_trigger_references_asset", watchers=[trigger])
134+
classpath, kwargs = trigger.serialize()
135+
asset = Asset(
136+
"test_add_asset_trigger_references_asset",
137+
watchers=[AssetWatcher(name="test", trigger={"classpath": classpath, "kwargs": kwargs})],
138+
)
135139

136140
with dag_maker(dag_id="test_add_asset_trigger_references_dag", schedule=[asset]) as dag:
137141
EmptyOperator(task_id="mytask")

0 commit comments

Comments
 (0)