Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix change quotes #13

Merged
merged 14 commits into from
Jan 9, 2025
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,6 @@ celerybeat.pid
venv/*
.venv/*
.db/*
.env
.env
universql.metadata.sqlite
credentials/
187 changes: 187 additions & 0 deletions tests/integration/object_identifiers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import pytest

from tests.integration.utils import execute_query, universql_connection, snowflake_connection, generate_select_statement_combos
from dotenv import load_dotenv
import os
import logging
from pprint import pp

logger = logging.getLogger(__name__)

class TestObjectIdentifiers:

load_dotenv()

ACCOUNT = os.getenv("TEST_ACCOUNT_IDENTIFIER")
TEST_USER = os.getenv("TEST_USER")
TEST_USER_PASSWORD = os.getenv("TEST_PASSWORD")
TEST_ROLE = os.getenv("TEST_ROLE")

STORAGE_LOCATION_NAME = os.getenv("STORAGE_LOCATION_NAME")
STORAGE_PROVIDER = os.getenv("STORAGE_PROVIDER")
STORAGE_AWS_ROLE_ARN = os.getenv("STORAGE_AWS_ROLE_ARN")
STORAGE_BASE_URL = os.getenv("STORAGE_BASE_URL")
STORAGE_AWS_EXTERNAL_ID = os.getenv("STORAGE_AWS_EXTERNAL_ID")
EXTERNAL_VOLUME_NAME = os.getenv("EXTERNAL_VOLUME_NAME")
BUCKET_NAME = os.getenv("BUCKET_NAME")
TEST_WAREHOUSE = os.getenv("TEST_WAREHOUSE")


# requires the following:
# a connection's file ~/.snowflake/connections.toml
# a connection in that file called "integration_test_snowflake_direct" specifying a snowflake compute warehouse
# the connected user must access to the accountadmin role
def test_setup(self):
raw_query = f"""
-- CREATE OBJECTS
USE ROLE ACCOUNTADMIN;
CREATE OR REPLACE DATABASE universql1;
CREATE OR REPLACE DATABASE universql2;
CREATE SCHEMA IF NOT EXISTS universql1.same_schema;
CREATE SCHEMA IF NOT EXISTS universql1.different_schema;
CREATE SCHEMA IF NOT EXISTS universql2.another_schema;
CREATE EXTERNAL VOLUME IF NOT EXISTS {self.EXTERNAL_VOLUME_NAME}
STORAGE_LOCATIONS = (
(
NAME = '{self.STORAGE_LOCATION_NAME}',
STORAGE_PROVIDER = '{self.STORAGE_PROVIDER}',
STORAGE_AWS_ROLE_ARN = '{self.STORAGE_AWS_ROLE_ARN}',
STORAGE_BASE_URL = '{self.STORAGE_BASE_URL}',
STORAGE_AWS_EXTERNAL_ID = '{self.STORAGE_AWS_EXTERNAL_ID}'
)
)
ALLOW_WRITES = TRUE;

CREATE OR REPLACE ICEBERG TABLE universql1.same_schema.dim_devices
external_volume = {self.EXTERNAL_VOLUME_NAME}
catalog = 'SNOWFLAKE'
BASE_LOCATION = 's3://{self.BUCKET_NAME}/tests/1/same_schema/dim_devices'
AS select 1;

CREATE OR REPLACE ICEBERG TABLE universql1.different_schema.different_dim_devices
external_volume = {self.EXTERNAL_VOLUME_NAME}
catalog = 'SNOWFLAKE'
BASE_LOCATION = 's3://{self.BUCKET_NAME}/tests/1/different_schema/different_dim_devices'
AS select 1;

CREATE OR REPLACE ICEBERG TABLE universql2.another_schema.another_dim_devices
external_volume = {self.EXTERNAL_VOLUME_NAME}
catalog = 'SNOWFLAKE'
BASE_LOCATION = 's3://{self.BUCKET_NAME}/tests/2/another_schema/another_dim_devices'
AS select 1;

CREATE OR REPLACE ROLE {self.TEST_ROLE};
GRANT ROLE {self.TEST_ROLE} TO USER {self.TEST_USER};

GRANT USAGE ON EXTERNAL VOLUME {self.EXTERNAL_VOLUME_NAME} TO ROLE {self.TEST_ROLE};
GRANT USAGE ON DATABASE universql1 TO ROLE {self.TEST_ROLE};
GRANT USAGE ON DATABASE universql2 TO ROLE {self.TEST_ROLE};
GRANT USAGE ON ALL SCHEMAS IN DATABASE universql1 TO ROLE {self.TEST_ROLE};
GRANT USAGE ON ALL SCHEMAS IN DATABASE universql2 TO ROLE {self.TEST_ROLE};
GRANT SELECT ON universql1.same_schema.dim_devices TO ROLE {self.TEST_ROLE};
GRANT SELECT ON universql1.different_schema.different_dim_devices TO ROLE {self.TEST_ROLE};
GRANT SELECT ON universql2.another_schema.another_dim_devices TO ROLE {self.TEST_ROLE};
GRANT USAGE ON WAREHOUSE {self.TEST_WAREHOUSE} TO ROLE {self.TEST_ROLE};

USE ROLE {self.TEST_ROLE};
USE DATABASE universql1;
SELECT * FROM universql1.same_schema.dim_devices;
SELECT * FROM universql1.different_schema.different_dim_devices;
SELECT * FROM universql2.another_schema.another_dim_devices;
"""

queries = raw_query.split(";")
# connection_params = generate_usql_connection_params(self.ACCOUNT, self.TEST_USER, self.TEST_USER_PASSWORD, 'ACCOUNTADMIN')
# connection_params["warehouse"] = self.TEST_WAREHOUSE
connection_params = {}
connection_params["snowflake_connection_name"] = "integration_test_snowflake_direct"
with snowflake_connection(**connection_params) as conn:
cursor = conn.cursor()
failed_queries = []
for query in queries:
try:
cursor.execute(query)
logger.info(query)
except Exception as e:
failed_queries.append(f"{query} | FAILED - {str(e)}")
logger.info(f"{query} | FAILED - {str(e)}")
if len(failed_queries) > 0:
error_message = "The following queries failed:"
for query in failed_queries:
error_message = error_message + "\n{query}"
raise Exception(error_message)


# requires the following:
# a connection's file ~/.snowflake/connections.toml
# a connection in that file called "integration_test_universql" specifying that the warehouse is local()
# the connected user must be the same as for test_setup
def test_querying_in_connected_db_and_schema(self):
connected_db = "universql1"
connected_schema = "same_schema"

combos = [
{
"database": "universql1",
"schema": "same_schema",
"table": "dim_devices"
},
{
"database": "universql1",
"schema": "different_schema",
"table": "different_dim_devices"
},
{
"database": "universql2",
"schema": "another_schema",
"table": "another_dim_devices"
},
]

select_statements = generate_select_statement_combos(combos, connected_db, connected_schema)


# no_db_queries = generate_select_statement_combos(table, schema)
# no_schema_queries = generate_select_statement_combos(table)
# all_queries = fully_qualified_queries + no_db_queries + no_schema_queries
# all_queries_no_duplicates = sorted(list(set(all_queries)))
# for query in select_statements:
# logger.info(f"{query}: TBE")
successful_queries = []
failed_queries = []
counter = 0

connection_params = {
"snowflake_connection_name": "integration_test_universql",
"database": connected_db,
"schema": connected_schema
}

# create toml file
with universql_connection(**connection_params) as conn:
for query in select_statements:
logger.info(f"current counter: {counter}")
counter += 1
# if counter > 20:
# break
# without the break there are connection refused errors starting at 23. if I add a sleep statement it waits until 25
# output is below in comments
try:
result = execute_query(conn, query)
successful_queries.append(query)
logger.info(f"QUERY PASSED: {query}")
logger.info(result)
continue
except Exception as e:
logger.info(f"QUERY FAILED: {query}")
failed_queries.append(f"{query} | FAILED - {str(e)}")
logger.info("test_querying_in_connected_db_and_schema")
logger.info("Successful Queries:")
for query in successful_queries:
logger.info(query)
if len(failed_queries) > 0:
error_message = f"The following {len(failed_queries)} queries failed:"
for query in failed_queries:
error_message = f"{error_message}\n{query}"
logger.error(error_message)
raise Exception(error_message)
155 changes: 110 additions & 45 deletions tests/integration/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@
from snowflake.connector import connect as snowflake_connect, SnowflakeConnection
from snowflake.connector.config_manager import CONFIG_MANAGER
from snowflake.connector.constants import CONNECTIONS_FILE
from itertools import product
import toml
import logging

logger = logging.getLogger(__name__)



from universql.util import LOCALHOSTCOMPUTING_COM

Expand Down Expand Up @@ -78,65 +85,57 @@
"""


@pytest.fixture(scope="session")
@contextmanager
def snowflake_connection(**properties) -> Generator:
print(f"Reading {CONNECTIONS_FILE} with {properties}")
conn = snowflake_connect(connection_name=SNOWFLAKE_CONNECTION_NAME, **properties)
yield conn
conn.close()

server_cache = {}
snowflake_connection_name = _set_connection_name(properties)
conn = snowflake_connect(connection_name=snowflake_connection_name, **properties)
try:
yield conn
finally:
conn.close()

@contextmanager
def universql_connection(**properties) -> SnowflakeConnection:
# https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-connect#connecting-using-the-connections-toml-file
print(f"Reading {CONNECTIONS_FILE} with {properties}")
connections = CONFIG_MANAGER["connections"]
if SNOWFLAKE_CONNECTION_NAME not in connections:
raise pytest.fail(f"Snowflake connection '{SNOWFLAKE_CONNECTION_NAME}' not found in config")
connection = connections[SNOWFLAKE_CONNECTION_NAME]
account = connection.get('account')

if account in server_cache:
uni_string = {"host": LOCALHOSTCOMPUTING_COM, "port": server_cache[account]} | properties
else:
from universql.main import snowflake
with socketserver.TCPServer(("localhost", 0), None) as s:
free_port = s.server_address[1]
print(f"Reusing existing server running on port {free_port} for account {account}")

def start_universql():
runner = CliRunner()
try:
invoke = runner.invoke(snowflake,
[
'--account', account,
'--port', free_port, '--catalog', 'snowflake',
# AWS_DEFAULT_PROFILE env can be used to pass AWS profile
],
)
except Exception as e:
pytest.fail(e)

if invoke.exit_code != 0:
pytest.fail("Unable to start Universql")


print(f"Starting running on port {free_port} for account {account}")
thread = threading.Thread(target=start_universql)
thread.daemon = True
thread.start()
server_cache[account] = free_port
snowflake_connection_name = _set_connection_name(properties)
if snowflake_connection_name not in connections:
raise pytest.fail(f"Snowflake connection '{snowflake_connection_name}' not found in config")
connection = connections[snowflake_connection_name]
from universql.main import snowflake
with socketserver.TCPServer(("localhost", 0), None) as s:
free_port = s.server_address[1]

def start_universql():
runner = CliRunner()
try:
invoke = runner.invoke(snowflake,
[
'--account', connection.get('account'),
'--port', free_port, '--catalog', 'snowflake',
# AWS_DEFAULT_PROFILE env can be used to pass AWS profile
],
)
except Exception as e:
pytest.fail(e)

if invoke.exit_code != 0:
pytest.fail("Unable to start Universql")

thread = threading.Thread(target=start_universql)
thread.daemon = True
thread.start()
# with runner.isolated_filesystem():
uni_string = {"host": LOCALHOSTCOMPUTING_COM, "port": free_port} | properties
uni_string = {"host": LOCALHOSTCOMPUTING_COM, "port": free_port} | properties

try:
connect = snowflake_connect(connection_name=SNOWFLAKE_CONNECTION_NAME, **uni_string)
connect = snowflake_connect(connection_name=snowflake_connection_name, **uni_string)
yield connect
finally:
finally: # Force stop the thread
connect.close()


def execute_query(conn, query: str) -> pyarrow.Table:
cur = conn.cursor()
try:
Expand Down Expand Up @@ -193,3 +192,69 @@ def cleanup_table(conn, table_name: str):
cur.close()
except Exception as e:
print(f"Error during cleanup: {e}")

def generate_name_variants(name, include_blank = False):
lowercase = name.lower()
uppercase = name.upper()
mixed_case = name.capitalize()
in_quotes = '"' + name.upper() + '"'
print([lowercase, uppercase, mixed_case, in_quotes])
return [lowercase, uppercase, mixed_case, in_quotes]

def generate_select_statement_combos(sets_of_identifiers, connected_db = None, connected_schema = None):
select_statements = []
for set in sets_of_identifiers:
set_of_select_statements = []
database = set.get("database")
schema = set.get("schema")
table = set.get("table")
if table is not None:
table_variants = generate_name_variants(table)
if database == connected_db and schema == connected_schema:
for table_variant in table_variants:
set_of_select_statements.append(f"SELECT * FROM {table_variant}")
else:
raise Exception("No table name provided for a select stametent combo.")

if schema is not None:
schema_variants = generate_name_variants(schema)
if database == connected_db:
object_name_combos = product(schema_variants, table_variants)
for schema_name, table_name in object_name_combos:
set_of_select_statements.append(f"SELECT * FROM {schema_name}.{table_name}")
else:
if database is not None:
raise Exception("You must provide a schema name if you provide a database name.")

if database is not None:
database_variants = generate_name_variants(database)
object_name_combos = product(database_variants, schema_variants, table_variants)
for db_name, schema_name, table_name in object_name_combos:
set_of_select_statements.append(f"SELECT * FROM {db_name}.{schema_name}.{table_name}")
select_statements = select_statements + set_of_select_statements
logger.info(f"database: {database}, schema: {schema}, table: {table}")
for statement in set_of_select_statements:
logger.info(statement)
# logger.info(f"database: {database}, schema: {schema}, table: {table}")

return select_statements

def generate_usql_connection_params(account, user, password, role, database = None, schema = None):
params = {
"account": account,
"user": user,
"password": password,
"role": role,
"warehouse": "local()",
}
if database is not None:
params["database"] = database
if schema is not None:
params["schema"] = schema

return params

def _set_connection_name(connection_dict = {}):
snowflake_connection_name = connection_dict.get("snowflake_connection_name", SNOWFLAKE_CONNECTION_NAME)
logger.info(f"Using the {snowflake_connection_name} connection")
return snowflake_connection_name
Loading
Loading