Skip to content
38 changes: 37 additions & 1 deletion sqlmesh/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

from hyperscript import h

if t.TYPE_CHECKING:
import pandas as pd

try:
from IPython.core.display import display # type: ignore
except ImportError:
Expand Down Expand Up @@ -690,7 +693,16 @@ def render(self, context: Context, line: str) -> None:
def fetchdf(self, context: Context, line: str, sql: str) -> None:
"""Fetches a dataframe from sql, optionally storing it in a variable."""
args = parse_argstring(self.fetchdf, line)
df = context.fetchdf(sql)

# Check if we're using Athena and use PandasCursor directly
if (
hasattr(context.engine_adapter, "DIALECT")
and context.engine_adapter.DIALECT == "athena"
):
df = self._fetchdf_athena_pandas_cursor(context, sql)
else:
df = context.fetchdf(sql)

if args.df_var:
self._shell.user_ns[args.df_var] = df
self.display(df)
Expand Down Expand Up @@ -1196,6 +1208,30 @@ def destroy(self, context: Context, line: str) -> None:
"""Removes all project resources, engine-managed objects, state tables and clears the SQLMesh cache."""
context.destroy()

def _fetchdf_athena_pandas_cursor(self, context: Context, sql: str) -> pd.DataFrame:
"""Special implementation for Athena using PandasCursor with SQLGlot transpilation"""

try:
from pyathena.pandas.cursor import PandasCursor
from pyathena import connect
except ImportError as e:
raise MagicError(f"PyAthena with pandas support is required: {e}")

try:
conn_config = context.config.get_connection(context.config.default_connection)
connection_kwargs = {
k: v
for k, v in conn_config.dict().items()
if k in conn_config._connection_kwargs_keys and v is not None
}
cursor = connect(cursor_class=PandasCursor, **connection_kwargs).cursor()
return cursor.execute(sql).as_pandas()

except Exception as e:
# Fall back to the regular fetchdf method if PandasCursor fails
context.console.log_error(f"PandasCursor failed, falling back to standard method: {e}")
return context.fetchdf(sql)


def register_magics() -> None:
try:
Expand Down