diff --git a/plugins/duckdb/README.md b/plugins/duckdb/README.md new file mode 100644 index 000000000..394560326 --- /dev/null +++ b/plugins/duckdb/README.md @@ -0,0 +1,107 @@ +# DuckDB Plugin for Flyte + +Run DuckDB SQL queries as Flyte tasks with DataFrame inputs, parameterized queries, and extension support. + +DuckDB is an embedded analytical database (like SQLite for OLAP). Queries execute locally and synchronously. + +## Installation + +```bash +pip install flyteplugins-duckdb +``` + +## Quick start + +```python +import pandas as pd +from flyteplugins.duckdb import DuckDB + +analyze = DuckDB( + name="analyze", + query="SELECT SUM(a) AS total FROM mydf", + inputs={"mydf": pd.DataFrame}, +) +``` + +## DataFrame inputs + +Pass pandas DataFrames or PyArrow Tables as inputs. They are registered as virtual tables queryable by name: + +```python +import pyarrow as pa + +task = DuckDB( + name="join_tables", + query="SELECT a.name, b.total FROM users a JOIN orders b ON a.id = b.user_id", + inputs={"users": pd.DataFrame, "orders": pa.Table}, +) +``` + +You can also pass `flyte.io.DataFrame` for interoperability with any DataFrame type in the Flyte ecosystem. + +## Parameterized queries + +Use `?` or `$N` placeholders with list parameters: + +```python +task = DuckDB( + name="filtered", + query="SELECT * FROM mydf WHERE age > ?", + inputs={"mydf": pd.DataFrame, "params": list}, +) +``` + +## Multiple queries + +Pass a list of queries. All are executed in order and the result of the last query is returned: + +```python +task = DuckDB( + name="etl", + query=[ + "CREATE TABLE staging AS SELECT * FROM raw WHERE active = true", + "SELECT department, COUNT(*) AS cnt FROM staging GROUP BY department", + ], + inputs={"raw": pd.DataFrame}, +) +``` + +## Runtime queries + +Omit `query` and provide it at execution time via a `query` string input: + +```python +task = DuckDB( + name="dynamic", + inputs={"mydf": pd.DataFrame, "query": str}, +) +``` + +## Extensions + +DuckDB extensions are auto-installed and loaded before query execution: + +```python +from flyteplugins.duckdb import DuckDBConfig + +task = DuckDB( + name="s3_query", + query="SELECT * FROM 's3://bucket/data.parquet' LIMIT 100", + config=DuckDBConfig(extensions=["httpfs"]), +) +``` + +Common extensions: `httpfs`, `json`, `spatial`, `excel`, `parquet`. + +## Configuration + +```python +from flyteplugins.duckdb import DuckDBConfig + +config = DuckDBConfig( + database_path=":memory:", # default; or "/path/to/file.duckdb" + extensions=["httpfs", "json"], +) + +task = DuckDB(name="my_task", query="SELECT 1", config=config) +``` diff --git a/plugins/duckdb/pyproject.toml b/plugins/duckdb/pyproject.toml new file mode 100644 index 000000000..c65a01d0f --- /dev/null +++ b/plugins/duckdb/pyproject.toml @@ -0,0 +1,81 @@ +[project] +name = "flyteplugins-duckdb" +dynamic = ["version"] +description = "DuckDB plugin for flyte" +readme = "README.md" +authors = [{ name = "Andre Ahlert", email = "andreahlert@users.noreply.github.com" }] +requires-python = ">=3.10" +dependencies = [ + "duckdb", + "flyte", + "pandas", + "pyarrow", +] + +[dependency-groups] +dev = [ + "pytest>=8.3.5", + "pytest-asyncio>=0.26.0", +] + +[build-system] +requires = ["setuptools", "setuptools_scm"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +include-package-data = true +license-files = ["licenses/*.txt", "LICENSE"] + +[tool.setuptools.packages.find] +where = ["src"] +include = ["flyteplugins*"] + +[tool.setuptools_scm] +root = "../../" + +[tool.pytest.ini_options] +norecursedirs = [] +log_cli = true +log_cli_level = 20 +markers = [] +asyncio_default_fixture_loop_scope = "function" + +[tool.coverage.run] +branch = true + +[tool.ruff] +line-length = 120 + +[tool.ruff.lint] +select = [ + "E", + "W", + "F", + "I", + "PLW", + "YTT", + "ASYNC", + "C4", + "T10", + "EXE", + "ISC", + "LOG", + "PIE", + "Q", + "RSE", + "FLY", + "PGH", + "PLC", + "PLE", + "PLW", + "FURB", + "RUF", +] +ignore = ["PGH003", "PLC0415"] + +[tool.ruff.lint.per-file-ignores] +"examples/*" = ["E402"] +"tests/*" = ["ASYNC230", "ASYNC240"] + +[tool.uv.sources] +flyte = { path = "../../", editable = true } diff --git a/plugins/duckdb/src/flyteplugins/duckdb/__init__.py b/plugins/duckdb/src/flyteplugins/duckdb/__init__.py new file mode 100644 index 000000000..f6a4d6639 --- /dev/null +++ b/plugins/duckdb/src/flyteplugins/duckdb/__init__.py @@ -0,0 +1,3 @@ +__all__ = ["DuckDB", "DuckDBConfig"] + +from flyteplugins.duckdb.task import DuckDB, DuckDBConfig diff --git a/plugins/duckdb/src/flyteplugins/duckdb/task.py b/plugins/duckdb/src/flyteplugins/duckdb/task.py new file mode 100644 index 000000000..b0dc9dd2f --- /dev/null +++ b/plugins/duckdb/src/flyteplugins/duckdb/task.py @@ -0,0 +1,176 @@ +import json +import typing +from dataclasses import dataclass + +from flyte._utils import lazy_module +from flyte.extend import TaskTemplate +from flyte.io import DataFrame +from flyte.models import NativeInterface, SerializationContext +from flyteidl2.core import tasks_pb2 + +if typing.TYPE_CHECKING: + import pandas as pd + import pyarrow as pa +else: + pd = lazy_module("pandas") + pa = lazy_module("pyarrow") + +duckdb = lazy_module("duckdb") + + +@dataclass +class DuckDBConfig: + """Configuration for a DuckDB task. + + Args: + database_path: Path to a DuckDB database file, or ":memory:" for in-memory. + extensions: List of DuckDB extensions to install and load before query execution + (e.g., ["httpfs", "spatial", "json"]). + """ + + database_path: str = ":memory:" + extensions: typing.Optional[typing.List[str]] = None + + +class DuckDB(TaskTemplate): + """Run SQL queries against DuckDB as a Flyte task. + + DuckDB is an embedded analytical database (like SQLite for OLAP). Queries execute + locally and synchronously, with no remote credentials or polling required. + + Supports DataFrame inputs (registered as virtual tables in DuckDB), parameterized + queries with ``?`` or ``$N`` placeholders, extension loading, and multi-query execution. + + Args: + name: Task name. + query: SQL query string or list of queries to execute in sequence. The result of + the last query is returned. If None, must be provided at runtime via a + ``query`` string input. + inputs: Input name-to-type mapping. DataFrame types (``pd.DataFrame``, + ``pa.Table``, ``flyte.io.DataFrame``) are registered as queryable virtual + tables. ``list`` or ``str`` types are used as query parameters. + config: Optional DuckDB configuration. Defaults to in-memory database. + + Example:: + + import pandas as pd + from flyteplugins.duckdb import DuckDB + + analyze = DuckDB( + name="analyze", + query="SELECT SUM(a) AS total FROM mydf", + inputs={"mydf": pd.DataFrame}, + ) + """ + + _TASK_TYPE = "duckdb" + + def __init__( + self, + name: str, + query: typing.Optional[typing.Union[str, typing.List[str]]] = None, + inputs: typing.Optional[typing.Dict[str, type]] = None, + config: typing.Optional[DuckDBConfig] = None, + **kwargs, + ): + super().__init__( + name=name, + task_type=self._TASK_TYPE, + image=None, + interface=NativeInterface( + {k: (v, None) for k, v in inputs.items()} if inputs else {}, + {"result": DataFrame}, + ), + **kwargs, + ) + self._query = query + self._config = config or DuckDBConfig() + + async def execute(self, **kwargs) -> DataFrame: + con = duckdb.connect(database=self._config.database_path) + try: + for ext in self._config.extensions or []: + con.install_extension(ext) + con.load_extension(ext) + + params = None + query = self._query + + for key, val in kwargs.items(): + if key == "query" and isinstance(val, str): + query = val + elif isinstance(val, (pd.DataFrame, pa.Table)): + con.register(key, val) + elif isinstance(val, DataFrame): + raw = val.val + if raw is not None: + if isinstance(raw, pa.Table): + arrow_table = raw + elif isinstance(raw, pd.DataFrame): + arrow_table = pa.Table.from_pandas(raw) + else: + arrow_table = pa.table(raw) + else: + arrow_table = await val.open(pa.Table).all() + con.register(key, arrow_table) + elif isinstance(val, list): + params = val + elif isinstance(val, str): + params = json.loads(val) + else: + raise ValueError(f"Unsupported input type for '{key}': {type(val)}") + + if query is None: + raise ValueError("A query must be provided at task definition or at runtime via a 'query' input.") + + queries = query if isinstance(query, list) else [query] + if not queries: + raise ValueError("Query list must not be empty.") + result = self._execute_queries(con, queries, params) + return DataFrame.wrap_df(result.to_arrow_table()) + finally: + con.close() + + def _execute_queries(self, con, queries: typing.List[str], params=None): + """Execute queries in sequence, returning the DuckDB result of the last one. + + When params is a nested list (params[0] is a list), each parameterized query + consumes the next element from params in order. Otherwise all parameterized + queries share the same params list. + """ + multiple_params = params is not None and len(params) > 0 and isinstance(params[0], list) + counter = -1 + result = None + + for query in queries: + has_placeholders = "?" in query or "$" in query + + if has_placeholders and params is not None: + if multiple_params: + counter += 1 + if counter >= len(params): + raise ValueError(f"Not enough parameter sets for parameterized query #{counter + 1}.") + current_params = params[counter] + else: + current_params = params + + if query.lstrip().lower().startswith("insert"): + result = con.executemany(query, current_params) + else: + result = con.execute(query, current_params) + else: + result = con.execute(query) + + return result + + def custom_config(self, sctx: SerializationContext) -> typing.Optional[typing.Dict[str, typing.Any]]: + config: typing.Dict[str, typing.Any] = {"database_path": self._config.database_path} + if self._config.extensions: + config["extensions"] = self._config.extensions + return config + + def sql(self, sctx: SerializationContext) -> typing.Optional[tasks_pb2.Sql]: + if self._query is None: + return None + statement = self._query[-1] if isinstance(self._query, list) else self._query + return tasks_pb2.Sql(statement=statement, dialect=tasks_pb2.Sql.Dialect.ANSI) diff --git a/plugins/duckdb/tests/__init__.py b/plugins/duckdb/tests/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/plugins/duckdb/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/plugins/duckdb/tests/test_task.py b/plugins/duckdb/tests/test_task.py new file mode 100644 index 000000000..58bf9283d --- /dev/null +++ b/plugins/duckdb/tests/test_task.py @@ -0,0 +1,338 @@ +import json + +import pandas as pd +import pyarrow as pa +import pytest +from flyte.io import DataFrame +from flyte.models import SerializationContext +from flyteidl2.core.tasks_pb2 import Sql + +from flyteplugins.duckdb import DuckDB, DuckDBConfig + +SCTX = SerializationContext(version="test") + + +def _make_task(**kwargs) -> DuckDB: + defaults = {"name": "test", "query": "SELECT 1"} + defaults.update(kwargs) + return DuckDB(**defaults) + + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + + +class TestDuckDBConfig: + def test_defaults(self): + config = DuckDBConfig() + assert config.database_path == ":memory:" + assert config.extensions is None + + def test_custom(self): + config = DuckDBConfig(database_path="/tmp/test.duckdb", extensions=["httpfs", "json"]) + assert config.database_path == "/tmp/test.duckdb" + assert config.extensions == ["httpfs", "json"] + + +# --------------------------------------------------------------------------- +# Task creation +# --------------------------------------------------------------------------- + + +class TestDuckDBTask: + def test_task_type(self): + task = _make_task() + assert task._TASK_TYPE == "duckdb" + assert task.task_type == "duckdb" + + def test_default_config(self): + task = _make_task() + assert task._config.database_path == ":memory:" + assert task._config.extensions is None + + def test_custom_config(self): + config = DuckDBConfig(database_path="/data/test.duckdb", extensions=["json"]) + task = _make_task(config=config) + assert task._config.database_path == "/data/test.duckdb" + assert task._config.extensions == ["json"] + + def test_no_image(self): + task = _make_task() + assert task.image is None + + +# --------------------------------------------------------------------------- +# Serialization: custom_config +# --------------------------------------------------------------------------- + + +class TestCustomConfig: + def test_default(self): + task = _make_task() + config = task.custom_config(SCTX) + assert config == {"database_path": ":memory:"} + + def test_with_extensions(self): + task = _make_task(config=DuckDBConfig(extensions=["httpfs", "spatial"])) + config = task.custom_config(SCTX) + assert config == {"database_path": ":memory:", "extensions": ["httpfs", "spatial"]} + + def test_full(self): + task = _make_task(config=DuckDBConfig(database_path="/tmp/t.duckdb", extensions=["json"])) + config = task.custom_config(SCTX) + assert config == {"database_path": "/tmp/t.duckdb", "extensions": ["json"]} + + +# --------------------------------------------------------------------------- +# Serialization: sql +# --------------------------------------------------------------------------- + + +class TestSql: + def test_single_query(self): + task = _make_task(query="SELECT * FROM users") + sql = task.sql(SCTX) + assert sql.statement == "SELECT * FROM users" + assert sql.dialect == Sql.Dialect.ANSI + + def test_multi_query_returns_last(self): + task = _make_task(query=["CREATE TABLE t (id INT)", "SELECT * FROM t"]) + sql = task.sql(SCTX) + assert sql.statement == "SELECT * FROM t" + + def test_no_query_returns_none(self): + task = DuckDB(name="dynamic", query=None, inputs={"query": str}) + assert task.sql(SCTX) is None + + +# --------------------------------------------------------------------------- +# Execute: basic queries +# --------------------------------------------------------------------------- + + +class TestExecute: + @pytest.mark.asyncio + async def test_simple_select(self): + task = _make_task(query="SELECT 42 AS answer") + result = await task.execute() + df = result.val.to_pandas() + assert df["answer"].iloc[0] == 42 + + @pytest.mark.asyncio + async def test_range_query(self): + task = _make_task(query="SELECT * FROM range(5)") + result = await task.execute() + df = result.val.to_pandas() + assert len(df) == 5 + + @pytest.mark.asyncio + async def test_no_query_raises(self): + task = DuckDB(name="empty", query=None) + with pytest.raises(ValueError, match="query must be provided"): + await task.execute() + + @pytest.mark.asyncio + async def test_empty_query_list_raises(self): + task = DuckDB(name="empty_list", query=[]) + with pytest.raises(ValueError, match="must not be empty"): + await task.execute() + + @pytest.mark.asyncio + async def test_ddl_only_query_returns_empty(self): + task = _make_task(query="CREATE TABLE t (id INTEGER)") + result = await task.execute() + assert result.val is not None + + +# --------------------------------------------------------------------------- +# Execute: DataFrame inputs +# --------------------------------------------------------------------------- + + +class TestDataFrameInputs: + @pytest.mark.asyncio + async def test_pandas_input(self): + task = _make_task( + query="SELECT SUM(a) AS total FROM mydf", + inputs={"mydf": pd.DataFrame}, + ) + df = pd.DataFrame({"a": [1, 2, 3]}) + result = await task.execute(mydf=df) + out = result.val.to_pandas() + assert out["total"].iloc[0] == 6 + + @pytest.mark.asyncio + async def test_arrow_input(self): + task = _make_task( + query="SELECT * FROM arrow_table WHERE i = 2", + inputs={"arrow_table": pa.Table}, + ) + table = pa.table({"i": [1, 2, 3], "j": ["a", "b", "c"]}) + result = await task.execute(arrow_table=table) + out = result.val.to_pandas() + assert len(out) == 1 + assert out["j"].iloc[0] == "b" + + @pytest.mark.asyncio + async def test_flyte_dataframe_input(self): + task = _make_task( + query="SELECT SUM(a) AS total FROM mydf", + inputs={"mydf": DataFrame}, + ) + raw = pd.DataFrame({"a": [10, 20, 30]}) + fdf = DataFrame.from_df(raw) + result = await task.execute(mydf=fdf) + out = result.val.to_pandas() + assert out["total"].iloc[0] == 60 + + @pytest.mark.asyncio + async def test_multiple_dataframe_inputs(self): + task = _make_task( + query="SELECT a.x, b.y FROM df_a a JOIN df_b b ON a.id = b.id", + inputs={"df_a": pd.DataFrame, "df_b": pd.DataFrame}, + ) + df_a = pd.DataFrame({"id": [1, 2], "x": ["foo", "bar"]}) + df_b = pd.DataFrame({"id": [1, 2], "y": [100, 200]}) + result = await task.execute(df_a=df_a, df_b=df_b) + out = result.val.to_pandas() + assert len(out) == 2 + assert set(out["y"]) == {100, 200} + + +# --------------------------------------------------------------------------- +# Execute: parameterized queries +# --------------------------------------------------------------------------- + + +class TestInsertDetection: + @pytest.mark.asyncio + async def test_column_named_insert_not_treated_as_insert(self): + """A SELECT on a column named 'insert_date' should use execute(), not executemany().""" + task = _make_task( + query=[ + "CREATE TABLE log (insert_date DATE, val INTEGER)", + "INSERT INTO log VALUES ('2026-01-01', 1)", + "SELECT insert_date FROM log WHERE val = ?", + ], + inputs={"params": list}, + ) + result = await task.execute(params=[1]) + out = result.val.to_pandas() + assert len(out) == 1 + + +class TestParameterizedQueries: + @pytest.mark.asyncio + async def test_positional_params(self): + task = _make_task( + query="SELECT * FROM range(10) WHERE range > ?", + inputs={"params": list}, + ) + result = await task.execute(params=[5]) + out = result.val.to_pandas() + assert len(out) == 4 + assert all(out["range"] > 5) + + @pytest.mark.asyncio + async def test_dollar_params(self): + task = _make_task( + query="SELECT $1 AS col1, $2 AS col2", + inputs={"params": list}, + ) + result = await task.execute(params=["hello", "world"]) + out = result.val.to_pandas() + assert out["col1"].iloc[0] == "hello" + assert out["col2"].iloc[0] == "world" + + @pytest.mark.asyncio + async def test_json_string_params(self): + task = _make_task( + query="SELECT $1 AS val", + inputs={"params": str}, + ) + result = await task.execute(params=json.dumps(["test_value"])) + out = result.val.to_pandas() + assert out["val"].iloc[0] == "test_value" + + +# --------------------------------------------------------------------------- +# Execute: multi-query +# --------------------------------------------------------------------------- + + +class TestMultiQuery: + @pytest.mark.asyncio + async def test_create_insert_select(self): + task = _make_task( + query=[ + "CREATE TABLE items (name VARCHAR, price INTEGER)", + "INSERT INTO items VALUES ('apple', 1), ('banana', 2)", + "SELECT SUM(price) AS total FROM items", + ], + ) + result = await task.execute() + out = result.val.to_pandas() + assert out["total"].iloc[0] == 3 + + @pytest.mark.asyncio + async def test_multi_query_with_multi_params(self): + task = _make_task( + query=[ + "CREATE TABLE items (name VARCHAR, price DECIMAL(10,2))", + "INSERT INTO items VALUES (?, ?)", + "SELECT $1 AS col1, $2 AS col2", + ], + inputs={"params": str}, + ) + params = [[["apple", 1.0], ["banana", 2.0]], ["hello", "world"]] + result = await task.execute(params=json.dumps(params)) + out = result.val.to_pandas() + assert out["col1"].iloc[0] == "hello" + assert out["col2"].iloc[0] == "world" + + +# --------------------------------------------------------------------------- +# Execute: runtime query +# --------------------------------------------------------------------------- + + +class TestRuntimeQuery: + @pytest.mark.asyncio + async def test_query_from_input(self): + task = DuckDB( + name="dynamic", + query=None, + inputs={"mydf": pd.DataFrame, "query": str}, + ) + df = pd.DataFrame({"x": [10, 20, 30]}) + result = await task.execute(mydf=df, query="SELECT MAX(x) AS max_x FROM mydf") + out = result.val.to_pandas() + assert out["max_x"].iloc[0] == 30 + + @pytest.mark.asyncio + async def test_runtime_query_overrides_default(self): + task = _make_task( + query="SELECT 1 AS original", + inputs={"query": str}, + ) + result = await task.execute(query="SELECT 99 AS overridden") + out = result.val.to_pandas() + assert out["overridden"].iloc[0] == 99 + + +# --------------------------------------------------------------------------- +# Execute: extensions +# --------------------------------------------------------------------------- + + +class TestExtensions: + @pytest.mark.asyncio + async def test_json_extension(self): + task = _make_task( + query="SELECT json_extract('{\"key\": \"value\"}', '$.key') AS val", + config=DuckDBConfig(extensions=["json"]), + ) + result = await task.execute() + out = result.val.to_pandas() + assert "value" in str(out["val"].iloc[0])