Skip to content

Commit

Permalink
object identifiers tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanwithawhy committed Jan 8, 2025
1 parent 6f397a7 commit 29e8c01
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 72 deletions.
117 changes: 68 additions & 49 deletions tests/integration/object_identifiers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import snowflake.connector

from tests.integration.utils import execute_query, dynamic_universql_connection, SIMPLE_QUERY, generate_select_statement_combos, generate_usql_connection_params
from tests.integration.utils import execute_query, dynamic_universql_connection, universql_connection, snowflake_connection, SIMPLE_QUERY, generate_select_statement_combos, generate_usql_connection_params
from dotenv import load_dotenv
import os
import logging
Expand Down Expand Up @@ -55,16 +55,16 @@ def test_setup(self):
BASE_LOCATION = 's3://{self.BUCKET_NAME}/tests/1/same_schema/dim_devices'
AS select 1;
CREATE OR REPLACE ICEBERG TABLE universql1.different_schema.dim_devices
CREATE OR REPLACE ICEBERG TABLE universql1.different_schema.different_dim_devices
external_volume = {self.EXTERNAL_VOLUME_NAME}
catalog = 'SNOWFLAKE'
BASE_LOCATION = 's3://{self.BUCKET_NAME}/tests/1/different_schema/dim_devices'
BASE_LOCATION = 's3://{self.BUCKET_NAME}/tests/1/different_schema/different_dim_devices'
AS select 1;
CREATE OR REPLACE ICEBERG TABLE universql2.another_schema.dim_devices
CREATE OR REPLACE ICEBERG TABLE universql2.another_schema.another_dim_devices
external_volume = {self.EXTERNAL_VOLUME_NAME}
catalog = 'SNOWFLAKE'
BASE_LOCATION = 's3://{self.BUCKET_NAME}/tests/2/another_schema/dim_devices'
BASE_LOCATION = 's3://{self.BUCKET_NAME}/tests/2/another_schema/another_dim_devices'
AS select 1;
CREATE OR REPLACE ROLE {self.TEST_ROLE};
Expand All @@ -76,57 +76,82 @@ def test_setup(self):
GRANT USAGE ON ALL SCHEMAS IN DATABASE universql1 TO ROLE {self.TEST_ROLE};
GRANT USAGE ON ALL SCHEMAS IN DATABASE universql2 TO ROLE {self.TEST_ROLE};
GRANT SELECT ON universql1.same_schema.dim_devices TO ROLE {self.TEST_ROLE};
GRANT SELECT ON universql1.different_schema.dim_devices TO ROLE {self.TEST_ROLE};
GRANT SELECT ON universql2.another_schema.dim_devices TO ROLE {self.TEST_ROLE};
GRANT SELECT ON universql1.different_schema.different_dim_devices TO ROLE {self.TEST_ROLE};
GRANT SELECT ON universql2.another_schema.another_dim_devices TO ROLE {self.TEST_ROLE};
GRANT USAGE ON WAREHOUSE {self.TEST_WAREHOUSE} TO ROLE {self.TEST_ROLE};
USE ROLE {self.TEST_ROLE};
USE DATABASE universql1;
SELECT * FROM universql1.same_schema.dim_devices;
SELECT * FROM universql1.different_schema.dim_devices;
SELECT * FROM universql2.another_schema.dim_devices;
SELECT * FROM universql1.different_schema.different_dim_devices;
SELECT * FROM universql2.another_schema.another_dim_devices;
"""

queries = raw_query.split(";")
connection_params = generate_usql_connection_params(self.ACCOUNT, self.TEST_USER, self.TEST_USER_PASSWORD, 'ACCOUNTADMIN')
connection_params["warehouse"] = self.TEST_WAREHOUSE
snowflake_conn = snowflake.connector.connect(**connection_params)
cursor = snowflake_conn.cursor()
failed_queries = []
for query in queries:
try:
cursor.execute(query)
logger.info(query)
except Exception as e:
failed_queries.append(f"{query} | FAILED - {str(e)}")
logger.info(f"{query} | FAILED - {str(e)}")
if len(failed_queries) > 0:
error_message = "The following queries failed:"
for query in failed_queries:
error_message = error_message + "\n{query}"
raise Exception(error_message)
# connection_params = generate_usql_connection_params(self.ACCOUNT, self.TEST_USER, self.TEST_USER_PASSWORD, 'ACCOUNTADMIN')
# connection_params["warehouse"] = self.TEST_WAREHOUSE
connection_params = {}
connection_params["snowflake_connection_name"] = "integration_test_snowflake_direct"
with snowflake_connection(**connection_params) as conn:
cursor = conn.cursor()
failed_queries = []
for query in queries:
try:
cursor.execute(query)
logger.info(query)
except Exception as e:
failed_queries.append(f"{query} | FAILED - {str(e)}")
logger.info(f"{query} | FAILED - {str(e)}")
if len(failed_queries) > 0:
error_message = "The following queries failed:"
for query in failed_queries:
error_message = error_message + "\n{query}"
raise Exception(error_message)

def test_querying_in_connected_db_and_schema(self):
database = "universql1"
schema = "same_schema"
table = "dim_devices"

fully_qualified_queries = generate_select_statement_combos(table, schema, database)
no_db_queries = generate_select_statement_combos(table, schema)
no_schema_queries = generate_select_statement_combos(table)
all_queries = fully_qualified_queries + no_db_queries + no_schema_queries
all_queries_no_duplicates = sorted(list(set(all_queries)))
for query in all_queries_no_duplicates:
logger.info(f"{query}: TBE")
connected_db = "universql1"
connected_schema = "same_schema"

combos = [
{
"database": "universql1",
"schema": "same_schema",
"table": "dim_devices"
},
{
"database": "universql1",
"schema": "different_schema",
"table": "different_dim_devices"
},
{
"database": "universql2",
"schema": "another_schema",
"table": "another_dim_devices"
},
]

select_statements = generate_select_statement_combos(combos, connected_db, connected_schema)


# no_db_queries = generate_select_statement_combos(table, schema)
# no_schema_queries = generate_select_statement_combos(table)
# all_queries = fully_qualified_queries + no_db_queries + no_schema_queries
# all_queries_no_duplicates = sorted(list(set(all_queries)))
# for query in select_statements:
# logger.info(f"{query}: TBE")
successful_queries = []
failed_queries = []
counter = 0

connection_params = generate_usql_connection_params(self.ACCOUNT, self.TEST_USER, self.TEST_USER_PASSWORD, self.TEST_ROLE, database, schema)

connection_params = {
"snowflake_connection_name": "integration_test_universql",
"database": connected_db,
"schema": connected_schema
}

# create toml file
with dynamic_universql_connection(**connection_params) as conn:
for query in all_queries_no_duplicates:
with universql_connection(**connection_params) as conn:
for query in select_statements:
logger.info(f"current counter: {counter}")
counter += 1
# if counter > 20:
Expand All @@ -147,14 +172,8 @@ def test_querying_in_connected_db_and_schema(self):
for query in successful_queries:
logger.info(query)
if len(failed_queries) > 0:
error_message = "The following queries failed:"
error_message = f"The following {len(failed_queries)} queries failed:"
for query in failed_queries:
error_message = f"{error_message}\n{query}"
logger.error(error_message)
raise Exception(error_message)

# WARNING snowflake.connector.vendored.urllib3.connectionpool:connectionpool.py:824 Retrying (Retry(total=0, connect=None, read=None, redirect=None, status=None)) after connection broken by 'NewConnectionError('<snowflake.connector.vendored.urllib3.connection.HTTPSConnection object at 0x14d169670>: Failed to establish a new connection: [Errno 61] Connection refused')': /session/v1/login-request?request_id=b806a1b2-0462-4d76-a9c8-348981837587&databaseName=universql1&schemaName=same_schema&warehouse=local%28%29&roleName=general_purpose
# WARNING 🧵:snowflake.py:387 Failed to set signal handler for SIGINT: signal only works in main thread of the main interpreter
# WARNING snowflake.connector.vendored.urllib3.connectionpool:connectionpool.py:824 Retrying (Retry(total=0, connect=None, read=None, redirect=None, status=None)) after connection broken by 'NewConnectionError('<snowflake.connector.vendored.urllib3.connection.HTTPSConnection object at 0x14d147680>: Failed to establish a new connection: [Errno 61] Connection refused')': /session/v1/login-request?request_id=4171a507-63b3-4349-85d1-4a27a68650a1&databaseName=universql1&schemaName=same_schema&warehouse=local%28%29&roleName=general_purpose
# WARNING snowflake.connector.vendored.urllib3.connectionpool:connectionpool.py:824 Retrying (Retry(total=0, connect=None, read=None, redirect=None, status=None)) after connection broken by 'NewConnectionError('<snowflake.connector.vendored.urllib3.connection.HTTPSConnection object at 0x14d168890>: Failed to establish a new connection: [Errno 61] Connection refused')': /session/v1/login-request?request_id=ff44232e-b3a0-42f4-8a56-2d7f9d37747d&databaseName=universql1&schemaName=same_schema&warehouse=local%28%29&roleName=general_purpose
# keeps repeating
raise Exception(error_message)
69 changes: 46 additions & 23 deletions tests/integration/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,20 +85,22 @@
"""


@pytest.fixture(scope="session")
@contextmanager
def snowflake_connection(**properties) -> Generator:
print(f"Reading {CONNECTIONS_FILE} with {properties}")
conn = snowflake_connect(connection_name=SNOWFLAKE_CONNECTION_NAME, **properties)
yield conn
conn.close()

snowflake_connection_name = _set_connection_name(properties)
conn = snowflake_connect(connection_name=snowflake_connection_name, **properties)
try:
yield conn
finally:
conn.close()

@contextmanager
def universql_connection(**properties) -> SnowflakeConnection:
# https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-connect#connecting-using-the-connections-toml-file
print(f"Reading {CONNECTIONS_FILE} with {properties}")
connections = CONFIG_MANAGER["connections"]
snowflake_connection_name = properties.get("snowflake_connection_name", SNOWFLAKE_CONNECTION_NAME)
snowflake_connection_name = _set_connection_name(properties)
if snowflake_connection_name not in connections:
raise pytest.fail(f"Snowflake connection '{snowflake_connection_name}' not found in config")
connection = connections[snowflake_connection_name]
Expand Down Expand Up @@ -229,26 +231,42 @@ def generate_name_variants(name, include_blank = False):
print([lowercase, uppercase, mixed_case, in_quotes])
return [lowercase, uppercase, mixed_case, in_quotes]

def generate_select_statement_combos(table, schema = None, database = None):
def generate_select_statement_combos(sets_of_identifiers, connected_db = None, connected_schema = None):
select_statements = []
table_variants = generate_name_variants(table)

if database is not None:
database_variants = generate_name_variants(database)
schema_variants = generate_name_variants(schema)
object_name_combos = product(database_variants, schema_variants, table_variants)
for db_name, schema_name, table_name in object_name_combos:
select_statements.append(f"SELECT * FROM {db_name}.{schema_name}.{table_name}")
else:
for set in sets_of_identifiers:
set_of_select_statements = []
database = set.get("database")
schema = set.get("schema")
table = set.get("table")
if table is not None:
table_variants = generate_name_variants(table)
if database == connected_db and schema == connected_schema:
for table_variant in table_variants:
set_of_select_statements.append(f"SELECT * FROM {table_variant}")
else:
raise Exception("No table name provided for a select stametent combo.")

if schema is not None:
schema_variants = generate_name_variants(schema)
object_name_combos = product(schema_variants, table_variants)
for schema_name, table_name in object_name_combos:
select_statements.append(f"SELECT * FROM {schema_name}.{table_name}")
if database == connected_db:
object_name_combos = product(schema_variants, table_variants)
for schema_name, table_name in object_name_combos:
set_of_select_statements.append(f"SELECT * FROM {schema_name}.{table_name}")
else:
for table_variant in table_variants:
select_statements.append(f"SELECT * FROM {table_variant}")

if database is not None:
raise Exception("You must provide a schema name if you provide a database name.")

if database is not None:
database_variants = generate_name_variants(database)
object_name_combos = product(database_variants, schema_variants, table_variants)
for db_name, schema_name, table_name in object_name_combos:
set_of_select_statements.append(f"SELECT * FROM {db_name}.{schema_name}.{table_name}")
select_statements = select_statements + set_of_select_statements
logger.info(f"database: {database}, schema: {schema}, table: {table}")
for statement in set_of_select_statements:
logger.info(statement)
# logger.info(f"database: {database}, schema: {schema}, table: {table}")

return select_statements

def generate_usql_connection_params(account, user, password, role, database = None, schema = None):
Expand All @@ -264,4 +282,9 @@ def generate_usql_connection_params(account, user, password, role, database = No
if schema is not None:
params["schema"] = schema

return params
return params

def _set_connection_name(connection_dict = {}):
snowflake_connection_name = connection_dict.get("snowflake_connection_name", SNOWFLAKE_CONNECTION_NAME)
logger.info(f"Using the {snowflake_connection_name} connection")
return snowflake_connection_name

0 comments on commit 29e8c01

Please sign in to comment.