Skip to content

Commit

Permalink
determine if the snowflake query requires a warehouse
Browse files Browse the repository at this point in the history
  • Loading branch information
buremba committed Jan 8, 2025
1 parent 32c6f0d commit b816a2e
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions universql/warehouse/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from snowflake.connector import NotSupportedError, DatabaseError
from snowflake.connector.constants import FIELD_TYPES, FieldType
from sqlglot.expressions import Literal, Var, Property, IcebergProperty, Properties, ColumnDef, DataType, \
Schema, TransientProperty, TemporaryProperty, Select, Column, Alias, Anonymous, parse_identifier, Subquery
Schema, TransientProperty, TemporaryProperty, Select, Column, Alias, Anonymous, parse_identifier, Subquery, Show, \
Use

from universql.warehouse import ICatalog, Executor, Locations, Tables
from universql.util import SNOWFLAKE_HOST, QueryError, prepend_to_lines, get_friendly_time_since
Expand Down Expand Up @@ -119,7 +120,8 @@ def get_volume_lake_path(self, volume: str) -> str:
volume_location = cursor.fetchall()

# Find the active storage location name
active_storage_name = next((item[3] for item in volume_location if item[1] == 'ACTIVE' and item[0] == 'STORAGE_LOCATIONS'), None)
active_storage_name = next(
(item[3] for item in volume_location if item[1] == 'ACTIVE' and item[0] == 'STORAGE_LOCATIONS'), None)

# Extract the STORAGE_BASE_URL from the corresponding storage location
storage_base_url = None
Expand Down Expand Up @@ -275,8 +277,7 @@ def _convert_snowflake_to_iceberg_type(self, snowflake_type: FieldType) -> str:
return 'TEXT'
return snowflake_type.name

def execute_raw(self, compiled_sql: str) -> None:
run_on_warehouse = self.catalog.compute.get('warehouse') is not None
def execute_raw(self, compiled_sql: str, run_on_warehouse=None) -> None:
try:
emoji = "☁️(user cloud services)" if not run_on_warehouse else "💰(used warehouse)"
logger.info(f"[{self.catalog.session_id}] Running on Snowflake.. {emoji} \n {compiled_sql}")
Expand All @@ -290,7 +291,7 @@ def execute(self, ast: sqlglot.exp.Expression, locations: Tables) -> \
compiled_sql = (ast
# .transform(self.default_create_table_as_iceberg)
.sql(dialect="snowflake", pretty=True))
self.execute_raw(compiled_sql)
self.execute_raw(compiled_sql, run_on_warehouse=not isinstance(ast, Show) and not isinstance(ast, Use))
return None

def get_query_log(self, total_duration) -> str:
Expand Down

0 comments on commit b816a2e

Please sign in to comment.