diff --git a/README.md b/README.md index 1149927c6..e5990f402 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,7 @@ Basic SQLAlchemy driver for [DuckDB](https://duckdb.org/) - [Alembic Integration](#alembic-integration) - [Preloading extensions (experimental)](#preloading-extensions-experimental) - [Registering Filesystems](#registering-filesystems) + - [Running actions right after connecting](#running-actions-right-after-connecting) - [The name](#the-name) @@ -176,7 +177,8 @@ create_engine( 'preload_extensions': ['https'], 'config': { 's3_region': 'ap-southeast-1' - } + }, + 'pre_actions': ["ATTACH 'file.db' AS file_db';"] } ) ``` @@ -199,6 +201,24 @@ create_engine( ) ``` +## Running actions right after connecting +You can run arbitrary SQL commands right after connecting by passing a list of SQL commands to the `pre_actions` parameter in `connect_args` + +```python +from sqlalchemy import create_engine +create_engine( + 'duckdb:///:memory:', + connect_args={ + 'pre_actions': [ + "ATTACH 'file.db' AS file_db';", + "SET some_config_option='some_value';" + ] + } +) +``` + ## The name Yes, I'm aware this package should be named `duckdb-driver` or something, I wasn't thinking when I named it and it's too hard to change the name now + + diff --git a/duckdb_engine/__init__.py b/duckdb_engine/__init__.py index e6b1680e2..f779018bd 100644 --- a/duckdb_engine/__init__.py +++ b/duckdb_engine/__init__.py @@ -285,6 +285,7 @@ def type_descriptor(self, typeobj: Type[sqltypes.TypeEngine]) -> Any: # type: i def connect(self, *cargs: Any, **cparams: Any) -> "Connection": core_keys = get_core_config() preload_extensions = cparams.pop("preload_extensions", []) + pre_actions = cparams.pop("pre_actions", []) config = dict(cparams.get("config", {})) cparams["config"] = config config.update(cparams.pop("url_config", {})) @@ -306,6 +307,9 @@ def connect(self, *cargs: Any, **cparams: Any) -> "Connection": for filesystem in filesystems: conn.register_filesystem(filesystem) + for action in pre_actions: + conn.execute(action) + apply_config(self, conn, ext) return ConnectionWrapper(conn) diff --git a/duckdb_engine/tests/test_basic.py b/duckdb_engine/tests/test_basic.py index 15fda1fea..d70af2021 100644 --- a/duckdb_engine/tests/test_basic.py +++ b/duckdb_engine/tests/test_basic.py @@ -279,6 +279,20 @@ def test_preload_extension() -> None: ) +def test_pre_actions() -> None: + engine = create_engine( + "duckdb:///", + connect_args={ + "pre_actions": ["INSTALL SPATIAL", "LOAD SPATIAL"], + "config": {"s3_region": "ap-southeast-2", "s3_use_ssl": True}, + }, + ) + + # check that we can use spatial functions + with engine.connect() as conn: + conn.execute(text("SELECT ST_Affine(ST_Point(1, 1),1, 0, 0, 1, 2, 3);")) + + @fixture def inspector(engine: Engine, session: Session) -> Inspector: cmds = [