diff --git a/tests/integration/extract.py b/tests/integration/extract.py index c504d7d..528965e 100644 --- a/tests/integration/extract.py +++ b/tests/integration/extract.py @@ -81,9 +81,11 @@ def test_union(self): result = execute_query(conn, "select 1 union all select 2") assert result.num_rows == 2 + def test_stage(self): with universql_connection(warehouse=None) as conn: - result = execute_query(conn, "select * from @iceberg_db.public.landing_stage/initial_objects/device_metadata.csv") + result = execute_query(conn, "copy into table_name FROM @stagename/") + # result = execute_query(conn, "select * from @iceberg_db.public.landing_stage/initial_objects/device_metadata.csv") assert result.num_rows > 0 def test_copy_into(self): diff --git a/tests/plugins/snow.py b/tests/plugins/snow.py index ee88261..724cca3 100644 --- a/tests/plugins/snow.py +++ b/tests/plugins/snow.py @@ -1,19 +1,18 @@ import uuid +import sqlglot from sqlglot import parse_one -from universql.plugins.snow import TableSampleUniversqlPlugin, SnowflakeStageUniversqlPlugin +from universql.plugins.snow import SnowflakeStageUniversqlPlugin from universql.protocol.session import UniverSQLSession from universql.warehouse.duckdb import DuckDBExecutor, DuckDBCatalog from universql.warehouse.snowflake import SnowflakeExecutor, SnowflakeCatalog +one = sqlglot.parse_one("select 1", dialect="snowflake") +compute = {"warehouse": None} session = UniverSQLSession( - {'account': 'dhb43249.us-east-1', - 'cache_directory': '/Users/bkabak/.universql/cache', - 'catalog': 'snowflake', - 'home_directory': '/Users/bkabak'}, uuid.uuid4(), {"warehouse": "duckdb()"}, {}) -compute = {"warehouse": "duckdb()"} + {'account': 'dhb43249.us-east-1', 'catalog': 'snowflake'}, uuid.uuid4(), {"warehouse": "duckdb()"}, {}) plugin = SnowflakeStageUniversqlPlugin(SnowflakeExecutor(SnowflakeCatalog(session, compute))) -sql = (parse_one("select * from @stagename", dialect="snowflake") +sql = (parse_one("copy into table_name FROM @stagename/test", dialect="snowflake") .transform(plugin.transform_sql, DuckDBExecutor(DuckDBCatalog(session, compute)))) print(sql.sql(dialect="duckdb")) \ No newline at end of file diff --git a/universql/plugins/snow.py b/universql/plugins/snow.py index 510f61f..3286fc2 100644 --- a/universql/plugins/snow.py +++ b/universql/plugins/snow.py @@ -58,9 +58,6 @@ 'Windows': "%USERPROFILE%\.aws\credentials", } - - - DISALLOWED_PARAMS_BY_FORMAT = { "JSON": { "ignore_errors": ["ALWAYS_REMOVE"], @@ -83,7 +80,7 @@ 'YY': '%y', "MMMM": "%B", 'MM': '%m', - 'MON': "%b", #in snowflake this means full or abbreviated; duckdb doesn't have an either or option + 'MON': "%b", # in snowflake this means full or abbreviated; duckdb doesn't have an either or option "DD": "%d", "DY": "%a", "HH24": "%24", @@ -256,7 +253,7 @@ "duckdb_property_name": None, "duckdb_property_type": None }, - "ALLOW_DUPLICATE": { # duckdb only takes the last value + "ALLOW_DUPLICATE": { # duckdb only takes the last value "duckdb_property_name": None, "duckdb_property_type": None }, @@ -268,42 +265,26 @@ "duckdb_property_name": None, "duckdb_property_type": None }, - "STRIP_NULL_VALUES": { # would need to be handled after a successful copy + "STRIP_NULL_VALUES": { # would need to be handled after a successful copy "duckdb_property_name": None, "duckdb_property_type": None }, - "STRIP_OUTER_ARRAY": { # needs to use json_array_elements() after loading + "STRIP_OUTER_ARRAY": { # needs to use json_array_elements() after loading "duckdb_property_name": None, "duckdb_property_type": None } } + class SnowflakeStageUniversqlPlugin(UniversqlPlugin): def __init__(self, source_executor: SnowflakeExecutor): super().__init__(source_executor) - def transform_sql(self, expression: Expression, target_executor: DuckDBExecutor) -> Expression: # Ensure the root node is a Copy node if not isinstance(expression, sqlglot.exp.Copy): return expression - def get_references(expression: Expression): - if isinstance(expression, sqlglot.exp.Table) and isinstance(expression.this, - sqlglot.exp.Var) and str( - expression.this.this).startswith('@'): - pass - - if isinstance(expression, sqlglot.exp.Table) and ( - isinstance(expression.this, Identifier) or isinstance(expression.this, Var)) and not any( - expression == parent.args.get('format') for parent in expression.walk(bfs=True)): - pass - - return expression - - expression.transform(get_references, copy=False) - - cache_directory = self.source_executor.catalog.context.get('cache_directory') file_cache_directories = [] @@ -334,24 +315,26 @@ def get_references(expression: Expression): return expression - - def get_file_info(self, files, ast): copy_data = {} if len(files) == 0: return {} + + cursor = None + try: copy_params = self._extract_copy_params(ast) file_format_params = copy_params.get("FILE_FORMAT") - cursor = self.cursor() + cursor = self.source_executor.catalog.cursor() for file in files: if file.get("type") == 'STAGE': stage_info = self.get_stage_info(file, file_format_params, cursor) stage_info["METADATA"] = stage_info["METADATA"] | file copy_data[file["stage_name"]] = stage_info finally: - cursor.close() + if cursor is not None: + cursor.close() return copy_data def _extract_copy_params(self, ast): @@ -455,7 +438,8 @@ def convert_to_duckdb_properties(self, copy_properties): metadata = {} for snowflake_property_name, snowflake_property_info in copy_properties.items(): - converted_properties = self.convert_properties(file_format, snowflake_property_name, snowflake_property_info) + converted_properties = self.convert_properties(file_format, snowflake_property_name, + snowflake_property_info) duckdb_property_name, property_values = next(iter(converted_properties.items())) if property_values["duckdb_property_type"] == 'METADATA': metadata[duckdb_property_name] = property_values["duckdb_property_value"]