diff --git a/src/pydiverse/pipedag/backend/table/sql/ddl.py b/src/pydiverse/pipedag/backend/table/sql/ddl.py index 4728d657..fab558e1 100644 --- a/src/pydiverse/pipedag/backend/table/sql/ddl.py +++ b/src/pydiverse/pipedag/backend/table/sql/ddl.py @@ -356,6 +356,20 @@ def __init__( self.nullable = nullable +class ChangeColumnAutoincrement(DDLElement): + def __init__( + self, + table_name: str, + schema: Schema, + column_names: list[str], + autoincrement: list[bool | str], + ): + self.table_name = table_name + self.schema = schema + self.column_names = column_names + self.autoincrement = autoincrement + + class ChangeTableLogged(DDLElement): """Changes a postgres table from LOGGED to UNLOGGED (or vice-versa) diff --git a/src/pydiverse/pipedag/backend/table/sql/dialects/duckdb.py b/src/pydiverse/pipedag/backend/table/sql/dialects/duckdb.py index 37e071cf..916a1c06 100644 --- a/src/pydiverse/pipedag/backend/table/sql/dialects/duckdb.py +++ b/src/pydiverse/pipedag/backend/table/sql/dialects/duckdb.py @@ -87,7 +87,7 @@ def _execute_materialize( # Create empty table with correct schema cls._dialect_create_empty_table(store, df, table, schema, dtypes) - store.add_indexes_and_set_nullable( + store.postprocess_table_creation( table, schema, on_empty_table=True, table_cols=df.columns ) @@ -101,11 +101,8 @@ def _execute_materialize( with duckdb.connect(connection_uri) as conn: conn.execute(f"INSERT INTO {schema_name}.{table_name} SELECT * FROM df") - store.add_indexes_and_set_nullable( - table, - schema, - on_empty_table=False, - table_cols=df.columns, + store.postprocess_table_creation( + table, schema, on_empty_table=False, table_cols=df.columns ) diff --git a/src/pydiverse/pipedag/backend/table/sql/dialects/ibm_db2.py b/src/pydiverse/pipedag/backend/table/sql/dialects/ibm_db2.py index 8ec8419e..29e33054 100644 --- a/src/pydiverse/pipedag/backend/table/sql/dialects/ibm_db2.py +++ b/src/pydiverse/pipedag/backend/table/sql/dialects/ibm_db2.py @@ -148,7 +148,7 @@ def get_forced_nullability_columns( ] return nullable_cols, non_nullable_cols - def add_indexes_and_set_nullable( + def postprocess_table_creation( self, table: Table, schema: Schema, @@ -156,7 +156,7 @@ def add_indexes_and_set_nullable( on_empty_table: bool | None = None, table_cols: Iterable[str] | None = None, ): - super().add_indexes_and_set_nullable( + super().postprocess_table_creation( table, schema, on_empty_table=on_empty_table, table_cols=table_cols ) table_name = self.engine.dialect.identifier_preparer.quote(table.name) diff --git a/src/pydiverse/pipedag/backend/table/sql/dialects/mssql.py b/src/pydiverse/pipedag/backend/table/sql/dialects/mssql.py index 4260dd49..d2816532 100644 --- a/src/pydiverse/pipedag/backend/table/sql/dialects/mssql.py +++ b/src/pydiverse/pipedag/backend/table/sql/dialects/mssql.py @@ -127,7 +127,7 @@ def get_forced_nullability_columns( # the list of nullable columns as well return self._process_table_nullable_parameters(table, table_cols) - def add_indexes_and_set_nullable( + def postprocess_table_creation( self, table: Table, schema: Schema, diff --git a/src/pydiverse/pipedag/backend/table/sql/dialects/postgres.py b/src/pydiverse/pipedag/backend/table/sql/dialects/postgres.py index b567b9cb..3f6b8aed 100644 --- a/src/pydiverse/pipedag/backend/table/sql/dialects/postgres.py +++ b/src/pydiverse/pipedag/backend/table/sql/dialects/postgres.py @@ -130,7 +130,7 @@ def _execute_materialize( # Create empty table cls._dialect_create_empty_table(store, df, table, schema, dtypes) - store.add_indexes_and_set_nullable( + store.postprocess_table_creation( table, schema, on_empty_table=True, table_cols=df.columns ) diff --git a/src/pydiverse/pipedag/backend/table/sql/hooks.py b/src/pydiverse/pipedag/backend/table/sql/hooks.py index 0c158782..0f5350d1 100644 --- a/src/pydiverse/pipedag/backend/table/sql/hooks.py +++ b/src/pydiverse/pipedag/backend/table/sql/hooks.py @@ -99,7 +99,7 @@ def materialize( suffix=suffix, ) ) - store.add_indexes_and_set_nullable(table, schema, on_empty_table=True) + store.postprocess_table_creation(table, schema, on_empty_table=True) statements = store.lock_table(table, schema) statements += store.lock_source_tables(source_tables) statements += [ @@ -113,7 +113,7 @@ def materialize( statements, truncate_printed_select=True, ) - store.add_indexes_and_set_nullable(table, schema, on_empty_table=False) + store.postprocess_table_creation(table, schema, on_empty_table=False) else: statements = store.lock_source_tables(source_tables) statements += [ @@ -126,7 +126,7 @@ def materialize( ) ] store.execute(statements) - store.add_indexes_and_set_nullable(table, schema) + store.postprocess_table_creation(table, schema) @classmethod def retrieve( @@ -323,7 +323,7 @@ def _execute_materialize( if early := store.dialect_requests_empty_creation(table, is_sql=False): cls._dialect_create_empty_table(store, df, table, schema, dtypes) - store.add_indexes_and_set_nullable( + store.postprocess_table_creation( table, schema, on_empty_table=True, table_cols=df.columns ) @@ -340,7 +340,7 @@ def _execute_materialize( chunksize=100_000, if_exists="append" if early else "fail", ) - store.add_indexes_and_set_nullable( + store.postprocess_table_creation( table, schema, on_empty_table=False if early else None, diff --git a/src/pydiverse/pipedag/backend/table/sql/sql.py b/src/pydiverse/pipedag/backend/table/sql/sql.py index 7ea94b64..4d5a424c 100644 --- a/src/pydiverse/pipedag/backend/table/sql/sql.py +++ b/src/pydiverse/pipedag/backend/table/sql/sql.py @@ -16,6 +16,7 @@ from pydiverse.pipedag.backend.table.sql.ddl import ( AddIndex, AddPrimaryKey, + ChangeColumnAutoincrement, ChangeColumnNullable, CopyTable, CreateAlias, @@ -635,6 +636,38 @@ def get_forced_nullability_columns( # in most dialects columns are nullable by default return [], non_nullable_cols + def get_autoincrement_options( + self, table: Table, table_cols: Iterable[str] + ) -> list[str | bool]: + autoincrement_options = self._process_table_autoincrement_options( + table, table_cols + ) + return autoincrement_options + + @staticmethod + def _process_table_autoincrement_options(table: Table, table_cols: Iterable[str]): + if table.autoincrement is None: + # if autoincrement not specified set to False for all columns + return [False for _ in table_cols] + + name = f'"{table.name}"' + table_cols_set = set(table_cols) + autoincrement_cols = set(table.autoincrement.keys()) + if invalid_cols := autoincrement_cols - table_cols_set: + raise ValueError( + f"The columns {invalid_cols} in Table({name}," + f" autoincrement={autoincrement_cols}) aren't contained in the table" + f" columns: {table_cols}" + ) + + autoincrement_out = [] + for col in table_cols: + # cols that were not specified are set to autoincrement=False + autoincrement_out.append( + table.autoincrement[col] if col in autoincrement_cols else False + ) + return autoincrement_out + @staticmethod def _process_table_nullable_parameters(table: Table, table_cols: Iterable[str]): name = f'"{table.name}"' @@ -686,7 +719,7 @@ def dialect_requests_empty_creation(self, table: Table, is_sql: bool) -> bool: _ = is_sql return table.nullable is not None or table.non_nullable is not None - def add_indexes_and_set_nullable( + def postprocess_table_creation( self, table: Table, schema: Schema, @@ -704,6 +737,7 @@ def add_indexes_and_set_nullable( nullable_cols, non_nullable_cols = self.get_forced_nullability_columns( table, table_cols ) + if len(nullable_cols) > 0: # some dialects represent literals as non-nullable types self.execute( @@ -717,6 +751,16 @@ def add_indexes_and_set_nullable( table.name, schema, non_nullable_cols, nullable=False ) ) + + autoincrement_col_options = self.get_autoincrement_options( + table, table_cols + ) + self.execute( + ChangeColumnAutoincrement( + table.name, schema, table_cols, autoincrement_col_options + ) + ) + if on_empty_table is None or not on_empty_table: # By default, we set indexes after loading full table. This can be # overridden by dialect diff --git a/src/pydiverse/pipedag/materialize/container.py b/src/pydiverse/pipedag/materialize/container.py index 061c519b..9717efca 100644 --- a/src/pydiverse/pipedag/materialize/container.py +++ b/src/pydiverse/pipedag/materialize/container.py @@ -43,7 +43,11 @@ def task(): is not None, all other columns will be nullable. :param materialization_details: The label of the materialization_details to be used. Overwrites the label given by the stage. - + :param autoincrement: Dictionary holding the values for the autoincrement property + of the different columns. Eg. {'col1': True, 'col2': False}. + The default is False. + Refer to https://docs.sqlalchemy.org/en/20/core/metadata.html#sqlalchemy.schema.Column.params.autoincrement + for a documentation of the autoincrement property. .. seealso:: You can specify which types of objects should automatically get converted to tables using the :ref:`auto_table` config option. """ @@ -59,6 +63,7 @@ def __init__( nullable: list[str] | None = None, non_nullable: list[str] | None = None, materialization_details: str | None = None, + autoincrement: dict[str, str | bool] = None, ): self._name = None self.stage: Stage | None = None @@ -73,6 +78,7 @@ def __init__( self.nullable = nullable self.non_nullable = non_nullable self.materialization_details = materialization_details + self.autoincrement = autoincrement # Check that indexes is of type list[list[str]] indexes_type_error = TypeError( @@ -100,7 +106,17 @@ def __init__( raise type_error if not all(isinstance(x, str) for x in arg): raise type_error - + if self.autoincrement is not None: + type_error = ( + "Table argument: autoincrement must be of type dict[str, str | bool]" + ) + if not isinstance(self.autoincrement, dict): + raise type_error + for key, val in self.autoincrement.items(): + if not isinstance(key, str): + raise type_error + if (not isinstance(val, str)) and (not isinstance(val, bool)): + raise type_error from pydiverse.pipedag.backend.table.sql import ExternalTableReference # ExternalTableReference can reference a table from an external schema