Skip to content

Commit

Permalink
fix compile issues
Browse files Browse the repository at this point in the history
  • Loading branch information
buremba committed Jan 29, 2025
1 parent 49f7b8d commit cf34fd6
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 37 deletions.
4 changes: 3 additions & 1 deletion tests/integration/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,11 @@ def test_union(self):
result = execute_query(conn, "select 1 union all select 2")
assert result.num_rows == 2


def test_stage(self):
with universql_connection(warehouse=None) as conn:
result = execute_query(conn, "select * from @iceberg_db.public.landing_stage/initial_objects/device_metadata.csv")
result = execute_query(conn, "copy into table_name FROM @stagename/")
# result = execute_query(conn, "select * from @iceberg_db.public.landing_stage/initial_objects/device_metadata.csv")
assert result.num_rows > 0

def test_copy_into(self):
Expand Down
13 changes: 6 additions & 7 deletions tests/plugins/snow.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
import uuid

import sqlglot
from sqlglot import parse_one

from universql.plugins.snow import TableSampleUniversqlPlugin, SnowflakeStageUniversqlPlugin
from universql.plugins.snow import SnowflakeStageUniversqlPlugin
from universql.protocol.session import UniverSQLSession
from universql.warehouse.duckdb import DuckDBExecutor, DuckDBCatalog
from universql.warehouse.snowflake import SnowflakeExecutor, SnowflakeCatalog

one = sqlglot.parse_one("select 1", dialect="snowflake")
compute = {"warehouse": None}
session = UniverSQLSession(
{'account': 'dhb43249.us-east-1',
'cache_directory': '/Users/bkabak/.universql/cache',
'catalog': 'snowflake',
'home_directory': '/Users/bkabak'}, uuid.uuid4(), {"warehouse": "duckdb()"}, {})
compute = {"warehouse": "duckdb()"}
{'account': 'dhb43249.us-east-1', 'catalog': 'snowflake'}, uuid.uuid4(), {"warehouse": "duckdb()"}, {})
plugin = SnowflakeStageUniversqlPlugin(SnowflakeExecutor(SnowflakeCatalog(session, compute)))
sql = (parse_one("select * from @stagename", dialect="snowflake")
sql = (parse_one("copy into table_name FROM @stagename/test", dialect="snowflake")
.transform(plugin.transform_sql, DuckDBExecutor(DuckDBCatalog(session, compute))))
print(sql.sql(dialect="duckdb"))
42 changes: 13 additions & 29 deletions universql/plugins/snow.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,6 @@
'Windows': "%USERPROFILE%\.aws\credentials",
}




DISALLOWED_PARAMS_BY_FORMAT = {
"JSON": {
"ignore_errors": ["ALWAYS_REMOVE"],
Expand All @@ -83,7 +80,7 @@
'YY': '%y',
"MMMM": "%B",
'MM': '%m',
'MON': "%b", #in snowflake this means full or abbreviated; duckdb doesn't have an either or option
'MON': "%b", # in snowflake this means full or abbreviated; duckdb doesn't have an either or option
"DD": "%d",
"DY": "%a",
"HH24": "%24",
Expand Down Expand Up @@ -256,7 +253,7 @@
"duckdb_property_name": None,
"duckdb_property_type": None
},
"ALLOW_DUPLICATE": { # duckdb only takes the last value
"ALLOW_DUPLICATE": { # duckdb only takes the last value
"duckdb_property_name": None,
"duckdb_property_type": None
},
Expand All @@ -268,42 +265,26 @@
"duckdb_property_name": None,
"duckdb_property_type": None
},
"STRIP_NULL_VALUES": { # would need to be handled after a successful copy
"STRIP_NULL_VALUES": { # would need to be handled after a successful copy
"duckdb_property_name": None,
"duckdb_property_type": None
},
"STRIP_OUTER_ARRAY": { # needs to use json_array_elements() after loading
"STRIP_OUTER_ARRAY": { # needs to use json_array_elements() after loading
"duckdb_property_name": None,
"duckdb_property_type": None
}
}


class SnowflakeStageUniversqlPlugin(UniversqlPlugin):
def __init__(self, source_executor: SnowflakeExecutor):
super().__init__(source_executor)


def transform_sql(self, expression: Expression, target_executor: DuckDBExecutor) -> Expression:
# Ensure the root node is a Copy node
if not isinstance(expression, sqlglot.exp.Copy):
return expression

def get_references(expression: Expression):
if isinstance(expression, sqlglot.exp.Table) and isinstance(expression.this,
sqlglot.exp.Var) and str(
expression.this.this).startswith('@'):
pass

if isinstance(expression, sqlglot.exp.Table) and (
isinstance(expression.this, Identifier) or isinstance(expression.this, Var)) and not any(
expression == parent.args.get('format') for parent in expression.walk(bfs=True)):
pass

return expression

expression.transform(get_references, copy=False)


cache_directory = self.source_executor.catalog.context.get('cache_directory')
file_cache_directories = []

Expand Down Expand Up @@ -334,24 +315,26 @@ def get_references(expression: Expression):

return expression



def get_file_info(self, files, ast):
copy_data = {}

if len(files) == 0:
return {}

cursor = None

try:
copy_params = self._extract_copy_params(ast)
file_format_params = copy_params.get("FILE_FORMAT")
cursor = self.cursor()
cursor = self.source_executor.catalog.cursor()
for file in files:
if file.get("type") == 'STAGE':
stage_info = self.get_stage_info(file, file_format_params, cursor)
stage_info["METADATA"] = stage_info["METADATA"] | file
copy_data[file["stage_name"]] = stage_info
finally:
cursor.close()
if cursor is not None:
cursor.close()
return copy_data

def _extract_copy_params(self, ast):
Expand Down Expand Up @@ -455,7 +438,8 @@ def convert_to_duckdb_properties(self, copy_properties):
metadata = {}

for snowflake_property_name, snowflake_property_info in copy_properties.items():
converted_properties = self.convert_properties(file_format, snowflake_property_name, snowflake_property_info)
converted_properties = self.convert_properties(file_format, snowflake_property_name,
snowflake_property_info)
duckdb_property_name, property_values = next(iter(converted_properties.items()))
if property_values["duckdb_property_type"] == 'METADATA':
metadata[duckdb_property_name] = property_values["duckdb_property_value"]
Expand Down

0 comments on commit cf34fd6

Please sign in to comment.