Skip to content

Commit

Permalink
merged main, kept changes from tests.integration/utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanwithawhy committed Jan 9, 2025
2 parents 72d7960 + 6a0e391 commit a9afda8
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 19 deletions.
7 changes: 3 additions & 4 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Release
name: Test

on:
push:
Expand Down Expand Up @@ -30,6 +30,5 @@ jobs:
env:
SNOWFLAKE_CONNECTIONS_BASE64: ${{ secrets.SNOWFLAKE_CONNECTIONS_BASE64 }}
run: |
rm /home/runner/.config/snowflake/connections.toml
export SNOWFLAKE_CONNECTIONS=$(echo $SNOWFLAKE_CONNECTIONS_BASE64 | base64 --decode)
poetry run pytest tests/integration/*
export SNOWFLAKE_CONNECTIONS=$(printf '%s' "$SNOWFLAKE_CONNECTIONS_BASE64" | base64 -d | tr '\r' '\n')
poetry run pytest tests/integration/* --log-cli-level=DEBUG
7 changes: 4 additions & 3 deletions tests/integration/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def test_simple_select(self):
universql_result = execute_query(conn, SIMPLE_QUERY)
print(universql_result)

@pytest.mark.skip(reason="Stages are not implemented yet")
def test_from_stage(self):
with universql_connection() as conn:
universql_result = execute_query(conn, "select count(*) from @stage/iceberg_stage")
Expand Down Expand Up @@ -67,9 +68,9 @@ def test_switch_schema(self):
universql_result = execute_query(conn, "SHOW SCHEMAS")
assert universql_result.num_rows > 0, f"The query did not return any rows!"

def test_in_schema(self):
with universql_connection(schema="public", warehouse="local()") as conn:
universql_result = execute_query(conn, "select count(*) from table_in_public_schema")
def test_in_database(self):
with universql_connection(database="public") as conn:
universql_result = execute_query(conn, "select * from information_schema.columns")
print(universql_result)

def test_qualifiers(self):
Expand Down
1 change: 0 additions & 1 deletion tests/snow_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def measure_all(cur, query):
def tt(max_iteration):
connection = snowflake.connector.connect(connection_name="ryan_snowflake",
port='8084', host='localhostcomputing.com',
warehouse='local()',
database='test',
# host='4ho74nvv4nyxhqxmcxrnpiid2m0bpcet.lambda-url.us-east-1.on.aws',
)
Expand Down
12 changes: 6 additions & 6 deletions universql/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,11 +397,11 @@ def calculate_script_cost(duration_second, electricity_rate=0.15, pc_lifetime_ye
pattern = r'(\w+)(?:\(([^)]*)\))'


def parse_compute(value):
if value is not None:
matches = re.findall(pattern, value)
def parse_compute(warehouse):
if warehouse is not None:
matches = re.findall(pattern, warehouse)
if len(matches) == 0:
matches = (('local', ''), ('snowflake', f'warehouse={value}'))
matches = (('local', ''), ('snowflake', f'warehouse={warehouse}'))
else:
# try locally if warehouse is not provided
matches = (('local', ''), )
Expand All @@ -412,8 +412,8 @@ def parse_compute(value):
if args_str:
for arg in args_str.split(','):
if '=' in arg:
key, value = arg.split('=', 1)
args[key.strip()] = value.strip()
key, warehouse = arg.split('=', 1)
args[key.strip()] = warehouse.strip()
else:
args[arg.strip()] = None # Handle arguments without '='
result.append({'name': func_name, 'args': args})
Expand Down
11 changes: 6 additions & 5 deletions universql/warehouse/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from snowflake.connector import NotSupportedError, DatabaseError
from snowflake.connector.constants import FIELD_TYPES, FieldType
from sqlglot.expressions import Literal, Var, Property, IcebergProperty, Properties, ColumnDef, DataType, \
Schema, TransientProperty, TemporaryProperty, Select, Column, Alias, Anonymous, parse_identifier, Subquery
Schema, TransientProperty, TemporaryProperty, Select, Column, Alias, Anonymous, parse_identifier, Subquery, Show, \
Use

from universql.warehouse import ICatalog, Executor, Locations, Tables
from universql.util import SNOWFLAKE_HOST, QueryError, prepend_to_lines, get_friendly_time_since
Expand Down Expand Up @@ -119,7 +120,8 @@ def get_volume_lake_path(self, volume: str) -> str:
volume_location = cursor.fetchall()

# Find the active storage location name
active_storage_name = next((item[3] for item in volume_location if item[1] == 'ACTIVE' and item[0] == 'STORAGE_LOCATIONS'), None)
active_storage_name = next(
(item[3] for item in volume_location if item[1] == 'ACTIVE' and item[0] == 'STORAGE_LOCATIONS'), None)

# Extract the STORAGE_BASE_URL from the corresponding storage location
storage_base_url = None
Expand Down Expand Up @@ -275,8 +277,7 @@ def _convert_snowflake_to_iceberg_type(self, snowflake_type: FieldType) -> str:
return 'TEXT'
return snowflake_type.name

def execute_raw(self, compiled_sql: str) -> None:
run_on_warehouse = self.catalog.compute.get('warehouse') is not None
def execute_raw(self, compiled_sql: str, run_on_warehouse=None) -> None:
try:
emoji = "☁️(user cloud services)" if not run_on_warehouse else "💰(used warehouse)"
logger.info(f"[{self.catalog.session_id}] Running on Snowflake.. {emoji} \n {compiled_sql}")
Expand All @@ -290,7 +291,7 @@ def execute(self, ast: sqlglot.exp.Expression, locations: Tables) -> \
compiled_sql = (ast
# .transform(self.default_create_table_as_iceberg)
.sql(dialect="snowflake", pretty=True))
self.execute_raw(compiled_sql)
self.execute_raw(compiled_sql, run_on_warehouse=not isinstance(ast, Show) and not isinstance(ast, Use))
return None

def get_query_log(self, total_duration) -> str:
Expand Down

0 comments on commit a9afda8

Please sign in to comment.