From 96e8e8999acfa1ab70fa311925e96494608ccb18 Mon Sep 17 00:00:00 2001 From: Will Sweet Date: Wed, 30 Jul 2025 13:01:48 -0400 Subject: [PATCH 1/5] use PandasCursor with Athena for %%fetchdf magic only --- sqlmesh/magics.py | 74 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 73 insertions(+), 1 deletion(-) diff --git a/sqlmesh/magics.py b/sqlmesh/magics.py index e74019a743..1b1b50b39e 100644 --- a/sqlmesh/magics.py +++ b/sqlmesh/magics.py @@ -611,7 +611,13 @@ def render(self, context: Context, line: str) -> None: def fetchdf(self, context: Context, line: str, sql: str) -> None: """Fetches a dataframe from sql, optionally storing it in a variable.""" args = parse_argstring(self.fetchdf, line) - df = context.fetchdf(sql) + + # Check if we're using Athena and use PandasCursor directly + if hasattr(context.engine_adapter, 'DIALECT') and context.engine_adapter.DIALECT == 'athena': + df = self._fetchdf_athena_pandas_cursor(context, sql) + else: + df = context.fetchdf(sql) + if args.df_var: self._shell.user_ns[args.df_var] = df self.display(df) @@ -1147,6 +1153,72 @@ def destroy(self, context: Context, line: str) -> None: """Removes all project resources, engine-managed objects, state tables and clears the SQLMesh cache.""" context.destroy() + def _fetchdf_athena_pandas_cursor(self, context: Context, sql: str) -> "pd.DataFrame": + """Special implementation for Athena using PandasCursor with SQLGlot transpilation""" + import pandas as pd + + try: + from pyathena.pandas.cursor import PandasCursor + from pyathena import connect + except ImportError as e: + raise MagicError(f"PyAthena with pandas support is required: {e}") + + # Use SQLMesh's transpilation to convert SQL to Athena dialect + # This handles features like QUALIFY that need transpilation + try: + # Parse the SQL string into a SQLGlot expression first + from sqlmesh.core.dialect import parse + parsed_expressions = parse(sql, default_dialect=context.config.dialect) + + # Get the first expression (should be a SELECT statement) + if parsed_expressions: + transpiled_sql = context.engine_adapter._to_sql(parsed_expressions[0], quote=False) + else: + raise ValueError("No valid SQL expressions found") + + except Exception as e: + context.console.log_error(f"SQL transpilation failed: {e}") + # Fall back to the regular fetchdf method if transpilation fails + return context.fetchdf(sql) + + # Get the connection configuration for Athena + conn_config = context.config.get_connection(context.config.default_connection) + + # Build connection kwargs using the same logic as SQLMesh + connection_kwargs = { + k: v for k, v in conn_config.dict().items() + if k in conn_config._connection_kwargs_keys and v is not None + } + + # Create connection with PandasCursor specifically + try: + with connect( + cursor_class=PandasCursor, + **connection_kwargs + ) as conn: + with conn.cursor() as cursor: + cursor.execute(transpiled_sql) + + # PyAthena PandasCursor needs to be converted to DataFrame manually + # It returns data but we need to use pandas.DataFrame constructor + data = cursor.fetchall() + + if data: + # Get column names from cursor description + columns = [desc[0] for desc in cursor.description] if cursor.description else None + df = pd.DataFrame(data, columns=columns) + else: + # Empty result set + columns = [desc[0] for desc in cursor.description] if cursor.description else [] + df = pd.DataFrame(columns=columns) + + return df + + except Exception as e: + # Fall back to the regular fetchdf method if PandasCursor fails + context.console.log_error(f"PandasCursor failed, falling back to standard method: {e}") + return context.fetchdf(sql) + def register_magics() -> None: try: From 959530490e1b8093f3f46970df3b82d5754f684a Mon Sep 17 00:00:00 2001 From: Will Sweet Date: Wed, 30 Jul 2025 13:59:39 -0400 Subject: [PATCH 2/5] simplify --- sqlmesh/magics.py | 56 ++++++----------------------------------------- 1 file changed, 7 insertions(+), 49 deletions(-) diff --git a/sqlmesh/magics.py b/sqlmesh/magics.py index 1b1b50b39e..e36bb87b71 100644 --- a/sqlmesh/magics.py +++ b/sqlmesh/magics.py @@ -1163,57 +1163,15 @@ def _fetchdf_athena_pandas_cursor(self, context: Context, sql: str) -> "pd.DataF except ImportError as e: raise MagicError(f"PyAthena with pandas support is required: {e}") - # Use SQLMesh's transpilation to convert SQL to Athena dialect - # This handles features like QUALIFY that need transpilation try: - # Parse the SQL string into a SQLGlot expression first - from sqlmesh.core.dialect import parse - parsed_expressions = parse(sql, default_dialect=context.config.dialect) - - # Get the first expression (should be a SELECT statement) - if parsed_expressions: - transpiled_sql = context.engine_adapter._to_sql(parsed_expressions[0], quote=False) - else: - raise ValueError("No valid SQL expressions found") - - except Exception as e: - context.console.log_error(f"SQL transpilation failed: {e}") - # Fall back to the regular fetchdf method if transpilation fails - return context.fetchdf(sql) + conn_config = context.config.get_connection(context.config.default_connection) + connection_kwargs = { + k: v for k, v in conn_config.dict().items() + if k in conn_config._connection_kwargs_keys and v is not None + } + cursor = connect(cursor_class=PandasCursor, **connection_kwargs).cursor() + return cursor.execute(sql).as_pandas() - # Get the connection configuration for Athena - conn_config = context.config.get_connection(context.config.default_connection) - - # Build connection kwargs using the same logic as SQLMesh - connection_kwargs = { - k: v for k, v in conn_config.dict().items() - if k in conn_config._connection_kwargs_keys and v is not None - } - - # Create connection with PandasCursor specifically - try: - with connect( - cursor_class=PandasCursor, - **connection_kwargs - ) as conn: - with conn.cursor() as cursor: - cursor.execute(transpiled_sql) - - # PyAthena PandasCursor needs to be converted to DataFrame manually - # It returns data but we need to use pandas.DataFrame constructor - data = cursor.fetchall() - - if data: - # Get column names from cursor description - columns = [desc[0] for desc in cursor.description] if cursor.description else None - df = pd.DataFrame(data, columns=columns) - else: - # Empty result set - columns = [desc[0] for desc in cursor.description] if cursor.description else [] - df = pd.DataFrame(columns=columns) - - return df - except Exception as e: # Fall back to the regular fetchdf method if PandasCursor fails context.console.log_error(f"PandasCursor failed, falling back to standard method: {e}") From 95758a103f43c8893439b11aadd224905ec5d55c Mon Sep 17 00:00:00 2001 From: Will Sweet Date: Wed, 30 Jul 2025 16:53:15 -0400 Subject: [PATCH 3/5] remove pandas import --- sqlmesh/magics.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sqlmesh/magics.py b/sqlmesh/magics.py index e36bb87b71..3918f62728 100644 --- a/sqlmesh/magics.py +++ b/sqlmesh/magics.py @@ -1155,7 +1155,6 @@ def destroy(self, context: Context, line: str) -> None: def _fetchdf_athena_pandas_cursor(self, context: Context, sql: str) -> "pd.DataFrame": """Special implementation for Athena using PandasCursor with SQLGlot transpilation""" - import pandas as pd try: from pyathena.pandas.cursor import PandasCursor From 8d361ce18f5152ccd7f2abe8d08e088832ab0b5c Mon Sep 17 00:00:00 2001 From: Will Sweet Date: Thu, 31 Jul 2025 08:59:25 -0400 Subject: [PATCH 4/5] import pd for type checking --- sqlmesh/magics.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sqlmesh/magics.py b/sqlmesh/magics.py index a7bb397a66..5997522ede 100644 --- a/sqlmesh/magics.py +++ b/sqlmesh/magics.py @@ -12,6 +12,9 @@ from hyperscript import h +if t.TYPE_CHECKING: + import pandas as pd + try: from IPython.core.display import display # type: ignore except ImportError: @@ -1195,7 +1198,7 @@ def destroy(self, context: Context, line: str) -> None: """Removes all project resources, engine-managed objects, state tables and clears the SQLMesh cache.""" context.destroy() - def _fetchdf_athena_pandas_cursor(self, context: Context, sql: str) -> "pd.DataFrame": + def _fetchdf_athena_pandas_cursor(self, context: Context, sql: str) -> pd.DataFrame: """Special implementation for Athena using PandasCursor with SQLGlot transpilation""" try: From da1a27a4253bb66dc74e2b4d85ee65f75495fad5 Mon Sep 17 00:00:00 2001 From: Will Sweet Date: Thu, 31 Jul 2025 09:37:31 -0400 Subject: [PATCH 5/5] format --- sqlmesh/magics.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/sqlmesh/magics.py b/sqlmesh/magics.py index 5997522ede..1d75f386d0 100644 --- a/sqlmesh/magics.py +++ b/sqlmesh/magics.py @@ -686,13 +686,16 @@ def render(self, context: Context, line: str) -> None: def fetchdf(self, context: Context, line: str, sql: str) -> None: """Fetches a dataframe from sql, optionally storing it in a variable.""" args = parse_argstring(self.fetchdf, line) - + # Check if we're using Athena and use PandasCursor directly - if hasattr(context.engine_adapter, 'DIALECT') and context.engine_adapter.DIALECT == 'athena': + if ( + hasattr(context.engine_adapter, "DIALECT") + and context.engine_adapter.DIALECT == "athena" + ): df = self._fetchdf_athena_pandas_cursor(context, sql) else: df = context.fetchdf(sql) - + if args.df_var: self._shell.user_ns[args.df_var] = df self.display(df) @@ -1200,22 +1203,23 @@ def destroy(self, context: Context, line: str) -> None: def _fetchdf_athena_pandas_cursor(self, context: Context, sql: str) -> pd.DataFrame: """Special implementation for Athena using PandasCursor with SQLGlot transpilation""" - + try: from pyathena.pandas.cursor import PandasCursor from pyathena import connect except ImportError as e: raise MagicError(f"PyAthena with pandas support is required: {e}") - + try: conn_config = context.config.get_connection(context.config.default_connection) connection_kwargs = { - k: v for k, v in conn_config.dict().items() + k: v + for k, v in conn_config.dict().items() if k in conn_config._connection_kwargs_keys and v is not None } cursor = connect(cursor_class=PandasCursor, **connection_kwargs).cursor() return cursor.execute(sql).as_pandas() - + except Exception as e: # Fall back to the regular fetchdf method if PandasCursor fails context.console.log_error(f"PandasCursor failed, falling back to standard method: {e}")