Skip to content

Commit

Permalink
pull result directly from duckdb
Browse files Browse the repository at this point in the history
  • Loading branch information
buremba committed Feb 3, 2025
1 parent ee67bae commit c6d58ed
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 27 deletions.
34 changes: 31 additions & 3 deletions tests/integration/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,41 @@ def test_union(self):


def test_stage(self):
with universql_connection(warehouse=None) as conn:
with universql_connection(warehouse=None, database="MY_ICEBERG_JINJAT", schema="TPCH_SF1") as conn:
result = execute_query(conn, """
create temp table if not exists table_name1 as select 1 as t;
copy into table_name1 FROM @stagename/""")
create temp table if not exists clickhouse_public_data as select 1 as t;
copy into clickhouse_public_data FROM @clickhouse_public_data_stage/
""")
# result = execute_query(conn, "select * from @iceberg_db.public.landing_stage/initial_objects/device_metadata.csv")
assert result.num_rows > 0

def test_copy_into_for_ryan(self):
with universql_connection(snowflake_connection_name='ryan_snowflake', warehouse=None, database="ICEBERG_DB") as conn:

result = execute_query(conn, """
CREATE OR REPLACE TEMPORARY TABLE DEVICE_METADATA_REF (
device_id VARCHAR,
device_name VARCHAR,
device_type VARCHAR,
manufacturer VARCHAR,
model_number VARCHAR,
firmware_version VARCHAR,
installation_date DATE,
location_id VARCHAR,
location_name VARCHAR,
facility_zone VARCHAR,
is_active BOOLEAN,
expected_lifetime_months INT,
maintenance_interval_days INT,
last_maintenance_date DATE
);
COPY INTO DEVICE_METADATA_REF
FROM @iceberg_db.public.landing_stage/initial_objects/device_metadata.csv
FILE_FORMAT = (SKIP_HEADER = 1);
""")
assert result.num_rows != 0

def test_copy_into(self):
with universql_connection(warehouse=None) as conn:
result = execute_query(conn, """
Expand Down
18 changes: 9 additions & 9 deletions universql/plugins/snow.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,14 +445,13 @@ def get_file_info(self, files, ast):

copy_params = self._extract_copy_params(ast)
file_format_params = copy_params.get("FILE_FORMAT")
with self.source_executor.catalog.cursor() as cursor:
for file in files:
if file.get("type") == 'STAGE':
stage_info = self.get_stage_info(file, file_format_params, cursor)
stage_info["METADATA"] = stage_info["METADATA"] | file
copy_data[file["stage_name"]] = stage_info
else:
raise QueryError("Unable to find type")
for file in files:
if file.get("type") == 'STAGE':
stage_info = self.get_stage_info(file, file_format_params)
stage_info["METADATA"] = stage_info["METADATA"] | file
copy_data[file["stage_name"]] = stage_info
else:
raise QueryError("Unable to find type")
return copy_data

def _extract_copy_params(self, ast):
Expand Down Expand Up @@ -498,7 +497,7 @@ def get_region(self, profile, url, storage_provider):
region_dict = s3.get_bucket_location(Bucket=bucket_name)
return region_dict.get('LocationConstraint') or 'us-east-1'

def get_stage_info(self, file, file_format_params, cursor):
def get_stage_info(self, file, file_format_params):
"""
Retrieves and processes Snowflake stage metadata.
Expand All @@ -518,6 +517,7 @@ def get_stage_info(self, file, file_format_params, cursor):
if file_format_params is None:
file_format_params = {}
stage_name = file["stage_name"]
cursor = self.source_executor.catalog.cursor()
cursor.execute(f"DESCRIBE STAGE {stage_name}")
stage_info = cursor.fetchall()
stage_info_dict = {}
Expand Down
2 changes: 0 additions & 2 deletions universql/protocol/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ def __init__(self, context, session_id, credentials: dict, session_parameters: d
self.catalog = COMPUTES["snowflake"](self, first_catalog_compute or {})
self.catalog_executor = self.catalog.executor()
self.computes = {"snowflake": self.catalog_executor}

self.last_executor_cursor = None
self.processing = False
self.metadata_db = None
self.transforms : List[UniversqlPlugin] = [transform(self.catalog_executor) for transform in TRANSFORMS]
Expand Down
2 changes: 1 addition & 1 deletion universql/protocol/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ async def query_request(request: Request) -> JSONResponse:
"data": {"sqlState": e.sqlstate}})
except Exception as e:
if not isinstance(e, HTTPException):
print_exc(limit=1)
print_exc(limit=10)
if query is not None:
logger.exception(f"Error processing query: {query}")
else:
Expand Down
11 changes: 4 additions & 7 deletions universql/warehouse/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import typing
from enum import Enum
from string import Template
from typing import List
from typing import List, Sequence, Any

import duckdb
import pyiceberg.table
Expand All @@ -19,7 +19,6 @@
from snowflake.connector.options import pyarrow
from sqlglot.expressions import Insert, Create, Drop, Properties, TemporaryProperty, Schema, Table, Property, \
Var, Literal, IcebergProperty, Use, ColumnDef, DataType, Copy
from sqlglot.optimizer.simplify import simplify

from universql.lake.cloud import s3, gcs, in_lambda
from universql.protocol.session import UniverSQLSession
Expand All @@ -34,7 +33,6 @@ class TableType(Enum):
ICEBERG = "iceberg"
LOCAL = "local"


@register(name="duckdb")
class DuckDBCatalog(ICatalog):

Expand Down Expand Up @@ -371,14 +369,13 @@ def execute(self, ast: sqlglot.exp.Expression, catalog_executor: Executor, locat
else:
sql = self._sync_and_transform_query(ast, locations).sql(dialect="duckdb", pretty=True)
self.execute_raw(sql, catalog_executor)

return None

def get_as_table(self) -> pyarrow.Table:
arrow_table = self.catalog.emulator._arrow_table

arrow_table = self.catalog.duckdb.fetch_arrow_table()
if arrow_table is None:
raise QueryError("No result returned from DuckDB")
arrow_table = self.catalog.duckdb.sql("select 'no response returned' as message").fetch_arrow_table()

for idx, column in enumerate(self.catalog.duckdb.description):
array, schema = get_field_from_duckdb(column, arrow_table, idx)
arrow_table = arrow_table.set_column(idx, schema, array)
Expand Down
11 changes: 6 additions & 5 deletions universql/warehouse/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def clear_cache(self):
def executor(self) -> Executor:
return SnowflakeExecutor(self)

def cursor(self, create_if_not_exists=True):
if self._cursor is not None or not create_if_not_exists:
def cursor(self):
if self._cursor is not None:
return self._cursor
with sentry_sdk.start_span(op="snowflake", name="Initialize Snowflake Connection"):
try:
Expand Down Expand Up @@ -98,12 +98,13 @@ def _get_ref(self, table_information) -> pyiceberg.table.Table:
def get_table_paths(self, tables: List[sqlglot.exp.Table]) -> Tables:
if len(tables) == 0:
return {}
cursor = self.cursor()
sqls = ["SYSTEM$GET_ICEBERG_TABLE_INFORMATION(%s)" for _ in tables]
values = [table.sql(comments=False, dialect="snowflake") for table in tables]
final_query = f"SELECT {(', '.join(sqls))}"
try:
self.cursor().execute(final_query, values)
result = self.cursor().fetchall()
cursor.execute(final_query, values)
result = cursor.fetchall()
return {table: self._get_ref(json.loads(result[0][idx])) for idx, table in
enumerate(tables)}
except DatabaseError as e:
Expand Down Expand Up @@ -225,7 +226,7 @@ def get_query_log(self, total_duration) -> str:
return "Run on Snowflake"

def close(self):
cursor = self.catalog.cursor(create_if_not_exists=False)
cursor = self.catalog._cursor
if cursor is not None:
cursor.close()

Expand Down

0 comments on commit c6d58ed

Please sign in to comment.