Skip to content

IWF-348 define data_attributes by prefix #89

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 24, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions iwf/data_attributes.py
Original file line number Diff line number Diff line change
@@ -7,37 +7,40 @@

class DataAttributes:
_type_store: dict[str, Optional[type]]
_prefix_type_store: dict[str, Optional[type]]
_object_encoder: ObjectEncoder
_current_values: dict[str, Union[EncodedObject, None]]
_updated_values_to_return: dict[str, EncodedObject]

def __init__(
self,
type_store: dict[str, Optional[type]],
prefix_type_store: dict[str, Optional[type]],
object_encoder: ObjectEncoder,
current_values: dict[str, Union[EncodedObject, None]],
):
self._object_encoder = object_encoder
self._type_store = type_store
self._prefix_type_store = prefix_type_store
self._current_values = current_values
self._updated_values_to_return = {}

def get_data_attribute(self, key: str) -> Any:
if key not in self._type_store:
is_registered, registered_type = self._validate_key_and_get_type(key)
if not is_registered:
raise WorkflowDefinitionError(f"data attribute %s is not registered {key}")

encoded_object = self._current_values.get(key)
if encoded_object is None:
return None

registered_type = self._type_store[key]
return self._object_encoder.decode(encoded_object, registered_type)

def set_data_attribute(self, key: str, value: Any):
if key not in self._type_store:
is_registered, registered_type = self._validate_key_and_get_type(key)
if not is_registered:
raise WorkflowDefinitionError(f"data attribute %s is not registered {key}")

registered_type = self._type_store[key]
if registered_type is not None and not isinstance(value, registered_type):
raise WorkflowDefinitionError(
f"data attribute %s is of the right type {registered_type}"
@@ -49,3 +52,13 @@ def set_data_attribute(self, key: str, value: Any):

def get_updated_values_to_return(self) -> dict[str, EncodedObject]:
return self._updated_values_to_return

def _validate_key_and_get_type(self, key) -> tuple[bool, Optional[type]]:
if key in self._type_store:
return (True, self._type_store.get(key))

for prefix, t in self._prefix_type_store.items():
if key.startswith(prefix):
return (True, t)

return (False, None)
7 changes: 7 additions & 0 deletions iwf/persistence_schema.py
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@
class PersistenceFieldType(Enum):
DataAttribute = 1
SearchAttribute = 2
DataAttributePrefix = 3


@dataclass
@@ -27,6 +28,12 @@ def search_attribute_def(cls, key: str, sa_type: SearchAttributeValueType):
key, PersistenceFieldType.SearchAttribute, None, sa_type
)

@classmethod
def data_attribute_prefix_def(cls, key: str, value_type: Optional[type]):
return PersistenceField(
key, PersistenceFieldType.DataAttributePrefix, value_type
)


@dataclass
class PersistenceSchema:
22 changes: 19 additions & 3 deletions iwf/registry.py
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@ class Registry:
_internal_channel_type_store: dict[str, TypeStore]
_signal_channel_type_store: dict[str, dict[str, Optional[type]]]
_data_attribute_types: dict[str, dict[str, Optional[type]]]
_data_attribute_prefix_types: dict[str, dict[str, Optional[type]]]
_search_attribute_types: dict[str, dict[str, SearchAttributeValueType]]
_rpc_infos: dict[str, dict[str, RPCInfo]]

@@ -27,6 +28,7 @@ def __init__(self):
self._internal_channel_type_store = dict()
self._signal_channel_type_store = dict()
self._data_attribute_types = dict()
self._data_attribute_prefix_types = dict()
self._search_attribute_types = {}
self._rpc_infos = dict()

@@ -36,6 +38,7 @@ def add_workflow(self, wf: ObjectWorkflow):
self._register_internal_channels(wf)
self._register_signal_channels(wf)
self._register_data_attributes(wf)
self._register_data_attribute_prefix_types(wf)
self._register_search_attributes(wf)
self._register_workflow_rpcs(wf)

@@ -82,6 +85,11 @@ def get_search_attribute_types(
) -> dict[str, SearchAttributeValueType]:
return self._search_attribute_types[wf_type]

def get_data_attribute_prefix_types(
self, wf_type: str
) -> dict[str, Optional[type]]:
return self._data_attribute_prefix_types[wf_type]

def get_rpc_infos(self, wf_type: str) -> dict[str, RPCInfo]:
return self._rpc_infos[wf_type]

@@ -115,11 +123,11 @@ def _register_signal_channels(self, wf: ObjectWorkflow):

def _register_data_attributes(self, wf: ObjectWorkflow):
wf_type = get_workflow_type(wf)
types: dict[str, Optional[type]] = {}
data_attribute_types: dict[str, Optional[type]] = {}
for field in wf.get_persistence_schema().persistence_fields:
if field.field_type == PersistenceFieldType.DataAttribute:
types[field.key] = field.value_type
self._data_attribute_types[wf_type] = types
data_attribute_types[field.key] = field.value_type
self._data_attribute_types[wf_type] = data_attribute_types

def _register_search_attributes(self, wf: ObjectWorkflow):
wf_type = get_workflow_type(wf)
@@ -138,6 +146,14 @@ def _register_search_attributes(self, wf: ObjectWorkflow):
types[field.key] = sa_type
self._search_attribute_types[wf_type] = types

def _register_data_attribute_prefix_types(self, wf: ObjectWorkflow):
wf_type = get_workflow_type(wf)
data_attribute_prefix_types: dict[str, Optional[type]] = {}
for field in wf.get_persistence_schema().persistence_fields:
if field.field_type == PersistenceFieldType.DataAttributePrefix:
data_attribute_prefix_types[field.key] = field.value_type
self._data_attribute_prefix_types[wf_type] = data_attribute_prefix_types

def _register_workflow_state(self, wf):
wf_type = get_workflow_type(wf)
state_map = {}
13 changes: 13 additions & 0 deletions iwf/tests/test_persistence_data_attributes.py
Original file line number Diff line number Diff line change
@@ -30,6 +30,12 @@
final_initial_da_value_1 = initial_da_value_1
final_initial_da_value_2 = "no-more-init"

test_da_prefix = "test-da-prefix"
test_da_prefix_key_1 = "test-da-prefix-1"
test_da_prefix_key_2 = "test-da-prefix-2"
test_da_prefix_value_1 = "test-da-value-1"
test_da_prefix_value_2 = "test-da-value-2"


class DataAttributeRWState(WorkflowState[None]):
def wait_until(
@@ -60,6 +66,8 @@ def execute(
persistence.set_data_attribute(test_da_1, final_test_da_value_1)
persistence.set_data_attribute(test_da_2, final_test_da_value_2)
persistence.set_data_attribute(initial_da_2, final_initial_da_value_2)
persistence.set_data_attribute(test_da_prefix_key_1, test_da_prefix_value_1)
persistence.set_data_attribute(test_da_prefix_key_2, test_da_prefix_value_2)
return StateDecision.graceful_complete_workflow()


@@ -73,6 +81,7 @@ def get_persistence_schema(self) -> PersistenceSchema:
PersistenceField.data_attribute_def(initial_da_2, str),
PersistenceField.data_attribute_def(test_da_1, str),
PersistenceField.data_attribute_def(test_da_2, int),
PersistenceField.data_attribute_prefix_def(test_da_prefix, str),
)

@rpc()
@@ -82,6 +91,8 @@ def test_persistence_read(self, pers: Persistence):
pers.get_data_attribute(initial_da_2),
pers.get_data_attribute(test_da_1),
pers.get_data_attribute(test_da_2),
pers.get_data_attribute(test_da_prefix_key_1),
pers.get_data_attribute(test_da_prefix_key_2),
)


@@ -115,4 +126,6 @@ def test_persistence_data_attributes_workflow(self):
final_initial_da_value_2,
final_test_da_value_1,
final_test_da_value_2,
test_da_prefix_value_1,
test_da_prefix_value_2,
]
24 changes: 21 additions & 3 deletions iwf/worker_service.py
Original file line number Diff line number Diff line change
@@ -78,6 +78,9 @@ def handle_workflow_worker_rpc(
internal_channel_types = self._registry.get_internal_channel_type_store(wf_type)
signal_channel_types = self._registry.get_signal_channel_types(wf_type)
data_attributes_types = self._registry.get_data_attribute_types(wf_type)
data_attributes_prefix_types = self._registry.get_data_attribute_prefix_types(
wf_type
)

context = _from_idl_context(request.context)
_input = self._options.object_encoder.decode(
@@ -92,7 +95,10 @@ def handle_workflow_worker_rpc(
}

data_attributes = DataAttributes(
data_attributes_types, self._options.object_encoder, current_data_attributes
data_attributes_types,
data_attributes_prefix_types,
self._options.object_encoder,
current_data_attributes,
)

search_attributes_types = self._registry.get_search_attribute_types(wf_type)
@@ -167,6 +173,9 @@ def handle_workflow_state_wait_until(
internal_channel_types = self._registry.get_internal_channel_type_store(wf_type)
signal_channel_types = self._registry.get_signal_channel_types(wf_type)
data_attributes_types = self._registry.get_data_attribute_types(wf_type)
data_attributes_prefix_types = self._registry.get_data_attribute_prefix_types(
wf_type
)

context = _from_idl_context(request.context)
_input = self._options.object_encoder.decode(
@@ -181,7 +190,10 @@ def handle_workflow_state_wait_until(
}

data_attributes = DataAttributes(
data_attributes_types, self._options.object_encoder, current_data_attributes
data_attributes_types,
data_attributes_prefix_types,
self._options.object_encoder,
current_data_attributes,
)

search_attributes_types = self._registry.get_search_attribute_types(wf_type)
@@ -236,6 +248,9 @@ def handle_workflow_state_execute(
internal_channel_types = self._registry.get_internal_channel_type_store(wf_type)
signal_channel_types = self._registry.get_signal_channel_types(wf_type)
data_attributes_types = self._registry.get_data_attribute_types(wf_type)
data_attributes_prefix_types = self._registry.get_data_attribute_prefix_types(
wf_type
)
context = _from_idl_context(request.context)

_input = self._options.object_encoder.decode(
@@ -250,7 +265,10 @@ def handle_workflow_state_execute(
}

data_attributes = DataAttributes(
data_attributes_types, self._options.object_encoder, current_data_attributes
data_attributes_types,
data_attributes_prefix_types,
self._options.object_encoder,
current_data_attributes,
)

search_attributes_types = self._registry.get_search_attribute_types(wf_type)