From 35158f1a4bab3937631b380274255625607929af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20Kabakc=C4=B1?= Date: Thu, 9 Jan 2025 16:12:15 +0000 Subject: [PATCH] Cache universql server in integration test --- tests/integration/utils.py | 63 ++++++++++++++++++++------------------ 1 file changed, 34 insertions(+), 29 deletions(-) diff --git a/tests/integration/utils.py b/tests/integration/utils.py index a856ea6..1c677ec 100644 --- a/tests/integration/utils.py +++ b/tests/integration/utils.py @@ -3,8 +3,10 @@ import sys import threading from contextlib import contextmanager +from traceback import print_exc from typing import Generator +import click import pyarrow import pytest from click.testing import CliRunner @@ -25,9 +27,10 @@ # export SNOWFLAKE_CONNECTION_STRING="account=xxx;user=xxx;password=xxx;warehouse=xxx;database=xxx;schema=xxx" # export UNIVERSQL_CONNECTION_STRING="warehouse=xxx" SNOWFLAKE_CONNECTION_NAME = os.getenv("SNOWFLAKE_CONNECTION_NAME") or "default" +logging.getLogger("snowflake.connector").setLevel(logging.INFO) # Allow Universql to start -os.environ["MAX_CON_RETRY_ATTEMPTS"] = "100" +os.environ["MAX_CON_RETRY_ATTEMPTS"] = "15" SIMPLE_QUERY = """ SELECT 1 as test @@ -84,6 +87,7 @@ -- [1.1,2.2,3]::VECTOR(FLOAT,3) AS sample_vector """ +server_cache = {} @contextmanager def snowflake_connection(**properties) -> Generator: @@ -100,41 +104,47 @@ def universql_connection(**properties) -> SnowflakeConnection: # https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-connect#connecting-using-the-connections-toml-file print(f"Reading {CONNECTIONS_FILE} with {properties}") connections = CONFIG_MANAGER["connections"] - snowflake_connection_name = _set_connection_name(properties) + snowflake_connection_name = connections.get("snowflake_connection_name", SNOWFLAKE_CONNECTION_NAME) + logger.info(f"Using the {snowflake_connection_name} connection") if snowflake_connection_name not in connections: raise pytest.fail(f"Snowflake connection '{snowflake_connection_name}' not found in config") - connection = connections[snowflake_connection_name] - from universql.main import snowflake - with socketserver.TCPServer(("localhost", 0), None) as s: - free_port = s.server_address[1] - - def start_universql(): - runner = CliRunner() - try: + connection = connections[SNOWFLAKE_CONNECTION_NAME] + account = connection.get('account') + if account in server_cache: + uni_string = {"host": LOCALHOSTCOMPUTING_COM, "port": server_cache[account]} | properties + else: + from universql.main import snowflake + with socketserver.TCPServer(("localhost", 0), None) as s: + free_port = s.server_address[1] + print(f"Reusing existing server running on port {free_port} for account {account}") + + def start_universql(): + runner = CliRunner() invoke = runner.invoke(snowflake, [ - '--account', connection.get('account'), + '--account', account, '--port', free_port, '--catalog', 'snowflake', # AWS_DEFAULT_PROFILE env can be used to pass AWS profile ], + catch_exceptions=False ) - except Exception as e: - pytest.fail(e) + if invoke.exit_code != 0: + raise Exception("Unable to start Universql") - if invoke.exit_code != 0: - pytest.fail("Unable to start Universql") - - thread = threading.Thread(target=start_universql) - thread.daemon = True - thread.start() - # with runner.isolated_filesystem(): - uni_string = {"host": LOCALHOSTCOMPUTING_COM, "port": free_port} | properties + print(f"Starting running on port {free_port} for account {account}") + thread = threading.Thread(target=start_universql) + thread.daemon = True + thread.start() + server_cache[account] = free_port + uni_string = {"host": LOCALHOSTCOMPUTING_COM, "port": free_port} | properties + connect = None try: connect = snowflake_connect(connection_name=snowflake_connection_name, **uni_string) yield connect - finally: # Force stop the thread - connect.close() + finally: + if connect is not None: + connect.close() def execute_query(conn, query: str) -> pyarrow.Table: cur = conn.cursor() @@ -252,9 +262,4 @@ def generate_usql_connection_params(account, user, password, role, database = No if schema is not None: params["schema"] = schema - return params - -def _set_connection_name(connection_dict = {}): - snowflake_connection_name = connection_dict.get("snowflake_connection_name", SNOWFLAKE_CONNECTION_NAME) - logger.info(f"Using the {snowflake_connection_name} connection") - return snowflake_connection_name \ No newline at end of file + return params \ No newline at end of file