30
30
import json
31
31
import logging
32
32
import traceback
33
- from typing import TYPE_CHECKING , Any , NamedTuple
33
+ from typing import TYPE_CHECKING , Any , NamedTuple , cast
34
34
35
35
from sqlalchemy import and_ , delete , exists , func , insert , select , tuple_
36
36
from sqlalchemy .exc import OperationalError
53
53
from airflow .models .errors import ParseImportError
54
54
from airflow .models .trigger import Trigger
55
55
from airflow .sdk .definitions .asset import Asset , AssetAlias , AssetNameRef , AssetUriRef
56
- from airflow .serialization .serialized_objects import BaseSerialization
57
- from airflow .triggers .base import BaseTrigger
56
+ from airflow .serialization .serialized_objects import BaseSerialization , SerializedAssetWatcher
58
57
from airflow .utils .retries import MAX_DB_RETRIES , run_with_db_retries
59
58
from airflow .utils .sqlalchemy import with_row_locks
60
59
from airflow .utils .timezone import utcnow
68
67
69
68
from airflow .models .dagwarning import DagWarning
70
69
from airflow .serialization .serialized_objects import MaybeSerializedDAG
71
- from airflow .triggers .base import BaseTrigger
72
70
from airflow .typing_compat import Self
73
71
74
72
log = logging .getLogger (__name__ )
@@ -747,16 +745,23 @@ def add_asset_trigger_references(
747
745
# Update references from assets being used
748
746
refs_to_add : dict [tuple [str , str ], set [int ]] = {}
749
747
refs_to_remove : dict [tuple [str , str ], set [int ]] = {}
750
- triggers : dict [int , BaseTrigger ] = {}
748
+ triggers : dict [int , dict ] = {}
751
749
752
750
# Optimization: if no asset collected, skip fetching active assets
753
751
active_assets = _find_active_assets (self .assets .keys (), session = session ) if self .assets else {}
754
752
755
753
for name_uri , asset in self .assets .items ():
756
754
# If the asset belong to a DAG not active or paused, consider there is no watcher associated to it
757
- asset_watchers = asset .watchers if name_uri in active_assets else []
758
- trigger_hash_to_trigger_dict : dict [int , BaseTrigger ] = {
759
- self ._get_base_trigger_hash (trigger ): trigger for trigger in asset_watchers
755
+ asset_watchers : list [SerializedAssetWatcher ] = (
756
+ [cast (SerializedAssetWatcher , watcher ) for watcher in asset .watchers ]
757
+ if name_uri in active_assets
758
+ else []
759
+ )
760
+ trigger_hash_to_trigger_dict : dict [int , dict ] = {
761
+ self ._get_trigger_hash (
762
+ watcher .trigger ["classpath" ], watcher .trigger ["kwargs" ]
763
+ ): watcher .trigger
764
+ for watcher in asset_watchers
760
765
}
761
766
triggers .update (trigger_hash_to_trigger_dict )
762
767
trigger_hash_from_asset : set [int ] = set (trigger_hash_to_trigger_dict .keys ())
@@ -783,7 +788,10 @@ def add_asset_trigger_references(
783
788
}
784
789
785
790
all_trigger_keys : set [tuple [str , str ]] = {
786
- self ._encrypt_trigger_kwargs (triggers [trigger_hash ])
791
+ (
792
+ triggers [trigger_hash ]["classpath" ],
793
+ Trigger .encrypt_kwargs (triggers [trigger_hash ]["kwargs" ]),
794
+ )
787
795
for trigger_hashes in refs_to_add .values ()
788
796
for trigger_hash in trigger_hashes
789
797
}
@@ -800,7 +808,9 @@ def add_asset_trigger_references(
800
808
new_trigger_models = [
801
809
trigger
802
810
for trigger in [
803
- Trigger .from_object (triggers [trigger_hash ])
811
+ Trigger (
812
+ classpath = triggers [trigger_hash ]["classpath" ], kwargs = triggers [trigger_hash ]["kwargs" ]
813
+ )
804
814
for trigger_hash in all_trigger_hashes
805
815
if trigger_hash not in orm_triggers
806
816
]
@@ -836,11 +846,6 @@ def add_asset_trigger_references(
836
846
if (asset_model .name , asset_model .uri ) not in self .assets :
837
847
asset_model .triggers = []
838
848
839
- @staticmethod
840
- def _encrypt_trigger_kwargs (trigger : BaseTrigger ) -> tuple [str , str ]:
841
- classpath , kwargs = trigger .serialize ()
842
- return classpath , Trigger .encrypt_kwargs (kwargs )
843
-
844
849
@staticmethod
845
850
def _get_trigger_hash (classpath : str , kwargs : dict [str , Any ]) -> int :
846
851
"""
@@ -852,7 +857,3 @@ def _get_trigger_hash(classpath: str, kwargs: dict[str, Any]) -> int:
852
857
This is not true for event driven scheduling.
853
858
"""
854
859
return hash ((classpath , json .dumps (BaseSerialization .serialize (kwargs )).encode ("utf-8" )))
855
-
856
- def _get_base_trigger_hash (self , trigger : BaseTrigger ) -> int :
857
- classpath , kwargs = trigger .serialize ()
858
- return self ._get_trigger_hash (classpath , kwargs )
0 commit comments