diff --git a/awswrangler/athena/_utils.py b/awswrangler/athena/_utils.py index 7a0e74aac..fa3ba15d0 100644 --- a/awswrangler/athena/_utils.py +++ b/awswrangler/athena/_utils.py @@ -44,6 +44,13 @@ _logger: logging.Logger = logging.getLogger(__name__) +class _MergeClause(TypedDict, total=False): + when: Literal["MATCHED", "NOT MATCHED", "NOT MATCHED BY SOURCE"] + condition: str | None + action: Literal["UPDATE", "DELETE", "INSERT"] + columns: list[str] | None + + class _QueryMetadata(NamedTuple): execution_id: str dtype: dict[str, str] diff --git a/awswrangler/athena/_write_iceberg.py b/awswrangler/athena/_write_iceberg.py index b5f05eb57..9bfb01b01 100644 --- a/awswrangler/athena/_write_iceberg.py +++ b/awswrangler/athena/_write_iceberg.py @@ -14,11 +14,7 @@ from awswrangler import _data_types, _utils, catalog, exceptions, s3 from awswrangler._config import apply_configs from awswrangler.athena._executions import wait_query -from awswrangler.athena._utils import ( - _get_workgroup_config, - _start_query_execution, - _WorkGroupConfig, -) +from awswrangler.athena._utils import _get_workgroup_config, _MergeClause, _start_query_execution, _WorkGroupConfig from awswrangler.typing import GlueTableSettings _logger: logging.Logger = logging.getLogger(__name__) @@ -219,7 +215,10 @@ def _validate_args( mode: Literal["append", "overwrite", "overwrite_partitions"], partition_cols: list[str] | None, merge_cols: list[str] | None, - merge_condition: Literal["update", "ignore"], + merge_on_clause: str | None, + merge_condition: Literal["update", "ignore", "conditional_merge"], + merge_conditional_clauses: list[_MergeClause] | None, + merge_match_nulls: bool, ) -> None: if df.empty is True: raise exceptions.EmptyDataFrame("DataFrame cannot be empty.") @@ -229,6 +228,10 @@ def _validate_args( "Either path or workgroup path must be specified to store the temporary results." ) + _validate_merge_arguments( + merge_cols, merge_on_clause, merge_condition, merge_conditional_clauses, merge_match_nulls + ) + if mode == "overwrite_partitions": if not partition_cols: raise exceptions.InvalidArgumentCombination( @@ -239,11 +242,61 @@ def _validate_args( "When mode is 'overwrite_partitions' merge_cols must not be specified." ) - if merge_cols and merge_condition not in ["update", "ignore"]: + +def _validate_merge_arguments( + merge_cols: list[str] | None, + merge_on_clause: str | None, + merge_condition: Literal["update", "ignore", "conditional_merge"], + merge_conditional_clauses: list[_MergeClause] | None, + merge_match_nulls: bool, +) -> None: + if merge_cols and merge_on_clause: + raise exceptions.InvalidArgumentCombination( + "Cannot specify both merge_cols and merge_on_clause. Use either merge_cols for simple equality matching or merge_on_clause for custom logic." + ) + + if merge_on_clause and merge_match_nulls: + raise exceptions.InvalidArgumentCombination("merge_match_nulls can only be used together with merge_cols.") + + if merge_conditional_clauses and merge_condition != "conditional_merge": + raise exceptions.InvalidArgumentCombination( + "merge_conditional_clauses can only be used when merge_condition is 'conditional_merge'." + ) + + if (merge_cols or merge_on_clause) and merge_condition not in ["update", "ignore", "conditional_merge"]: raise exceptions.InvalidArgumentValue( - f"Invalid merge_condition: {merge_condition}. Valid values: ['update', 'ignore']" + f"Invalid merge_condition: {merge_condition}. Valid values: ['update', 'ignore', 'conditional_merge']" ) + if merge_condition == "conditional_merge": + if not merge_conditional_clauses: + raise exceptions.InvalidArgumentCombination( + "merge_conditional_clauses must be provided when merge_condition is 'conditional_merge'." + ) + + seen_not_matched = False + for i, clause in enumerate(merge_conditional_clauses): + if "when" not in clause: + raise exceptions.InvalidArgumentValue(f"merge_conditional_clauses[{i}] must contain 'when' field.") + if "action" not in clause: + raise exceptions.InvalidArgumentValue(f"merge_conditional_clauses[{i}] must contain 'action' field.") + if clause["when"] not in ["MATCHED", "NOT MATCHED", "NOT MATCHED BY SOURCE"]: + raise exceptions.InvalidArgumentValue( + f"merge_conditional_clauses[{i}]['when'] must be one of ['MATCHED', 'NOT MATCHED', 'NOT MATCHED BY SOURCE']." + ) + if clause["action"] not in ["UPDATE", "DELETE", "INSERT"]: + raise exceptions.InvalidArgumentValue( + f"merge_conditional_clauses[{i}]['action'] must be one of ['UPDATE', 'DELETE', 'INSERT']." + ) + + if clause["when"] in ["NOT MATCHED", "NOT MATCHED BY SOURCE"]: + seen_not_matched = True + elif clause["when"] == "MATCHED" and seen_not_matched: + raise exceptions.InvalidArgumentValue( + f"merge_conditional_clauses[{i}]['when'] is MATCHED but appears after a NOT MATCHED clause. " + "WHEN MATCHED must come before WHEN NOT MATCHED or WHEN NOT MATCHED BY SOURCE." + ) + def _merge_iceberg( df: pd.DataFrame, @@ -251,7 +304,9 @@ def _merge_iceberg( table: str, source_table: str, merge_cols: list[str] | None = None, - merge_condition: Literal["update", "ignore"] = "update", + merge_on_clause: str | None = None, + merge_condition: Literal["update", "ignore", "conditional_merge"] = "update", + merge_conditional_clauses: list[_MergeClause] | None = None, merge_match_nulls: bool = False, kms_key: str | None = None, boto3_session: boto3.Session | None = None, @@ -278,11 +333,27 @@ def _merge_iceberg( source_table: str AWS Glue/Athena source table name. merge_cols: List[str], optional - List of column names that will be used for conditional inserts and updates. + List of column names that will be used for conditional inserts and updates. Cannot be used together with ``merge_on_clause``. https://docs.aws.amazon.com/athena/latest/ug/merge-into-statement.html + merge_on_clause: str, optional + Custom ON clause for the MERGE statement. If specified, this string will be used as the ON condition + between the target and source tables, allowing for complex join logic beyond simple equality on columns. + Cannot be used together with ``merge_cols``. + It must produce at most one match per target row. Using OR conditions may result in merge failures. merge_condition: str, optional - The condition to be used in the MERGE INTO statement. Valid values: ['update', 'ignore']. + The condition to be used in the MERGE INTO statement. Valid values: ['update', 'ignore', 'conditional_merge']. + - 'update': Update matched rows and insert non-matched rows. + - 'ignore': Only insert non-matched rows. + - 'conditional_merge': Use custom conditional clauses for merge actions. + merge_conditional_clauses : List[dict], optional + List of dictionaries specifying custom conditional clauses for the MERGE statement. + Each dictionary should have: + - 'when': One of ['MATCHED', 'NOT MATCHED', 'NOT MATCHED BY SOURCE'] + - 'condition': (optional) Additional SQL condition for the clause + - 'action': One of ['UPDATE', 'DELETE', 'INSERT'] + - 'columns': (optional) List of columns to update or insert + Used only when merge_condition is 'conditional_merge'. merge_match_nulls: bool, optional Instruct whether to have nulls in the merge condition match other nulls kms_key : str, optional @@ -306,26 +377,66 @@ def _merge_iceberg( wg_config: _WorkGroupConfig = _get_workgroup_config(session=boto3_session, workgroup=workgroup) sql_statement: str - if merge_cols: - if merge_condition == "update": - match_condition = f"""WHEN MATCHED THEN - UPDATE SET {", ".join([f'"{x}" = source."{x}"' for x in df.columns])}""" - else: - match_condition = "" + if merge_cols or merge_on_clause: + if merge_on_clause: + on_condition = merge_on_clause + elif merge_cols is not None: + if merge_match_nulls: + merge_conditions = [f'(target."{x}" IS NOT DISTINCT FROM source."{x}")' for x in merge_cols] + else: + merge_conditions = [f'(target."{x}" = source."{x}")' for x in merge_cols] + on_condition = " AND ".join(merge_conditions) + + # Build WHEN clauses based on merge_condition + when_clauses = [] - if merge_match_nulls: - merge_conditions = [f'(target."{x}" IS NOT DISTINCT FROM source."{x}")' for x in merge_cols] - else: - merge_conditions = [f'(target."{x}" = source."{x}")' for x in merge_cols] + if merge_condition == "update": + when_clauses.append(f"""WHEN MATCHED THEN + UPDATE SET {", ".join([f'"{x}" = source."{x}"' for x in df.columns])}""") + when_clauses.append(f"""WHEN NOT MATCHED THEN + INSERT ({", ".join([f'"{x}"' for x in df.columns])}) + VALUES ({", ".join([f'source."{x}"' for x in df.columns])})""") + elif merge_condition == "ignore": + when_clauses.append(f"""WHEN NOT MATCHED THEN + INSERT ({", ".join([f'"{x}"' for x in df.columns])}) + VALUES ({", ".join([f'source."{x}"' for x in df.columns])})""") + + elif merge_condition == "conditional_merge" and merge_conditional_clauses is not None: + for clause in merge_conditional_clauses: + when_type = clause["when"] + action = clause["action"] + condition = clause.get("condition") + columns = clause.get("columns") + + # Build WHEN clause + when_part = f"WHEN {when_type}" + if condition: + when_part += f" AND {condition}" + + # Build action + if action == "UPDATE": + update_columns = columns or df.columns.tolist() + update_sets = [f'"{col}" = source."{col}"' for col in update_columns] + when_part += f" THEN UPDATE SET {', '.join(update_sets)}" + + elif action == "DELETE": + when_part += " THEN DELETE" + + elif action == "INSERT": + insert_columns = columns or df.columns.tolist() + column_list = ", ".join([f'"{col}"' for col in insert_columns]) + values_list = ", ".join([f'source."{col}"' for col in insert_columns]) + when_part += f" THEN INSERT ({column_list}) VALUES ({values_list})" + + when_clauses.append(when_part) + + joined_clauses = "\n ".join(when_clauses) sql_statement = f""" MERGE INTO "{database}"."{table}" target USING "{database}"."{source_table}" source - ON {" AND ".join(merge_conditions)} - {match_condition} - WHEN NOT MATCHED THEN - INSERT ({", ".join([f'"{x}"' for x in df.columns])}) - VALUES ({", ".join([f'source."{x}"' for x in df.columns])}) + ON {on_condition} + {joined_clauses} """ else: sql_statement = f""" @@ -361,7 +472,9 @@ def to_iceberg( # noqa: PLR0913 table_location: str | None = None, partition_cols: list[str] | None = None, merge_cols: list[str] | None = None, - merge_condition: Literal["update", "ignore"] = "update", + merge_on_clause: str | None = None, + merge_condition: Literal["update", "ignore", "conditional_merge"] = "update", + merge_conditional_clauses: list[_MergeClause] | None = None, merge_match_nulls: bool = False, keep_files: bool = True, data_source: str | None = None, @@ -381,9 +494,9 @@ def to_iceberg( # noqa: PLR0913 glue_table_settings: GlueTableSettings | None = None, ) -> None: """ - Insert into Athena Iceberg table using INSERT INTO ... SELECT. Will create Iceberg table if it does not exist. + Write a Pandas DataFrame to an Athena Iceberg table, supporting table creation, schema evolution, and advanced merge operations. - Creates temporary external table, writes staged files and inserts via INSERT INTO ... SELECT. + This function inserts data into an Athena Iceberg table, creating the table if it does not exist. It supports multiple write modes (append, overwrite, overwrite_partitions), schema evolution, and conditional merge logic using Athena's MERGE INTO statement. Advanced options allow for custom merge conditions, partitioning, and table properties. Parameters ---------- @@ -410,8 +523,21 @@ def to_iceberg( # noqa: PLR0913 List of column names that will be used for conditional inserts and updates. https://docs.aws.amazon.com/athena/latest/ug/merge-into-statement.html + merge_on_clause + Custom ON clause for the MERGE statement. If specified, this string will be used as the ON condition + between the target and source tables, allowing for complex join logic beyond simple equality on columns. + Cannot be used together with ``merge_cols``. + It must produce at most one match per target row. Using OR conditions may result in merge failures. merge_condition - The condition to be used in the MERGE INTO statement. Valid values: ['update', 'ignore']. + The condition to be used in the MERGE INTO statement. Valid values: ['update', 'ignore', 'conditional_merge']. + merge_conditional_clauses + List of dictionaries specifying custom conditional clauses for the MERGE statement. + Each dictionary should have: + - 'when': One of ['MATCHED', 'NOT MATCHED', 'NOT MATCHED BY SOURCE'] + - 'action': One of ['UPDATE', 'DELETE', 'INSERT'] + - 'condition': (optional) Additional SQL condition for the clause + - 'columns': (optional) List of columns to update or insert + Used only when merge_condition is 'conditional_merge'. merge_match_nulls Instruct whether to have nulls in the merge condition match other nulls keep_files @@ -498,7 +624,10 @@ def to_iceberg( # noqa: PLR0913 mode=mode, partition_cols=partition_cols, merge_cols=merge_cols, + merge_on_clause=merge_on_clause, merge_condition=merge_condition, + merge_conditional_clauses=merge_conditional_clauses, + merge_match_nulls=merge_match_nulls, ) glue_table_settings = glue_table_settings if glue_table_settings else {} @@ -621,7 +750,9 @@ def to_iceberg( # noqa: PLR0913 table=table, source_table=temp_table, merge_cols=merge_cols, + merge_on_clause=merge_on_clause, merge_condition=merge_condition, + merge_conditional_clauses=merge_conditional_clauses, merge_match_nulls=merge_match_nulls, kms_key=kms_key, boto3_session=boto3_session, diff --git a/tests/unit/test_athena_iceberg.py b/tests/unit/test_athena_iceberg.py index 9954375b0..004a3a22f 100644 --- a/tests/unit/test_athena_iceberg.py +++ b/tests/unit/test_athena_iceberg.py @@ -1005,6 +1005,122 @@ def test_athena_delete_from_iceberg_empty_df_error( ) +def test_to_iceberg_merge_cols_and_merge_on_clause_error( + path: str, path2: str, glue_database: str, glue_table: str +) -> None: + df = pd.DataFrame({"id": [1], "val": ["a"]}) + with pytest.raises(wr.exceptions.InvalidArgumentCombination): + wr.athena.to_iceberg( + df=df, + database=glue_database, + table=glue_table, + table_location=path, + temp_path=path2, + merge_cols=["id"], + merge_on_clause="id = source.id", + ) + + +def test_to_iceberg_merge_match_nulls_with_merge_on_clause_error( + path: str, path2: str, glue_database: str, glue_table: str +) -> None: + df = pd.DataFrame({"id": [1], "val": ["a"]}) + with pytest.raises(wr.exceptions.InvalidArgumentCombination): + wr.athena.to_iceberg( + df=df, + database=glue_database, + table=glue_table, + table_location=path, + temp_path=path2, + merge_on_clause="id = source.id", + merge_match_nulls=True, + ) + + +def test_to_iceberg_merge_conditional_clauses_without_conditional_merge_error( + path: str, path2: str, glue_database: str, glue_table: str +) -> None: + df = pd.DataFrame({"id": [1], "val": ["a"]}) + with pytest.raises(wr.exceptions.InvalidArgumentCombination): + wr.athena.to_iceberg( + df=df, + database=glue_database, + table=glue_table, + table_location=path, + temp_path=path2, + merge_cols=["id"], + merge_conditional_clauses=[{"when": "MATCHED", "action": "UPDATE"}], + merge_condition="update", + ) + + +def test_to_iceberg_conditional_merge_without_clauses_error( + path: str, path2: str, glue_database: str, glue_table: str +) -> None: + df = pd.DataFrame({"id": [1], "val": ["a"]}) + with pytest.raises(wr.exceptions.InvalidArgumentCombination): + wr.athena.to_iceberg( + df=df, + database=glue_database, + table=glue_table, + table_location=path, + temp_path=path2, + merge_cols=["id"], + merge_condition="conditional_merge", + ) + + +def test_to_iceberg_invalid_merge_condition_error(path: str, path2: str, glue_database: str, glue_table: str) -> None: + df = pd.DataFrame({"id": [1], "val": ["a"]}) + with pytest.raises(wr.exceptions.InvalidArgumentValue): + wr.athena.to_iceberg( + df=df, + database=glue_database, + table=glue_table, + table_location=path, + temp_path=path2, + merge_cols=["id"], + merge_condition="not_a_valid_condition", + ) + + +def test_to_iceberg_conditional_merge_happy_path(path: str, path2: str, glue_database: str, glue_table: str) -> None: + df = pd.DataFrame({"id": [1, 2], "val": ["a", "b"]}) + wr.athena.to_iceberg( + df=df, + database=glue_database, + table=glue_table, + table_location=path, + temp_path=path2, + keep_files=False, + ) + df2 = pd.DataFrame({"id": [1, 3], "val": ["c", "d"]}) + clauses = [ + {"when": "MATCHED", "action": "UPDATE", "columns": ["val"]}, + {"when": "NOT MATCHED", "action": "INSERT"}, + ] + wr.athena.to_iceberg( + df=df2, + database=glue_database, + table=glue_table, + table_location=path, + temp_path=path2, + merge_cols=["id"], + merge_condition="conditional_merge", + merge_conditional_clauses=clauses, + keep_files=False, + ) + df_out = wr.athena.read_sql_query( + sql=f'SELECT * FROM "{glue_table}" ORDER BY id', + database=glue_database, + ctas_approach=False, + unload_approach=False, + ) + # id=1 should be updated, id=2 should remain, id=3 should be inserted + expected = pd.DataFrame({"id": [1, 2, 3], "val": ["c", "b", "d"]}) + assert_pandas_equals(expected, df_out.reset_index(drop=True)) + + def test_athena_iceberg_use_partition_function( path: str, path2: str,