From 32c6f0d11767c6cacb559a5b103b9ea86eeac91a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20Kabakc=C4=B1?= Date: Wed, 8 Jan 2025 22:32:49 +0000 Subject: [PATCH 1/4] fix github actions --- .github/workflows/test.yml | 3 +-- universql/util.py | 12 ++++++------ 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 7063f39..5cf76e3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -30,6 +30,5 @@ jobs: env: SNOWFLAKE_CONNECTIONS_BASE64: ${{ secrets.SNOWFLAKE_CONNECTIONS_BASE64 }} run: | - rm /home/runner/.config/snowflake/connections.toml - export SNOWFLAKE_CONNECTIONS=$(echo $SNOWFLAKE_CONNECTIONS_BASE64 | base64 --decode) + export SNOWFLAKE_CONNECTIONS=$(printf '%s' "$SNOWFLAKE_CONNECTIONS_BASE64" | base64 -d | tr '\r' '\n') poetry run pytest tests/integration/* \ No newline at end of file diff --git a/universql/util.py b/universql/util.py index c31ba3a..9f9c54d 100644 --- a/universql/util.py +++ b/universql/util.py @@ -397,11 +397,11 @@ def calculate_script_cost(duration_second, electricity_rate=0.15, pc_lifetime_ye pattern = r'(\w+)(?:\(([^)]*)\))' -def parse_compute(value): - if value is not None: - matches = re.findall(pattern, value) +def parse_compute(warehouse): + if warehouse is not None: + matches = re.findall(pattern, warehouse) if len(matches) == 0: - matches = (('local', ''), ('snowflake', f'warehouse={value}')) + matches = (('local', ''), ('snowflake', f'warehouse={warehouse}')) else: # try locally if warehouse is not provided matches = (('local', ''), ) @@ -412,8 +412,8 @@ def parse_compute(value): if args_str: for arg in args_str.split(','): if '=' in arg: - key, value = arg.split('=', 1) - args[key.strip()] = value.strip() + key, warehouse = arg.split('=', 1) + args[key.strip()] = warehouse.strip() else: args[arg.strip()] = None # Handle arguments without '=' result.append({'name': func_name, 'args': args}) From b816a2edc0dd89ccb367678a732948d0d5dd132e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20Kabakc=C4=B1?= Date: Wed, 8 Jan 2025 22:40:49 +0000 Subject: [PATCH 2/4] determine if the snowflake query requires a warehouse --- universql/warehouse/snowflake.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/universql/warehouse/snowflake.py b/universql/warehouse/snowflake.py index 9baf77c..e7be085 100644 --- a/universql/warehouse/snowflake.py +++ b/universql/warehouse/snowflake.py @@ -18,7 +18,8 @@ from snowflake.connector import NotSupportedError, DatabaseError from snowflake.connector.constants import FIELD_TYPES, FieldType from sqlglot.expressions import Literal, Var, Property, IcebergProperty, Properties, ColumnDef, DataType, \ - Schema, TransientProperty, TemporaryProperty, Select, Column, Alias, Anonymous, parse_identifier, Subquery + Schema, TransientProperty, TemporaryProperty, Select, Column, Alias, Anonymous, parse_identifier, Subquery, Show, \ + Use from universql.warehouse import ICatalog, Executor, Locations, Tables from universql.util import SNOWFLAKE_HOST, QueryError, prepend_to_lines, get_friendly_time_since @@ -119,7 +120,8 @@ def get_volume_lake_path(self, volume: str) -> str: volume_location = cursor.fetchall() # Find the active storage location name - active_storage_name = next((item[3] for item in volume_location if item[1] == 'ACTIVE' and item[0] == 'STORAGE_LOCATIONS'), None) + active_storage_name = next( + (item[3] for item in volume_location if item[1] == 'ACTIVE' and item[0] == 'STORAGE_LOCATIONS'), None) # Extract the STORAGE_BASE_URL from the corresponding storage location storage_base_url = None @@ -275,8 +277,7 @@ def _convert_snowflake_to_iceberg_type(self, snowflake_type: FieldType) -> str: return 'TEXT' return snowflake_type.name - def execute_raw(self, compiled_sql: str) -> None: - run_on_warehouse = self.catalog.compute.get('warehouse') is not None + def execute_raw(self, compiled_sql: str, run_on_warehouse=None) -> None: try: emoji = "☁️(user cloud services)" if not run_on_warehouse else "💰(used warehouse)" logger.info(f"[{self.catalog.session_id}] Running on Snowflake.. {emoji} \n {compiled_sql}") @@ -290,7 +291,7 @@ def execute(self, ast: sqlglot.exp.Expression, locations: Tables) -> \ compiled_sql = (ast # .transform(self.default_create_table_as_iceberg) .sql(dialect="snowflake", pretty=True)) - self.execute_raw(compiled_sql) + self.execute_raw(compiled_sql, run_on_warehouse=not isinstance(ast, Show) and not isinstance(ast, Use)) return None def get_query_log(self, total_duration) -> str: From 93e4f66b447848563b5e4e7279beb318789b2e28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20Kabakc=C4=B1?= Date: Wed, 8 Jan 2025 22:41:51 +0000 Subject: [PATCH 3/4] Enhance logging in the tests --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5cf76e3..68c7fda 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -31,4 +31,4 @@ jobs: SNOWFLAKE_CONNECTIONS_BASE64: ${{ secrets.SNOWFLAKE_CONNECTIONS_BASE64 }} run: | export SNOWFLAKE_CONNECTIONS=$(printf '%s' "$SNOWFLAKE_CONNECTIONS_BASE64" | base64 -d | tr '\r' '\n') - poetry run pytest tests/integration/* \ No newline at end of file + poetry run pytest tests/integration/* --log-cli-level=DEBUG \ No newline at end of file From 6a0e3915c44f24a00dacb630bb6c3d07f5bd54b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20Kabakc=C4=B1?= Date: Wed, 8 Jan 2025 23:28:57 +0000 Subject: [PATCH 4/4] Reuse universql servers in e2e --- .github/workflows/test.yml | 2 +- tests/integration/extract.py | 7 +++-- tests/integration/utils.py | 60 +++++++++++++++++++++--------------- tests/snow_client.py | 1 - 4 files changed, 40 insertions(+), 30 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 68c7fda..60077ee 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,4 +1,4 @@ -name: Release +name: Test on: push: diff --git a/tests/integration/extract.py b/tests/integration/extract.py index cdb3b07..528a9ef 100644 --- a/tests/integration/extract.py +++ b/tests/integration/extract.py @@ -39,6 +39,7 @@ def test_simple_select(self): universql_result = execute_query(conn, SIMPLE_QUERY) print(universql_result) + @pytest.mark.skip(reason="Stages are not implemented yet") def test_from_stage(self): with universql_connection() as conn: universql_result = execute_query(conn, "select count(*) from @stage/iceberg_stage") @@ -67,9 +68,9 @@ def test_switch_schema(self): universql_result = execute_query(conn, "SHOW SCHEMAS") assert universql_result.num_rows > 0, f"The query did not return any rows!" - def test_in_schema(self): - with universql_connection(schema="public", warehouse="local()") as conn: - universql_result = execute_query(conn, "select count(*) from table_in_public_schema") + def test_in_database(self): + with universql_connection(database="public") as conn: + universql_result = execute_query(conn, "select * from information_schema.columns") print(universql_result) def test_qualifiers(self): diff --git a/tests/integration/utils.py b/tests/integration/utils.py index cd7b65d..20fce65 100644 --- a/tests/integration/utils.py +++ b/tests/integration/utils.py @@ -85,6 +85,7 @@ def snowflake_connection(**properties) -> Generator: yield conn conn.close() +server_cache = {} @contextmanager def universql_connection(**properties) -> SnowflakeConnection: @@ -94,36 +95,45 @@ def universql_connection(**properties) -> SnowflakeConnection: if SNOWFLAKE_CONNECTION_NAME not in connections: raise pytest.fail(f"Snowflake connection '{SNOWFLAKE_CONNECTION_NAME}' not found in config") connection = connections[SNOWFLAKE_CONNECTION_NAME] - from universql.main import snowflake - with socketserver.TCPServer(("localhost", 0), None) as s: - free_port = s.server_address[1] - - def start_universql(): - runner = CliRunner() - try: - invoke = runner.invoke(snowflake, - [ - '--account', connection.get('account'), - '--port', free_port, '--catalog', 'snowflake', - # AWS_DEFAULT_PROFILE env can be used to pass AWS profile - ], - ) - except Exception as e: - pytest.fail(e) - - if invoke.exit_code != 0: - pytest.fail("Unable to start Universql") - - thread = threading.Thread(target=start_universql) - thread.daemon = True - thread.start() + account = connection.get('account') + + if account in server_cache: + uni_string = {"host": LOCALHOSTCOMPUTING_COM, "port": server_cache[account]} | properties + else: + from universql.main import snowflake + with socketserver.TCPServer(("localhost", 0), None) as s: + free_port = s.server_address[1] + print(f"Reusing existing server running on port {free_port} for account {account}") + + def start_universql(): + runner = CliRunner() + try: + invoke = runner.invoke(snowflake, + [ + '--account', account, + '--port', free_port, '--catalog', 'snowflake', + # AWS_DEFAULT_PROFILE env can be used to pass AWS profile + ], + ) + except Exception as e: + pytest.fail(e) + + if invoke.exit_code != 0: + pytest.fail("Unable to start Universql") + + + print(f"Starting running on port {free_port} for account {account}") + thread = threading.Thread(target=start_universql) + thread.daemon = True + thread.start() + server_cache[account] = free_port # with runner.isolated_filesystem(): - uni_string = {"host": LOCALHOSTCOMPUTING_COM, "port": free_port} | properties + uni_string = {"host": LOCALHOSTCOMPUTING_COM, "port": free_port} | properties try: connect = snowflake_connect(connection_name=SNOWFLAKE_CONNECTION_NAME, **uni_string) yield connect - finally: # Force stop the thread + finally: connect.close() diff --git a/tests/snow_client.py b/tests/snow_client.py index 552a488..fffd3c5 100644 --- a/tests/snow_client.py +++ b/tests/snow_client.py @@ -35,7 +35,6 @@ def measure_all(cur, query): def tt(max_iteration): connection = snowflake.connector.connect(connection_name="ryan_snowflake", port='8084', host='localhostcomputing.com', - warehouse='local()', database='test', # host='4ho74nvv4nyxhqxmcxrnpiid2m0bpcet.lambda-url.us-east-1.on.aws', )