Skip to content

Commit

Permalink
Cache universql server in integration test
Browse files Browse the repository at this point in the history
  • Loading branch information
buremba committed Jan 9, 2025
1 parent 413f275 commit 35158f1
Showing 1 changed file with 34 additions and 29 deletions.
63 changes: 34 additions & 29 deletions tests/integration/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import sys
import threading
from contextlib import contextmanager
from traceback import print_exc
from typing import Generator

import click
import pyarrow
import pytest
from click.testing import CliRunner
Expand All @@ -25,9 +27,10 @@
# export SNOWFLAKE_CONNECTION_STRING="account=xxx;user=xxx;password=xxx;warehouse=xxx;database=xxx;schema=xxx"
# export UNIVERSQL_CONNECTION_STRING="warehouse=xxx"
SNOWFLAKE_CONNECTION_NAME = os.getenv("SNOWFLAKE_CONNECTION_NAME") or "default"
logging.getLogger("snowflake.connector").setLevel(logging.INFO)

# Allow Universql to start
os.environ["MAX_CON_RETRY_ATTEMPTS"] = "100"
os.environ["MAX_CON_RETRY_ATTEMPTS"] = "15"

SIMPLE_QUERY = """
SELECT 1 as test
Expand Down Expand Up @@ -84,6 +87,7 @@
-- [1.1,2.2,3]::VECTOR(FLOAT,3) AS sample_vector
"""

server_cache = {}

@contextmanager
def snowflake_connection(**properties) -> Generator:
Expand All @@ -100,41 +104,47 @@ def universql_connection(**properties) -> SnowflakeConnection:
# https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-connect#connecting-using-the-connections-toml-file
print(f"Reading {CONNECTIONS_FILE} with {properties}")
connections = CONFIG_MANAGER["connections"]
snowflake_connection_name = _set_connection_name(properties)
snowflake_connection_name = connections.get("snowflake_connection_name", SNOWFLAKE_CONNECTION_NAME)
logger.info(f"Using the {snowflake_connection_name} connection")
if snowflake_connection_name not in connections:
raise pytest.fail(f"Snowflake connection '{snowflake_connection_name}' not found in config")
connection = connections[snowflake_connection_name]
from universql.main import snowflake
with socketserver.TCPServer(("localhost", 0), None) as s:
free_port = s.server_address[1]

def start_universql():
runner = CliRunner()
try:
connection = connections[SNOWFLAKE_CONNECTION_NAME]
account = connection.get('account')
if account in server_cache:
uni_string = {"host": LOCALHOSTCOMPUTING_COM, "port": server_cache[account]} | properties
else:
from universql.main import snowflake
with socketserver.TCPServer(("localhost", 0), None) as s:
free_port = s.server_address[1]
print(f"Reusing existing server running on port {free_port} for account {account}")

def start_universql():
runner = CliRunner()
invoke = runner.invoke(snowflake,
[
'--account', connection.get('account'),
'--account', account,
'--port', free_port, '--catalog', 'snowflake',
# AWS_DEFAULT_PROFILE env can be used to pass AWS profile
],
catch_exceptions=False
)
except Exception as e:
pytest.fail(e)
if invoke.exit_code != 0:
raise Exception("Unable to start Universql")

if invoke.exit_code != 0:
pytest.fail("Unable to start Universql")

thread = threading.Thread(target=start_universql)
thread.daemon = True
thread.start()
# with runner.isolated_filesystem():
uni_string = {"host": LOCALHOSTCOMPUTING_COM, "port": free_port} | properties
print(f"Starting running on port {free_port} for account {account}")
thread = threading.Thread(target=start_universql)
thread.daemon = True
thread.start()
server_cache[account] = free_port
uni_string = {"host": LOCALHOSTCOMPUTING_COM, "port": free_port} | properties

connect = None
try:
connect = snowflake_connect(connection_name=snowflake_connection_name, **uni_string)
yield connect
finally: # Force stop the thread
connect.close()
finally:
if connect is not None:
connect.close()

def execute_query(conn, query: str) -> pyarrow.Table:
cur = conn.cursor()
Expand Down Expand Up @@ -252,9 +262,4 @@ def generate_usql_connection_params(account, user, password, role, database = No
if schema is not None:
params["schema"] = schema

return params

def _set_connection_name(connection_dict = {}):
snowflake_connection_name = connection_dict.get("snowflake_connection_name", SNOWFLAKE_CONNECTION_NAME)
logger.info(f"Using the {snowflake_connection_name} connection")
return snowflake_connection_name
return params

0 comments on commit 35158f1

Please sign in to comment.