diff --git a/universql/warehouse/snowflake.py b/universql/warehouse/snowflake.py index 9baf77c..e7be085 100644 --- a/universql/warehouse/snowflake.py +++ b/universql/warehouse/snowflake.py @@ -18,7 +18,8 @@ from snowflake.connector import NotSupportedError, DatabaseError from snowflake.connector.constants import FIELD_TYPES, FieldType from sqlglot.expressions import Literal, Var, Property, IcebergProperty, Properties, ColumnDef, DataType, \ - Schema, TransientProperty, TemporaryProperty, Select, Column, Alias, Anonymous, parse_identifier, Subquery + Schema, TransientProperty, TemporaryProperty, Select, Column, Alias, Anonymous, parse_identifier, Subquery, Show, \ + Use from universql.warehouse import ICatalog, Executor, Locations, Tables from universql.util import SNOWFLAKE_HOST, QueryError, prepend_to_lines, get_friendly_time_since @@ -119,7 +120,8 @@ def get_volume_lake_path(self, volume: str) -> str: volume_location = cursor.fetchall() # Find the active storage location name - active_storage_name = next((item[3] for item in volume_location if item[1] == 'ACTIVE' and item[0] == 'STORAGE_LOCATIONS'), None) + active_storage_name = next( + (item[3] for item in volume_location if item[1] == 'ACTIVE' and item[0] == 'STORAGE_LOCATIONS'), None) # Extract the STORAGE_BASE_URL from the corresponding storage location storage_base_url = None @@ -275,8 +277,7 @@ def _convert_snowflake_to_iceberg_type(self, snowflake_type: FieldType) -> str: return 'TEXT' return snowflake_type.name - def execute_raw(self, compiled_sql: str) -> None: - run_on_warehouse = self.catalog.compute.get('warehouse') is not None + def execute_raw(self, compiled_sql: str, run_on_warehouse=None) -> None: try: emoji = "☁️(user cloud services)" if not run_on_warehouse else "💰(used warehouse)" logger.info(f"[{self.catalog.session_id}] Running on Snowflake.. {emoji} \n {compiled_sql}") @@ -290,7 +291,7 @@ def execute(self, ast: sqlglot.exp.Expression, locations: Tables) -> \ compiled_sql = (ast # .transform(self.default_create_table_as_iceberg) .sql(dialect="snowflake", pretty=True)) - self.execute_raw(compiled_sql) + self.execute_raw(compiled_sql, run_on_warehouse=not isinstance(ast, Show) and not isinstance(ast, Use)) return None def get_query_log(self, total_duration) -> str: