diff --git a/tests/integration/extract.py b/tests/integration/extract.py index 3e834ea..bcf101b 100644 --- a/tests/integration/extract.py +++ b/tests/integration/extract.py @@ -83,13 +83,41 @@ def test_union(self): def test_stage(self): - with universql_connection(warehouse=None) as conn: + with universql_connection(warehouse=None, database="MY_ICEBERG_JINJAT", schema="TPCH_SF1") as conn: result = execute_query(conn, """ - create temp table if not exists table_name1 as select 1 as t; - copy into table_name1 FROM @stagename/""") + create temp table if not exists clickhouse_public_data as select 1 as t; + copy into clickhouse_public_data FROM @clickhouse_public_data_stage/ + """) # result = execute_query(conn, "select * from @iceberg_db.public.landing_stage/initial_objects/device_metadata.csv") assert result.num_rows > 0 + def test_copy_into_for_ryan(self): + with universql_connection(snowflake_connection_name='ryan_snowflake', warehouse=None, database="ICEBERG_DB") as conn: + + result = execute_query(conn, """ + CREATE OR REPLACE TEMPORARY TABLE DEVICE_METADATA_REF ( + device_id VARCHAR, + device_name VARCHAR, + device_type VARCHAR, + manufacturer VARCHAR, + model_number VARCHAR, + firmware_version VARCHAR, + installation_date DATE, + location_id VARCHAR, + location_name VARCHAR, + facility_zone VARCHAR, + is_active BOOLEAN, + expected_lifetime_months INT, + maintenance_interval_days INT, + last_maintenance_date DATE + ); + + COPY INTO DEVICE_METADATA_REF + FROM @iceberg_db.public.landing_stage/initial_objects/device_metadata.csv + FILE_FORMAT = (SKIP_HEADER = 1); + """) + assert result.num_rows != 0 + def test_copy_into(self): with universql_connection(warehouse=None) as conn: result = execute_query(conn, """ diff --git a/universql/plugins/snow.py b/universql/plugins/snow.py index 8c95f37..0d93947 100644 --- a/universql/plugins/snow.py +++ b/universql/plugins/snow.py @@ -445,14 +445,13 @@ def get_file_info(self, files, ast): copy_params = self._extract_copy_params(ast) file_format_params = copy_params.get("FILE_FORMAT") - with self.source_executor.catalog.cursor() as cursor: - for file in files: - if file.get("type") == 'STAGE': - stage_info = self.get_stage_info(file, file_format_params, cursor) - stage_info["METADATA"] = stage_info["METADATA"] | file - copy_data[file["stage_name"]] = stage_info - else: - raise QueryError("Unable to find type") + for file in files: + if file.get("type") == 'STAGE': + stage_info = self.get_stage_info(file, file_format_params) + stage_info["METADATA"] = stage_info["METADATA"] | file + copy_data[file["stage_name"]] = stage_info + else: + raise QueryError("Unable to find type") return copy_data def _extract_copy_params(self, ast): @@ -498,7 +497,7 @@ def get_region(self, profile, url, storage_provider): region_dict = s3.get_bucket_location(Bucket=bucket_name) return region_dict.get('LocationConstraint') or 'us-east-1' - def get_stage_info(self, file, file_format_params, cursor): + def get_stage_info(self, file, file_format_params): """ Retrieves and processes Snowflake stage metadata. @@ -518,6 +517,7 @@ def get_stage_info(self, file, file_format_params, cursor): if file_format_params is None: file_format_params = {} stage_name = file["stage_name"] + cursor = self.source_executor.catalog.cursor() cursor.execute(f"DESCRIBE STAGE {stage_name}") stage_info = cursor.fetchall() stage_info_dict = {} diff --git a/universql/protocol/session.py b/universql/protocol/session.py index d0d090b..61146b5 100644 --- a/universql/protocol/session.py +++ b/universql/protocol/session.py @@ -39,8 +39,6 @@ def __init__(self, context, session_id, credentials: dict, session_parameters: d self.catalog = COMPUTES["snowflake"](self, first_catalog_compute or {}) self.catalog_executor = self.catalog.executor() self.computes = {"snowflake": self.catalog_executor} - - self.last_executor_cursor = None self.processing = False self.metadata_db = None self.transforms : List[UniversqlPlugin] = [transform(self.catalog_executor) for transform in TRANSFORMS] diff --git a/universql/protocol/snowflake.py b/universql/protocol/snowflake.py index 9678174..9a77e3c 100644 --- a/universql/protocol/snowflake.py +++ b/universql/protocol/snowflake.py @@ -283,7 +283,7 @@ async def query_request(request: Request) -> JSONResponse: "data": {"sqlState": e.sqlstate}}) except Exception as e: if not isinstance(e, HTTPException): - print_exc(limit=1) + print_exc(limit=10) if query is not None: logger.exception(f"Error processing query: {query}") else: diff --git a/universql/warehouse/duckdb.py b/universql/warehouse/duckdb.py index 7fa7e85..48967c9 100644 --- a/universql/warehouse/duckdb.py +++ b/universql/warehouse/duckdb.py @@ -4,7 +4,7 @@ import typing from enum import Enum from string import Template -from typing import List +from typing import List, Sequence, Any import duckdb import pyiceberg.table @@ -19,7 +19,6 @@ from snowflake.connector.options import pyarrow from sqlglot.expressions import Insert, Create, Drop, Properties, TemporaryProperty, Schema, Table, Property, \ Var, Literal, IcebergProperty, Use, ColumnDef, DataType, Copy -from sqlglot.optimizer.simplify import simplify from universql.lake.cloud import s3, gcs, in_lambda from universql.protocol.session import UniverSQLSession @@ -34,7 +33,6 @@ class TableType(Enum): ICEBERG = "iceberg" LOCAL = "local" - @register(name="duckdb") class DuckDBCatalog(ICatalog): @@ -371,14 +369,13 @@ def execute(self, ast: sqlglot.exp.Expression, catalog_executor: Executor, locat else: sql = self._sync_and_transform_query(ast, locations).sql(dialect="duckdb", pretty=True) self.execute_raw(sql, catalog_executor) - return None def get_as_table(self) -> pyarrow.Table: - arrow_table = self.catalog.emulator._arrow_table - + arrow_table = self.catalog.duckdb.fetch_arrow_table() if arrow_table is None: - raise QueryError("No result returned from DuckDB") + arrow_table = self.catalog.duckdb.sql("select 'no response returned' as message").fetch_arrow_table() + for idx, column in enumerate(self.catalog.duckdb.description): array, schema = get_field_from_duckdb(column, arrow_table, idx) arrow_table = arrow_table.set_column(idx, schema, array) diff --git a/universql/warehouse/snowflake.py b/universql/warehouse/snowflake.py index 64d833f..f8dca97 100644 --- a/universql/warehouse/snowflake.py +++ b/universql/warehouse/snowflake.py @@ -60,8 +60,8 @@ def clear_cache(self): def executor(self) -> Executor: return SnowflakeExecutor(self) - def cursor(self, create_if_not_exists=True): - if self._cursor is not None or not create_if_not_exists: + def cursor(self): + if self._cursor is not None: return self._cursor with sentry_sdk.start_span(op="snowflake", name="Initialize Snowflake Connection"): try: @@ -98,12 +98,13 @@ def _get_ref(self, table_information) -> pyiceberg.table.Table: def get_table_paths(self, tables: List[sqlglot.exp.Table]) -> Tables: if len(tables) == 0: return {} + cursor = self.cursor() sqls = ["SYSTEM$GET_ICEBERG_TABLE_INFORMATION(%s)" for _ in tables] values = [table.sql(comments=False, dialect="snowflake") for table in tables] final_query = f"SELECT {(', '.join(sqls))}" try: - self.cursor().execute(final_query, values) - result = self.cursor().fetchall() + cursor.execute(final_query, values) + result = cursor.fetchall() return {table: self._get_ref(json.loads(result[0][idx])) for idx, table in enumerate(tables)} except DatabaseError as e: @@ -225,7 +226,7 @@ def get_query_log(self, total_duration) -> str: return "Run on Snowflake" def close(self): - cursor = self.catalog.cursor(create_if_not_exists=False) + cursor = self.catalog._cursor if cursor is not None: cursor.close()