diff --git a/sqlmesh/core/engine_adapter/athena.py b/sqlmesh/core/engine_adapter/athena.py index abaf7ba281..adab019cc6 100644 --- a/sqlmesh/core/engine_adapter/athena.py +++ b/sqlmesh/core/engine_adapter/athena.py @@ -103,24 +103,49 @@ def _get_data_objects( """ schema_name = to_schema(schema_name) schema = schema_name.db - query = ( - exp.select( - exp.column("table_catalog").as_("catalog"), - exp.column("table_schema", table="t").as_("schema"), - exp.column("table_name", table="t").as_("name"), - exp.case() - .when( - exp.column("table_type", table="t").eq("BASE TABLE"), - exp.Literal.string("table"), - ) - .else_(exp.column("table_type", table="t")) - .as_("type"), + + base_query = exp.select( + exp.column("table_catalog", table="t").as_("catalog"), + exp.column("table_schema", table="t").as_("schema"), + exp.column("table_name", table="t").as_("name"), + exp.case() + .when( + exp.column("table_type", table="t").eq("BASE TABLE"), + exp.Literal.string("table"), ) - .from_(exp.to_table("information_schema.tables", alias="t")) - .where(exp.column("table_schema", table="t").eq(schema)) - ) + .else_(exp.column("table_type", table="t")) + .as_("type"), + ).from_(exp.to_table("information_schema.tables", alias="t")) + if object_names: - query = query.where(exp.column("table_name", table="t").isin(*object_names)) + # Use CTE with VALUES for better performance when filtering specific tables + query = ( + base_query.join( + exp.to_identifier("object_names"), + on=exp.column("table_name", table="t").eq( + exp.column("name", table="object_names") + ), + ) + .where(exp.column("table_schema", table="t").eq(schema)) + .with_( + "object_names", + exp.select(exp.column("name")).from_( + exp.Values( + expressions=[ + exp.Tuple(expressions=[exp.Literal.string(name)]) + for name in object_names + ], + alias=exp.TableAlias( + this="t", + columns=[exp.column("name")], + ), + ) + ), + ) + ) + else: + # Simple query when no specific object names are provided + query = base_query.where(exp.column("table_schema", table="t").eq(schema)) df = self.fetchdf(query)