diff --git a/sdk/python/feast/infra/online_stores/dynamodb.py b/sdk/python/feast/infra/online_stores/dynamodb.py index 13dd949059c..75ef2ae10c7 100644 --- a/sdk/python/feast/infra/online_stores/dynamodb.py +++ b/sdk/python/feast/infra/online_stores/dynamodb.py @@ -15,12 +15,18 @@ import contextlib import itertools import logging -from collections import OrderedDict, defaultdict +from collections import OrderedDict from datetime import datetime from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union from aiobotocore.config import AioConfig from pydantic import StrictBool, StrictStr +from tenacity import ( + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) from feast import Entity, FeatureView, utils from feast.infra.infra_object import DYNAMODB_INFRA_OBJECT_CLASS_TYPE, InfraObject @@ -74,7 +80,10 @@ class DynamoDBOnlineStoreConfig(FeastConfigBaseModel): """Whether to read from Dynamodb by forcing consistent reads""" tags: Union[Dict[str, str], None] = None - """AWS resource tags added to each table""" + """Key-value pairs added to each feature-view""" + + tag_aws_resources: StrictBool = False + """Add the feature-view tags to the underlying AWS dynamodb tables""" session_based_auth: bool = False """AWS session based client authentication""" @@ -138,38 +147,6 @@ async def close(self): def async_supported(self) -> SupportedAsyncMethods: return SupportedAsyncMethods(read=True, write=True) - @staticmethod - def _table_tags(online_config, table_instance) -> list[dict[str, str]]: - table_instance_tags = table_instance.tags or {} - online_tags = online_config.tags or {} - - common_tags = [ - {"Key": key, "Value": table_instance_tags.get(key) or value} - for key, value in online_tags.items() - ] - table_tags = [ - {"Key": key, "Value": value} - for key, value in table_instance_tags.items() - if key not in online_tags - ] - - return common_tags + table_tags - - @staticmethod - def _update_tags(dynamodb_client, table_name: str, new_tags: list[dict[str, str]]): - table_arn = dynamodb_client.describe_table(TableName=table_name)["Table"][ - "TableArn" - ] - current_tags = dynamodb_client.list_tags_of_resource(ResourceArn=table_arn)[ - "Tags" - ] - if current_tags: - remove_keys = [tag["Key"] for tag in current_tags] - dynamodb_client.untag_resource(ResourceArn=table_arn, TagKeys=remove_keys) - - if new_tags: - dynamodb_client.tag_resource(ResourceArn=table_arn, Tags=new_tags) - def update( self, config: RepoConfig, @@ -189,59 +166,25 @@ def update( """ online_config = config.online_store assert isinstance(online_config, DynamoDBOnlineStoreConfig) - dynamodb_client = self._get_dynamodb_client( - online_config.region, - online_config.endpoint_url, - online_config.session_based_auth, - ) + dynamodb_resource = self._get_dynamodb_resource( online_config.region, online_config.endpoint_url, online_config.session_based_auth, ) - do_tag_updates = defaultdict(bool) - for table_instance in tables_to_keep: - # Add Tags attribute to creation request only if configured to prevent - # TagResource permission issues, even with an empty Tags array. - table_tags = self._table_tags(online_config, table_instance) - kwargs = {"Tags": table_tags} if table_tags else {} + def get_table_manager(table): + return _DynamoTableManager( + dynamodb_resource=dynamodb_resource, + config=config, + feature_view=table, + ) - table_name = _get_table_name(online_config, config, table_instance) - try: - dynamodb_resource.create_table( - TableName=table_name, - KeySchema=[{"AttributeName": "entity_id", "KeyType": "HASH"}], - AttributeDefinitions=[ - {"AttributeName": "entity_id", "AttributeType": "S"} - ], - BillingMode="PAY_PER_REQUEST", - **kwargs, - ) + for table in tables_to_keep: + get_table_manager(table).update() - except ClientError as ce: - do_tag_updates[table_name] = True - - # If the table creation fails with ResourceInUseException, - # it means the table already exists or is being created. - # Otherwise, re-raise the exception - if ce.response["Error"]["Code"] != "ResourceInUseException": - raise - - for table_instance in tables_to_keep: - table_name = _get_table_name(online_config, config, table_instance) - dynamodb_client.get_waiter("table_exists").wait(TableName=table_name) - # once table is confirmed to exist, update the tags. - # tags won't be updated in the create_table call if the table already exists - if do_tag_updates[table_name]: - tags = self._table_tags(online_config, table_instance) - self._update_tags(dynamodb_client, table_name, tags) - - for table_to_delete in tables_to_delete: - _delete_table_idempotent( - dynamodb_resource, - _get_table_name(online_config, config, table_to_delete), - ) + for table in tables_to_delete: + get_table_manager(table).delete() def teardown( self, @@ -265,9 +208,11 @@ def teardown( ) for table in tables: - _delete_table_idempotent( - dynamodb_resource, _get_table_name(online_config, config, table) - ) + _DynamoTableManager( + dynamodb_resource=dynamodb_resource, + config=config, + feature_view=table, + ).delete() def online_write_batch( self, @@ -845,3 +790,102 @@ def _latest_data_to_write( as_hashable = ((d[0].SerializeToString(), d) for d in data) sorted_data = sorted(as_hashable, key=lambda ah: (ah[0], ah[1][2])) return (v for _, v in OrderedDict((ah[0], ah[1]) for ah in sorted_data).items()) + + +class RetryableBotoError(Exception): + pass + + +class LimitExceededException(RetryableBotoError): + pass + + +class _DynamoTableManager: + def __init__( + self, dynamodb_resource, config: RepoConfig, feature_view: FeatureView + ): + self.config = config + self.feature_view = feature_view + self._dynamodb_resource = dynamodb_resource + + @property + def _dynamodb_client(self): + return self._dynamodb_resource.meta.client + + @property + def table_name(self) -> str: + return _get_table_name(self.config.online_store, self.config, self.feature_view) + + def table_tags(self) -> list[dict[str, str]]: + table_instance_tags = self.feature_view.tags or {} + online_tags = self.config.online_store.tags or {} + + common_tags = [ + {"Key": key, "Value": table_instance_tags.get(key) or value} + for key, value in online_tags.items() + ] + table_tags = [ + {"Key": key, "Value": value} + for key, value in table_instance_tags.items() + if key not in online_tags + ] + + return common_tags + table_tags + + @retry( + wait=wait_exponential(multiplier=1, max=4), + retry=retry_if_exception_type(RetryableBotoError), + stop=stop_after_attempt(3), + reraise=True, + ) + def _update_tags(self, new_tags: list[dict[str, str]]): + table_arn = self._dynamodb_client.describe_table(TableName=self.table_name)[ + "Table" + ]["TableArn"] + current_tags = self._dynamodb_client.list_tags_of_resource( + ResourceArn=table_arn + )["Tags"] + if current_tags: + remove_keys = [tag["Key"] for tag in current_tags] + self._dynamodb_client.untag_resource( + ResourceArn=table_arn, TagKeys=remove_keys + ) + + if new_tags: + try: + self._dynamodb_client.tag_resource(ResourceArn=table_arn, Tags=new_tags) + except ClientError as ce: + if ce.response["Error"]["Code"] == "LimitExceededException": + raise LimitExceededException from ce + + def update(self): + # Add Tags attribute to creation request only if configured to prevent + # TagResource permission issues, even with an empty Tags array. + do_tag_update = self.config.online_store.tag_aws_resources + table_tags = self.table_tags() + kwargs = {"Tags": table_tags} if table_tags and do_tag_update else {} + try: + self._dynamodb_resource.create_table( + TableName=self.table_name, + KeySchema=[{"AttributeName": "entity_id", "KeyType": "HASH"}], + AttributeDefinitions=[ + {"AttributeName": "entity_id", "AttributeType": "S"} + ], + BillingMode="PAY_PER_REQUEST", + **kwargs, + ) + do_tag_update = False + except ClientError as ce: + # If the table creation fails with ResourceInUseException, + # it means the table already exists or is being created. + # Otherwise, re-raise the exception + if ce.response["Error"]["Code"] != "ResourceInUseException": + raise + + # tags won't be updated in the create_table call if the table already exists + self._dynamodb_client.get_waiter("table_exists").wait(TableName=self.table_name) + if do_tag_update: + self._update_tags(table_tags) + + def delete(self) -> None: + _delete_table_idempotent(self._dynamodb_resource, self.table_name) diff --git a/sdk/python/tests/unit/infra/online_store/test_dynamodb_online_store.py b/sdk/python/tests/unit/infra/online_store/test_dynamodb_online_store.py index 91f0474ab93..5c7ca3f0e8d 100644 --- a/sdk/python/tests/unit/infra/online_store/test_dynamodb_online_store.py +++ b/sdk/python/tests/unit/infra/online_store/test_dynamodb_online_store.py @@ -12,6 +12,7 @@ DynamoDBOnlineStore, DynamoDBOnlineStoreConfig, DynamoDBTable, + _DynamoTableManager, _latest_data_to_write, ) from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto @@ -251,14 +252,15 @@ def test_dynamodb_online_store_update(repo_config, dynamodb_online_store): assert len(existing_tables) == 1 assert existing_tables[0] == f"test_aws.{db_table_keep_name}" - assert _get_tags(dynamodb_client, existing_tables[0]) == [ - {"Key": "some", "Value": "tag"} - ] + # default behavior: no dynamo table tags + assert _get_tags(dynamodb_client, existing_tables[0]) == [] @mock_dynamodb def test_dynamodb_online_store_update_tags(repo_config, dynamodb_online_store): """Test DynamoDBOnlineStore update method.""" + repo_config.online_config.tag_aws_resources = True + # create dummy table to update with new tags and tag values table_name = f"{TABLE_NAME}_keep_update_tags" create_test_table(PROJECT, table_name, REGION) @@ -335,12 +337,14 @@ def test_dynamodb_online_store_update_tags(repo_config, dynamodb_online_store): ], ) def test_dynamodb_online_store_tag_priority( - global_tags, table_tags, expected, dynamodb_online_store + repo_config, global_tags, table_tags, expected ): - actual = dynamodb_online_store._table_tags( - MockOnlineConfig(tags=global_tags), + repo_config.online_config = MockOnlineConfig(tags=global_tags) + actual = _DynamoTableManager( + None, + repo_config, MockFeatureView(name="table", tags=table_tags), - ) + ).table_tags() assert actual == expected