diff --git a/iwf/data_attributes.py b/iwf/data_attributes.py index c954a7c..a8c3387 100644 --- a/iwf/data_attributes.py +++ b/iwf/data_attributes.py @@ -7,6 +7,7 @@ class DataAttributes: _type_store: dict[str, Optional[type]] + _prefix_type_store: dict[str, Optional[type]] _object_encoder: ObjectEncoder _current_values: dict[str, Union[EncodedObject, None]] _updated_values_to_return: dict[str, EncodedObject] @@ -14,30 +15,32 @@ class DataAttributes: def __init__( self, type_store: dict[str, Optional[type]], + prefix_type_store: dict[str, Optional[type]], object_encoder: ObjectEncoder, current_values: dict[str, Union[EncodedObject, None]], ): self._object_encoder = object_encoder self._type_store = type_store + self._prefix_type_store = prefix_type_store self._current_values = current_values self._updated_values_to_return = {} def get_data_attribute(self, key: str) -> Any: - if key not in self._type_store: + is_registered, registered_type = self._validate_key_and_get_type(key) + if not is_registered: raise WorkflowDefinitionError(f"data attribute %s is not registered {key}") encoded_object = self._current_values.get(key) if encoded_object is None: return None - registered_type = self._type_store[key] return self._object_encoder.decode(encoded_object, registered_type) def set_data_attribute(self, key: str, value: Any): - if key not in self._type_store: + is_registered, registered_type = self._validate_key_and_get_type(key) + if not is_registered: raise WorkflowDefinitionError(f"data attribute %s is not registered {key}") - registered_type = self._type_store[key] if registered_type is not None and not isinstance(value, registered_type): raise WorkflowDefinitionError( f"data attribute %s is of the right type {registered_type}" @@ -49,3 +52,13 @@ def set_data_attribute(self, key: str, value: Any): def get_updated_values_to_return(self) -> dict[str, EncodedObject]: return self._updated_values_to_return + + def _validate_key_and_get_type(self, key) -> tuple[bool, Optional[type]]: + if key in self._type_store: + return (True, self._type_store.get(key)) + + for prefix, t in self._prefix_type_store.items(): + if key.startswith(prefix): + return (True, t) + + return (False, None) diff --git a/iwf/persistence_schema.py b/iwf/persistence_schema.py index 076f6f9..778842d 100644 --- a/iwf/persistence_schema.py +++ b/iwf/persistence_schema.py @@ -8,6 +8,7 @@ class PersistenceFieldType(Enum): DataAttribute = 1 SearchAttribute = 2 + DataAttributePrefix = 3 @dataclass @@ -27,6 +28,12 @@ def search_attribute_def(cls, key: str, sa_type: SearchAttributeValueType): key, PersistenceFieldType.SearchAttribute, None, sa_type ) + @classmethod + def data_attribute_prefix_def(cls, key: str, value_type: Optional[type]): + return PersistenceField( + key, PersistenceFieldType.DataAttributePrefix, value_type + ) + @dataclass class PersistenceSchema: diff --git a/iwf/registry.py b/iwf/registry.py index dd1360c..d9ea5d1 100644 --- a/iwf/registry.py +++ b/iwf/registry.py @@ -17,6 +17,7 @@ class Registry: _internal_channel_type_store: dict[str, TypeStore] _signal_channel_type_store: dict[str, dict[str, Optional[type]]] _data_attribute_types: dict[str, dict[str, Optional[type]]] + _data_attribute_prefix_types: dict[str, dict[str, Optional[type]]] _search_attribute_types: dict[str, dict[str, SearchAttributeValueType]] _rpc_infos: dict[str, dict[str, RPCInfo]] @@ -27,6 +28,7 @@ def __init__(self): self._internal_channel_type_store = dict() self._signal_channel_type_store = dict() self._data_attribute_types = dict() + self._data_attribute_prefix_types = dict() self._search_attribute_types = {} self._rpc_infos = dict() @@ -36,6 +38,7 @@ def add_workflow(self, wf: ObjectWorkflow): self._register_internal_channels(wf) self._register_signal_channels(wf) self._register_data_attributes(wf) + self._register_data_attribute_prefix_types(wf) self._register_search_attributes(wf) self._register_workflow_rpcs(wf) @@ -82,6 +85,11 @@ def get_search_attribute_types( ) -> dict[str, SearchAttributeValueType]: return self._search_attribute_types[wf_type] + def get_data_attribute_prefix_types( + self, wf_type: str + ) -> dict[str, Optional[type]]: + return self._data_attribute_prefix_types[wf_type] + def get_rpc_infos(self, wf_type: str) -> dict[str, RPCInfo]: return self._rpc_infos[wf_type] @@ -115,11 +123,11 @@ def _register_signal_channels(self, wf: ObjectWorkflow): def _register_data_attributes(self, wf: ObjectWorkflow): wf_type = get_workflow_type(wf) - types: dict[str, Optional[type]] = {} + data_attribute_types: dict[str, Optional[type]] = {} for field in wf.get_persistence_schema().persistence_fields: if field.field_type == PersistenceFieldType.DataAttribute: - types[field.key] = field.value_type - self._data_attribute_types[wf_type] = types + data_attribute_types[field.key] = field.value_type + self._data_attribute_types[wf_type] = data_attribute_types def _register_search_attributes(self, wf: ObjectWorkflow): wf_type = get_workflow_type(wf) @@ -138,6 +146,14 @@ def _register_search_attributes(self, wf: ObjectWorkflow): types[field.key] = sa_type self._search_attribute_types[wf_type] = types + def _register_data_attribute_prefix_types(self, wf: ObjectWorkflow): + wf_type = get_workflow_type(wf) + data_attribute_prefix_types: dict[str, Optional[type]] = {} + for field in wf.get_persistence_schema().persistence_fields: + if field.field_type == PersistenceFieldType.DataAttributePrefix: + data_attribute_prefix_types[field.key] = field.value_type + self._data_attribute_prefix_types[wf_type] = data_attribute_prefix_types + def _register_workflow_state(self, wf): wf_type = get_workflow_type(wf) state_map = {} diff --git a/iwf/tests/test_persistence_data_attributes.py b/iwf/tests/test_persistence_data_attributes.py index 5769e25..6175428 100644 --- a/iwf/tests/test_persistence_data_attributes.py +++ b/iwf/tests/test_persistence_data_attributes.py @@ -30,6 +30,12 @@ final_initial_da_value_1 = initial_da_value_1 final_initial_da_value_2 = "no-more-init" +test_da_prefix = "test-da-prefix" +test_da_prefix_key_1 = "test-da-prefix-1" +test_da_prefix_key_2 = "test-da-prefix-2" +test_da_prefix_value_1 = "test-da-value-1" +test_da_prefix_value_2 = "test-da-value-2" + class DataAttributeRWState(WorkflowState[None]): def wait_until( @@ -60,6 +66,8 @@ def execute( persistence.set_data_attribute(test_da_1, final_test_da_value_1) persistence.set_data_attribute(test_da_2, final_test_da_value_2) persistence.set_data_attribute(initial_da_2, final_initial_da_value_2) + persistence.set_data_attribute(test_da_prefix_key_1, test_da_prefix_value_1) + persistence.set_data_attribute(test_da_prefix_key_2, test_da_prefix_value_2) return StateDecision.graceful_complete_workflow() @@ -73,6 +81,7 @@ def get_persistence_schema(self) -> PersistenceSchema: PersistenceField.data_attribute_def(initial_da_2, str), PersistenceField.data_attribute_def(test_da_1, str), PersistenceField.data_attribute_def(test_da_2, int), + PersistenceField.data_attribute_prefix_def(test_da_prefix, str), ) @rpc() @@ -82,6 +91,8 @@ def test_persistence_read(self, pers: Persistence): pers.get_data_attribute(initial_da_2), pers.get_data_attribute(test_da_1), pers.get_data_attribute(test_da_2), + pers.get_data_attribute(test_da_prefix_key_1), + pers.get_data_attribute(test_da_prefix_key_2), ) @@ -115,4 +126,6 @@ def test_persistence_data_attributes_workflow(self): final_initial_da_value_2, final_test_da_value_1, final_test_da_value_2, + test_da_prefix_value_1, + test_da_prefix_value_2, ] diff --git a/iwf/worker_service.py b/iwf/worker_service.py index 16be37f..c06ea60 100644 --- a/iwf/worker_service.py +++ b/iwf/worker_service.py @@ -78,6 +78,9 @@ def handle_workflow_worker_rpc( internal_channel_types = self._registry.get_internal_channel_type_store(wf_type) signal_channel_types = self._registry.get_signal_channel_types(wf_type) data_attributes_types = self._registry.get_data_attribute_types(wf_type) + data_attributes_prefix_types = self._registry.get_data_attribute_prefix_types( + wf_type + ) context = _from_idl_context(request.context) _input = self._options.object_encoder.decode( @@ -92,7 +95,10 @@ def handle_workflow_worker_rpc( } data_attributes = DataAttributes( - data_attributes_types, self._options.object_encoder, current_data_attributes + data_attributes_types, + data_attributes_prefix_types, + self._options.object_encoder, + current_data_attributes, ) search_attributes_types = self._registry.get_search_attribute_types(wf_type) @@ -167,6 +173,9 @@ def handle_workflow_state_wait_until( internal_channel_types = self._registry.get_internal_channel_type_store(wf_type) signal_channel_types = self._registry.get_signal_channel_types(wf_type) data_attributes_types = self._registry.get_data_attribute_types(wf_type) + data_attributes_prefix_types = self._registry.get_data_attribute_prefix_types( + wf_type + ) context = _from_idl_context(request.context) _input = self._options.object_encoder.decode( @@ -181,7 +190,10 @@ def handle_workflow_state_wait_until( } data_attributes = DataAttributes( - data_attributes_types, self._options.object_encoder, current_data_attributes + data_attributes_types, + data_attributes_prefix_types, + self._options.object_encoder, + current_data_attributes, ) search_attributes_types = self._registry.get_search_attribute_types(wf_type) @@ -236,6 +248,9 @@ def handle_workflow_state_execute( internal_channel_types = self._registry.get_internal_channel_type_store(wf_type) signal_channel_types = self._registry.get_signal_channel_types(wf_type) data_attributes_types = self._registry.get_data_attribute_types(wf_type) + data_attributes_prefix_types = self._registry.get_data_attribute_prefix_types( + wf_type + ) context = _from_idl_context(request.context) _input = self._options.object_encoder.decode( @@ -250,7 +265,10 @@ def handle_workflow_state_execute( } data_attributes = DataAttributes( - data_attributes_types, self._options.object_encoder, current_data_attributes + data_attributes_types, + data_attributes_prefix_types, + self._options.object_encoder, + current_data_attributes, ) search_attributes_types = self._registry.get_search_attribute_types(wf_type)