From 147731fdd7a6888eb953cbc175cbce8c00467ce6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20Kabakc=C4=B1?= Date: Sun, 4 Aug 2024 22:56:48 +0100 Subject: [PATCH] use fakesnow for better transpilation --- .github/workflows/release.yml | 1 + .gitignore | 3 +- README.md | 34 --- poetry.lock | 180 ++++++----- pyproject.toml | 2 +- tests/sqlglot_tests.py | 55 +++- universql/catalog/__init__.py | 3 +- universql/catalog/snow/show_iceberg_tables.py | 221 +++++++++----- universql/lake/cloud.py | 7 +- universql/server.py | 2 +- universql/util.py | 4 - universql/warehouse/duckdb.py | 288 ++++++++++-------- 12 files changed, 463 insertions(+), 337 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 3634b12..3c52aac 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -9,6 +9,7 @@ jobs: pip: name: Release PyPI runs-on: ubuntu-latest + environment: deploy steps: - name: Check out the repository uses: actions/checkout@v3 diff --git a/.gitignore b/.gitignore index 4be605b..a6c2c9d 100644 --- a/.gitignore +++ b/.gitignore @@ -103,4 +103,5 @@ celerybeat.pid .metabase/* .clickhouse/* .DS_STORE -.certs/* \ No newline at end of file +.certs/* +.rill/* \ No newline at end of file diff --git a/README.md b/README.md index 56bdda8..6f80b80 100644 --- a/README.md +++ b/README.md @@ -143,40 +143,6 @@ It gives you free https connection to your local server and it's the default hos For Catalog, [Snowflake](https://docs.snowflake.com/en/sql-reference/sql/create-iceberg-table-snowflake) and [Object Store](https://docs.snowflake.com/en/sql-reference/sql/create-iceberg-table-iceberg-files) catalogs are supported at the moment. For Data lake, S3 and GCS supported. -## Can't query all Snowflake types locally - -Here is a Markdown table of some Snowflake data types with a "Supported" column. The checkbox indicates whether the type is supported or not. Please replace the checkboxes with the correct values according to your project's support for each data type. - -| Snowflake Data Type | Supported | -| --- |--------------------------------| -| NUMBER | ✓ | -| DECIMAL | ✓ | -| INT | ✓ | -| BIGINT | ✓ | -| SMALLINT | ✓ | -| TINYINT | ✓ | -| FLOAT | ✓ | -| DOUBLE | ✓ | -| VARCHAR | ✓ | -| CHAR | ✓ | -| STRING | ✓ | -| TEXT | ✓ | -| BOOLEAN | ✓ | -| DATE | ✓ | -| DATETIME | ✓ | -| TIME | ✓ | -| TIMESTAMP | ✓ | -| TIMESTAMP_LTZ | ✗ ¹ | -| TIMESTAMP_NTZ | ✗ ¹ | -| TIMESTAMP_TZ | ✗¹ | -| VARIANT | ✓ | -| OBJECT | ✓ | -| ARRAY | ✓ | -| GEOGRAPHY | ✗ ¹ | -| VECTOR | ✗ ¹ | - -¹: No Support in DuckDB yet. - ## Can't query native Snowflake tables locally UniverSQL doesn't support querying native Snowflake tables as they're not accessible from outside of Snowflake. If you try to query a Snowflake table directly, it will return an error. diff --git a/poetry.lock b/poetry.lock index 52b3c16..7463da1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -311,6 +311,17 @@ files = [ {file = "asn1crypto-1.5.1.tar.gz", hash = "sha256:13ae38502be632115abf8a24cbe5f4da52e3b5231990aff31123c805306ccb9c"}, ] +[[package]] +name = "astroid" +version = "3.2.4" +description = "An abstract syntax tree for Python with inference support." +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "astroid-3.2.4-py3-none-any.whl", hash = "sha256:413658a61eeca6202a59231abb473f932038fbcbf1666587f66d482083413a25"}, + {file = "astroid-3.2.4.tar.gz", hash = "sha256:0e14202810b30da1b735827f78f5157be2bbd4a7a59b7707ca0bfc2fb4c0063a"}, +] + [[package]] name = "attrs" version = "23.2.0" @@ -740,6 +751,21 @@ wrapt = ">=1.10,<2" [package.extras] dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"] +[[package]] +name = "dill" +version = "0.3.8" +description = "serialize all of Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7"}, + {file = "dill-0.3.8.tar.gz", hash = "sha256:3ebe3c479ad625c4553aca177444d89b486b1d84982eeacded644afc0cf797ca"}, +] + +[package.extras] +graph = ["objgraph (>=1.7.2)"] +profile = ["gprof2dot (>=2022.7.29)"] + [[package]] name = "dnspython" version = "2.6.1" @@ -880,6 +906,27 @@ files = [ [package.extras] tests = ["black", "pytest", "pytest-cov", "tox"] +[[package]] +name = "fakesnow" +version = "0.9.20" +description = "Fake Snowflake Connector for Python. Run, mock and test Snowflake DB locally." +optional = false +python-versions = ">=3.9" +files = [ + {file = "fakesnow-0.9.20-py3-none-any.whl", hash = "sha256:4a497c6f67303a14b6b24a487cf424318a83507421ea94543ed621ff3760596c"}, + {file = "fakesnow-0.9.20.tar.gz", hash = "sha256:45847ad6392dbf164255f731ffe123b231fd87995dfb871ff7308e86804a62e9"}, +] + +[package.dependencies] +duckdb = ">=1.0.0,<1.1.0" +pyarrow = "*" +snowflake-connector-python = "*" +sqlglot = ">=25.3.0,<25.4.0" + +[package.extras] +dev = ["build (>=1.0,<2.0)", "pandas-stubs", "pre-commit (>=3.4,<4.0)", "pytest (>=8.0,<9.0)", "ruff (>=0.4.2,<0.5.0)", "snowflake-connector-python[pandas,secure-local-storage]", "snowflake-sqlalchemy (>=1.5.0,<1.6.0)", "twine (>=5.0,<6.0)"] +notebook = ["duckdb-engine", "ipykernel", "jupysql"] + [[package]] name = "fastapi" version = "0.111.0" @@ -1650,6 +1697,20 @@ doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linke perf = ["ipython"] test = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] +[[package]] +name = "isort" +version = "5.13.2" +description = "A Python utility / library to sort Python imports." +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "isort-5.13.2-py3-none-any.whl", hash = "sha256:8ca5e72a8d85860d5a3fa69b8745237f2939afe12dbf656afbcb47fe72d947a6"}, + {file = "isort-5.13.2.tar.gz", hash = "sha256:48fdfcb9face5d58a4f6dde2e72a1fb8dcaf8ab26f95ab49fab84c2ddefb0109"}, +] + +[package.extras] +colors = ["colorama (>=0.4.6)"] + [[package]] name = "jaraco-classes" version = "3.4.0" @@ -1895,6 +1956,17 @@ files = [ {file = "MarkupSafe-2.1.5.tar.gz", hash = "sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b"}, ] +[[package]] +name = "mccabe" +version = "0.7.0" +description = "McCabe checker, plugin for flake8" +optional = false +python-versions = ">=3.6" +files = [ + {file = "mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e"}, + {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, +] + [[package]] name = "mdit-py-plugins" version = "0.4.1" @@ -2796,6 +2868,33 @@ dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pyte docs = ["sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"] tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] +[[package]] +name = "pylint" +version = "3.2.6" +description = "python code static checker" +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "pylint-3.2.6-py3-none-any.whl", hash = "sha256:03c8e3baa1d9fb995b12c1dbe00aa6c4bcef210c2a2634374aedeb22fb4a8f8f"}, + {file = "pylint-3.2.6.tar.gz", hash = "sha256:a5d01678349454806cff6d886fb072294f56a58c4761278c97fb557d708e1eb3"}, +] + +[package.dependencies] +astroid = ">=3.2.4,<=3.3.0-dev0" +colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} +dill = [ + {version = ">=0.3.7", markers = "python_version >= \"3.12\""}, + {version = ">=0.3.6", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, +] +isort = ">=4.2.5,<5.13.0 || >5.13.0,<6" +mccabe = ">=0.6,<0.8" +platformdirs = ">=2.2.0" +tomlkit = ">=0.10.1" + +[package.extras] +spelling = ["pyenchant (>=3.2,<4.0)"] +testutils = ["gitpython (>3)"] + [[package]] name = "pyopenssl" version = "24.1.0" @@ -3566,92 +3665,19 @@ sqlcipher = ["sqlcipher3_binary"] [[package]] name = "sqlglot" -version = "25.5.1" +version = "25.3.3" description = "An easily customizable SQL parser and transpiler" optional = false python-versions = ">=3.7" files = [ - {file = "sqlglot-25.5.1-py3-none-any.whl", hash = "sha256:80019318158069edc11e6892f74c696e5579a5588da2a0ce11dd7e215a3cb318"}, - {file = "sqlglot-25.5.1.tar.gz", hash = "sha256:c167eac4536dd6ed202fee5107b76b8cb05db85550a47e8448bf6452c4780158"}, + {file = "sqlglot-25.3.3-py3-none-any.whl", hash = "sha256:678adad59a01ce58a7f33463b907dcdb43a0309caa5d55d6ac0af33acf698494"}, + {file = "sqlglot-25.3.3.tar.gz", hash = "sha256:6b9a4c740fb8c78509eb3f51ab1d79f7c4ab3dc9d274251007acdd85faf4aeeb"}, ] -[package.dependencies] -sqlglotrs = {version = "0.2.8", optional = true, markers = "extra == \"rs\""} - [package.extras] dev = ["duckdb (>=0.6)", "maturin (>=1.4,<2.0)", "mypy", "pandas", "pandas-stubs", "pdoc", "pre-commit", "python-dateutil", "ruff (==0.4.3)", "types-python-dateutil", "typing-extensions"] rs = ["sqlglotrs (==0.2.8)"] -[[package]] -name = "sqlglotrs" -version = "0.2.8" -description = "" -optional = false -python-versions = ">=3.7" -files = [ - {file = "sqlglotrs-0.2.8-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:24e82375d2c004b98cbd814602236e4056194c90cf2523dbdecc0cb53a04e030"}, - {file = "sqlglotrs-0.2.8-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cc043373971953f67558e24082d5bf23205401caff684cfb5c6828e9ccc58f1c"}, - {file = "sqlglotrs-0.2.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d9e8e3349c0984e06494317d96416dd5b56f2aa66df60feb28b6005a67a3ac17"}, - {file = "sqlglotrs-0.2.8-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:141aa79fd40b2911fbb0976a8f8f4ac83b1b28d454c5249dd7ea068c48c7a810"}, - {file = "sqlglotrs-0.2.8-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:88695b00e5a04251569c9bb64265a2869f1ca27db31c5bc25237e8ad5d7b54c4"}, - {file = "sqlglotrs-0.2.8-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:389275e162075d9f516f76b97f8b28de959cc60b1263725fabe9d081b6b729b9"}, - {file = "sqlglotrs-0.2.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f2e41df30777adc5f0b3def20e6ac6ba052b20f152cb4d03e9494d0483f9103"}, - {file = "sqlglotrs-0.2.8-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7b49ebe6c17ca4b4622b88b110447ac7cbd087b7f3c0e0ee2f4a815c739f2300"}, - {file = "sqlglotrs-0.2.8-cp310-none-win32.whl", hash = "sha256:50bd7d3fccb043fb080edd65041286131ad09d5fb526ee8ce12d000a17866bfd"}, - {file = "sqlglotrs-0.2.8-cp310-none-win_amd64.whl", hash = "sha256:ad2b09aa06d004662be281c688b4665d466cafc7574c38c518230bc8edab52a2"}, - {file = "sqlglotrs-0.2.8-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:26c4842816a2dc6ea9b62afccfe6d1491c93afb0b808af981b9ba6f6490761c8"}, - {file = "sqlglotrs-0.2.8-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c5e163928006eb0e4719f8eda485d79ef9768644d8f320d806da6c64bed6cdd7"}, - {file = "sqlglotrs-0.2.8-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:08967b891b0e184c9008032fde519a0b934a4bb237024f853bd7b0e76a14dcc8"}, - {file = "sqlglotrs-0.2.8-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c95d9e8ccbe4380f669df51045822a75f0de6a0f98753ec33b9103f4f609f0d2"}, - {file = "sqlglotrs-0.2.8-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3c3750b48a1273b884ea057d2e3e37f443e35390785ed099fb7789fb32f67195"}, - {file = "sqlglotrs-0.2.8-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:20f131190bad15bdf4daae1e9853c843545efb0d8bc2199e9722abb4fde47447"}, - {file = "sqlglotrs-0.2.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:588f0a16bc08371d0bdc6f3cedba2902a03b1e3f8d98bd82c049dfafe8272afb"}, - {file = "sqlglotrs-0.2.8-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:284a237a3f3ff6bfc8307da544dd47c5f714d6b586cc0f15e8baa836cddcd729"}, - {file = "sqlglotrs-0.2.8-cp311-none-win32.whl", hash = "sha256:6b05462e2570855a76d4523d18ccdf5da9f17c09018727764581d66db967796a"}, - {file = "sqlglotrs-0.2.8-cp311-none-win_amd64.whl", hash = "sha256:60e59aa0052a86f26febb3c3cff76028a36023e8d38bec76ce2b2bb8023a0b75"}, - {file = "sqlglotrs-0.2.8-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:bb97d7e90b244005bfa783bf7de65fd9a0748731a16b06412b7eeedc30756553"}, - {file = "sqlglotrs-0.2.8-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b81996937b32d5c734e507c22f03847a7348abdea59c07b56669a0c620157acd"}, - {file = "sqlglotrs-0.2.8-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70d2b395ca2c5bc35c9f59063f8298ba4353470cf3aa4adf43576ecc742111cf"}, - {file = "sqlglotrs-0.2.8-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:27cf838fe731105d6ecc1fc80ea4603357a3e02824721176502ea3ed10461191"}, - {file = "sqlglotrs-0.2.8-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ba65c687eb6d41f357b4b03657d76fdfc1d1c582835271d0abc7d125d910bf13"}, - {file = "sqlglotrs-0.2.8-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a02447644e7092df11d2acefb07d94c83a84f01599ac0fe4e3ca9f17a0f9c125"}, - {file = "sqlglotrs-0.2.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:010e453e7f86a5b26e2cb31c7bbdadaf4131fc5b7c9cfd79a42d4b989446c317"}, - {file = "sqlglotrs-0.2.8-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:da42b7c8595dcd008b6bb61ffbad241ccf0269fa2f555c7e703b79568a959a7c"}, - {file = "sqlglotrs-0.2.8-cp312-none-win32.whl", hash = "sha256:56321198b8bb2d5268f55945465fbadc23ac00d6f143b93b13c31263e5e22151"}, - {file = "sqlglotrs-0.2.8-cp312-none-win_amd64.whl", hash = "sha256:1798db5197d4f450efc8585c9c6d5d554d25d1cdfe3c2a8de147e116ea09fa5f"}, - {file = "sqlglotrs-0.2.8-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:03bed19b8cabdd19221e9407620f0e39e0177022aff92724168b88f68578562d"}, - {file = "sqlglotrs-0.2.8-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:fb41c2e79e47fdf796e33cf32c444f1bade7d56b127356a5120676c34b0d14bc"}, - {file = "sqlglotrs-0.2.8-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5a9155a7b4bd46f9df867ebbd8d04fb171d284f745a5f70178da269734a834bf"}, - {file = "sqlglotrs-0.2.8-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:20876f8cc63936153edb1be9b9caa0f81bf13e91745f9d2df9fc7117affcbea8"}, - {file = "sqlglotrs-0.2.8-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f088f74ef2f6251536b70a8b1be9b7497a82fa8e2cc85b9830860ad540c3de07"}, - {file = "sqlglotrs-0.2.8-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:78bb4fc644887cb66f19362e6a8e066e1986ffb6e602b38a7afd9a0c709f6362"}, - {file = "sqlglotrs-0.2.8-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:84de27c1f248465f6e98a3ec175a3c5b2debe01b7027628f65366e395568eaf1"}, - {file = "sqlglotrs-0.2.8-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b15ddb34c6d9ef631df70ed0780628eb04cc8d0359edcb1246cabcec0cc14971"}, - {file = "sqlglotrs-0.2.8-cp37-none-win32.whl", hash = "sha256:d1091950daa7e53ce6239f4607c74a9761a3a20692658f0cd4d2231705e9d8ea"}, - {file = "sqlglotrs-0.2.8-cp37-none-win_amd64.whl", hash = "sha256:65f9afc8f9ac26fc7dcfe9f4380f9071c03ab615017b7c316db4c5962c162a62"}, - {file = "sqlglotrs-0.2.8-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:38a56de8bf7a3b43ef11a49ff88dbef9558dcb805f1a206081547bdd707af271"}, - {file = "sqlglotrs-0.2.8-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:bb6ff9837a3564f4e6b87475a99a8f5e3311b88f505778d674a5264bb50bb13d"}, - {file = "sqlglotrs-0.2.8-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7a4fff83f7321fa445c5868f632f395f5e5e35c283b05751a6fe3ec30dfc76ba"}, - {file = "sqlglotrs-0.2.8-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:cf9079fb799001d428964cacf01d54222f31a080520b959f3908f37bd0da2fec"}, - {file = "sqlglotrs-0.2.8-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:57f38e44ea07640495790469394a2df2d5b8b53d70b37212faaa6ae5621fb1ed"}, - {file = "sqlglotrs-0.2.8-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8b8ed26ff6eba66b79cc52e0c1c9caf3a2972cd2dcd234413a3e313575b52561"}, - {file = "sqlglotrs-0.2.8-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:88d9bbed1352daa44171cd679150358bd81c91a6a1b46dc4a14cc41f95de4d09"}, - {file = "sqlglotrs-0.2.8-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:8cac4be0fabf0a6dba9e5a0d9ddde13a3e38717f3f538c2b955c71fd33307020"}, - {file = "sqlglotrs-0.2.8-cp38-none-win32.whl", hash = "sha256:dab5c9a0eedfe7fb9d3e7956bbf707db4241e0c83c603bd6ac38dffee9bfb731"}, - {file = "sqlglotrs-0.2.8-cp38-none-win_amd64.whl", hash = "sha256:0c3baa393e4bb064075cb0b3ff48806dfee1f0eb2fb2ffc592cba4636b6ed07f"}, - {file = "sqlglotrs-0.2.8-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:7c69b6bd458e3129dcb9b9d76bda9829a30e88089a5540c280f145434f0bd9bc"}, - {file = "sqlglotrs-0.2.8-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b4a9bb015c9bee16a00e73a4250534019b492d4e010c2158de54b4105e9d3f29"}, - {file = "sqlglotrs-0.2.8-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3cd9cd6af79623b09828848488771cd2d879f72da3827421c6f0125dd509de5c"}, - {file = "sqlglotrs-0.2.8-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5d0748086dcba0126aff9c28342adf90541b57facff1af70b59f81c911a9dafd"}, - {file = "sqlglotrs-0.2.8-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f3451c9337a561ae51eeebbff9c6ed75798655a85e95684ba751ad13af35ed2d"}, - {file = "sqlglotrs-0.2.8-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:21e732aed7fc889671d939519dc2791ca72a30996ba5a5509e732736344c3282"}, - {file = "sqlglotrs-0.2.8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95ba5cec63d1add75a25100987d8cf0fd0935a864351a846cb56418c8bf2f0e9"}, - {file = "sqlglotrs-0.2.8-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4d098963cbc1121f63ef8191aff0b87e43dff8937e5ea309a6957d66c5040e90"}, - {file = "sqlglotrs-0.2.8-cp39-none-win32.whl", hash = "sha256:57c33aec2593728cadbd29df09e03e0f3d0c17cacf303dd9ea6745c8e4e8ff60"}, - {file = "sqlglotrs-0.2.8-cp39-none-win_amd64.whl", hash = "sha256:eae62e5c11bca383a81b0f4c8c9f1695812baa7398511a3644420daac2f27b1d"}, - {file = "sqlglotrs-0.2.8.tar.gz", hash = "sha256:7a9c451cd850621ff2da1ff660a72ae6ceba562a2d629659b70427edf2c09b58"}, -] - [[package]] name = "starlette" version = "0.37.2" @@ -4376,4 +4402,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "e001ca383709166dd4fbc7402bc182a403d6907c3891241b4c6d4137015817b4" +content-hash = "1b4d28aa8f31d91c4dbd9df823909e63f6ffe36c96acc728fcfaae3da6ea24a7" diff --git a/pyproject.toml b/pyproject.toml index 4077f39..ec3f067 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,6 @@ uvicorn = "^0.30.1" snowflake-connector-python = {extras = ["pandas", "secure-local-storage"], version = "^3.11.0"} eval-type-backport = "^0.2.0" pip-system-certs = "^4.0" -sqlglot = {extras = ["rs"], version = "^25.5.1"} chdb = "^1.4.1" fsspec = "^2024.6.1" gcsfs = "^2024.6.1" @@ -29,6 +28,7 @@ pyiceberg = "^0.7.0" sqlalchemy = "^2.0.31" fastapi-utils = "^0.7.0" +fakesnow = "^0.9.20" [tool.poetry.dev-dependencies] pylint = ">=2.11.1" diff --git a/tests/sqlglot_tests.py b/tests/sqlglot_tests.py index 2d0987d..d318581 100644 --- a/tests/sqlglot_tests.py +++ b/tests/sqlglot_tests.py @@ -1,5 +1,11 @@ +import time +import pyarrow as pa + import duckdb import sqlglot +from fakesnow.fakes import FakeSnowflakeCursor, FakeSnowflakeConnection + +from universql.warehouse.duckdb import fix_snowflake_to_duckdb_types # queries = sqlglot.parse(""" # SET tables = (SHOW TABLES); @@ -8,17 +14,38 @@ # """, read="snowflake") -query = sqlglot.parse_one(""" -SET stmt = $$ - SELECT PI(); -$$; - -SELECT *, 1 FROM $stmt; -""", dialect="snowflake") - -zsql = query.sql(dialect="duckdb") - -query = duckdb.sql(query.sql()) - - -print(sql) +# query = sqlglot.parse_one(""" +# SET stmt = $$ +# SELECT PI(); +# $$; +# +# SELECT *, 1 FROM $stmt; +# """, dialect="snowflake") + +fields = [ + pa.field("epoch", nullable=False, type=pa.int64()), + pa.field("fraction", nullable=False, type=pa.int32()), + pa.field("timezone", nullable=False, type=pa.int32()), +] +pa_type = pa.struct(fields) +pa.StructArray.from_arrays(arrays=[pa.array([1, 2, 3], type=pa.int64()), pa.array([1, 2, 3], type=pa.int32()), pa.array([1, 2, 3], type=pa.int32())], fields=fields) + +query = """ +SELECT + CAST('2023-01-01 10:34:56 +00:00' AS TIMESTAMPLTZ) AS sample_timestamp_ltz, + CAST('2023-01-01 11:34:56' AS TIMESTAMP) AS sample_timestamp_ntz, + CAST('2023-01-01 12:34:56 +00:00' AS TIMESTAMPTZ) AS sample_timestamp_tz, + CAST(JSON('{"key":"value"}') /* Semi-structured data types */ AS VARIANT) AS sample_variant, +""" + +start = time.time() +for i in range(10): + con = duckdb.connect(f"ali/{i}") + con.execute("CREATE TABLE test (a int, b int)") +print(time.time() - start) +ast = sqlglot.parse_one(query, dialect="duckdb") +transformed_ast = ast.transform(fix_snowflake_to_duckdb_types) +query = transformed_ast.sql(dialect="duckdb", pretty=True) +print(query) +response = duckdb.sql(query) +print(response.show()) diff --git a/universql/catalog/__init__.py b/universql/catalog/__init__.py index 7f750bf..f86fd40 100644 --- a/universql/catalog/__init__.py +++ b/universql/catalog/__init__.py @@ -6,7 +6,6 @@ from snowflake.connector.options import pyarrow from universql import util -from universql.lake.cloud import CACHE_DIRECTORY_KEY from universql.util import Catalog @@ -25,7 +24,7 @@ def get_catalog(context: dict, query_id: str, credentials: dict): class Cursor(ABC): @abstractmethod - def execute(self, ast: sqlglot.exp.Expression) -> None: + def execute(self, ast: typing.Optional[sqlglot.exp.Expression], raw_query : str) -> None: pass @abstractmethod diff --git a/universql/catalog/snow/show_iceberg_tables.py b/universql/catalog/snow/show_iceberg_tables.py index c0975db..22c3b6f 100644 --- a/universql/catalog/snow/show_iceberg_tables.py +++ b/universql/catalog/snow/show_iceberg_tables.py @@ -1,13 +1,14 @@ import json import logging import typing -from typing import Optional, List +from traceback import print_exc +from typing import List import duckdb -import pandas as pd import pyarrow as pa import snowflake.connector import sqlglot +from pyarrow import ArrowInvalid from snowflake.connector import NotSupportedError, DatabaseError from snowflake.connector.constants import FIELD_TYPES, FIELD_ID_TO_NAME from snowflake.connector.cursor import ResultMetadataV2, SnowflakeCursor @@ -18,18 +19,23 @@ MAX_LIMIT = 10000 logging.basicConfig(level=logging.INFO) -cloud_logger = logging.getLogger(f"❄️(cloud)") +cloud_logger = logging.getLogger("❄️(cloud services)") + + class SnowflakeIcebergCursor(Cursor): def __init__(self, query_id, cursor: SnowflakeCursor): self.query_id = query_id self.cursor = cursor - def execute(self, ast: sqlglot.exp.Expression) -> None: - compiled_sql = ast.sql(dialect="snowflake") + def execute(self, asts: typing.Optional[List[sqlglot.exp.Expression]], raw_query: str) -> None: + if asts is None: + compiled_sql = raw_query + else: + compiled_sql = ";".join([ast.sql(dialect="snowflake", pretty=True) for ast in asts]) try: self.cursor.execute(compiled_sql) emoji = "" - if ast.key == 'show': + if all(ast.key == 'show' for ast in asts): logger = cloud_logger else: emoji = "💰" @@ -63,12 +69,15 @@ def get_v1_columns(self): return columns def get_as_table(self) -> pa.Table: - schema = pa.schema([self.get_field_for_snowflake(column) for column in self.cursor._description]) try: arrow_all = self.cursor.fetch_arrow_all(force_return_table=True) - return arrow_all.cast(schema) + for idx, column in enumerate(self.cursor._description): + (field, value) = self.get_field_for_snowflake(column, arrow_all[idx]) + arrow_all.set_column(idx, field, value) + return arrow_all + # return from snowflake is not using arrow except NotSupportedError: - values = [[] for _ in schema] + values = [[] for _ in range(len(self.cursor._description))] row = self.cursor.fetchone() while row is not None: @@ -76,17 +85,31 @@ def get_as_table(self) -> pa.Table: values[idx].append(column) row = self.cursor.fetchone() - return pa.Table.from_pydict(dict(zip(schema.names, values)), schema=schema) + fields = [] + for idx, column in enumerate(self.cursor._description): + (field, _) = self.get_field_for_snowflake(column) + fields.append(field) + schema = pa.schema(fields) - @staticmethod - def get_field_for_snowflake(column: ResultMetadataV2) -> pa.Field: + result_data = pa.Table.from_arrays([pa.array(value) for value in values], names=schema.names) + + for idx, column in enumerate(self.cursor._description): + (field, value) = self.get_field_for_snowflake(column, result_data[idx]) + try: + result_data = result_data.set_column(idx, field, value) + except ArrowInvalid as e: + # TODO: find a better approach (maybe casting?) + if any(value is not None for value in values): + result_data = result_data.set_column(idx, field, pa.nulls(len(result_data), field.type)) + else: + raise SnowflakeError(self.query_id, f"Unable to transform response: {e}") + + return result_data + + def get_field_for_snowflake(self, column: ResultMetadataV2, value: typing.Optional[pa.Array] = None) -> \ + typing.Tuple[ + pa.Field, pa.Array]: arrow_field = FIELD_TYPES[column.type_code] - if arrow_field.name == 'FIXED': - pa_type = pa.decimal128(column.precision, column.scale) - elif arrow_field.name == 'DATE': - pa_type = pa.date32() - else: - pa_type = arrow_field.pa_type(column) metadata = { "logicalType": arrow_field.name, @@ -94,15 +117,70 @@ def get_field_for_snowflake(column: ResultMetadataV2) -> pa.Field: "byteLength": "8388608", } + if arrow_field.name == "GEOGRAPHY": + metadata["logicalType"] = "OBJECT" + + if arrow_field.name == 'FIXED': + pa_type = pa.decimal128(column.precision, column.scale) + if value is not None: + value = value.cast(pa_type) + elif arrow_field.name == 'DATE': + pa_type = pa.date32() + if value is not None: + value = value.cast(pa_type) + elif arrow_field.name == 'TIME': + pa_type = pa.int64() + if value is not None: + value = value.cast(pa_type) + elif arrow_field.name == 'TIMESTAMP_LTZ' or arrow_field.name == 'TIMESTAMP_NTZ' or arrow_field.name == 'TIMESTAMP': + metadata["final_type"] = "T" + timestamp_fields = [ + pa.field("epoch", nullable=False, type=pa.int64(), metadata=metadata), + pa.field("fraction", nullable=False, type=pa.int32(), metadata=metadata), + ] + pa_type = pa.struct(timestamp_fields) + if value is not None: + epoch = pa.compute.divide(value.cast(pa.int64()), 1_000_000_000).combine_chunks() + value = pa.StructArray.from_arrays(arrays=[epoch, pa.nulls(len(value), type=pa.int32())], + fields=timestamp_fields) + elif arrow_field.name == 'TIMESTAMP_TZ': + metadata["final_type"] = "T" + timestamp_fields = [ + pa.field("epoch", nullable=False, type=pa.int64(), metadata=metadata), + pa.field("fraction", nullable=False, type=pa.int32(), metadata=metadata), + pa.field("timezone", nullable=False, type=pa.int32(), metadata=metadata), + ] + pa_type = pa.struct(timestamp_fields) + if value is not None: + epoch = pa.compute.divide(value.cast(pa.int64()), 1_000_000_000).combine_chunks() + + value = pa.StructArray.from_arrays( + arrays=[epoch, + # TODO: modulos 1_000_000_000 to get the fraction of a second, pyarrow doesn't support the operator yet + pa.nulls(len(value), type=pa.int32()), + # TODO: reverse engineer the timezone conversion + pa.nulls(len(value), type=pa.int32()), + ], + fields=timestamp_fields) + else: + pa_type = arrow_field.pa_type(column) + if column.precision is not None: metadata["precision"] = str(column.precision) if column.scale is not None: metadata["scale"] = str(column.scale) - return pa.field(column.name, type=pa_type, nullable=column.is_nullable, metadata=metadata) + + field = pa.field(column.name, type=pa_type, nullable=column.is_nullable, metadata=metadata) + try: + return (field, value) + except Exception as e: + print_exc() + raise SnowflakeError(self.query_id, + f"Unable to convert Snowflake data to Arrow, please create a Github issue with the stacktrace above: {e}") class SnowflakeShowIcebergTables(IcebergCatalog): - def __init__(self, account : str, query_id: str, credentials: dict): + def __init__(self, account: str, query_id: str, credentials: dict): super().__init__(query_id, credentials) self.databases = {} self.connection = snowflake.connector.connect(**credentials, account=account) @@ -110,47 +188,6 @@ def __init__(self, account : str, query_id: str, credentials: dict): def cursor(self) -> Cursor: return SnowflakeIcebergCursor(self.query_id, self.connection.cursor()) - def load_iceberg_tables(self, database: str, schema: str, after: Optional[str] = None) -> pd.DataFrame: - query = "SHOW ICEBERG TABLES IN SCHEMA IDENTIFIER(%s) LIMIT %s", [database + '.' + schema, MAX_LIMIT] - if after is not None: - query[0] += " AFTER %s" - query[1].append(after) - tables = pd.read_sql(query[0], self.connection, params=query[1]) - if len(tables.index) >= MAX_LIMIT: - after = tables.iloc[-1, :]["name"] - return tables + self.load_iceberg_tables(database, schema, after=after) - else: - return tables - def load_external_volumes_for_tables(self, tables: pd.DataFrame) -> pd.DataFrame: - volumes = tables["external_volume_name"].unique() - - volume_mapping = {} - for volume in volumes: - volume_location = pd.read_sql("DESC EXTERNAL VOLUME identifier(%s)", self.connection, params=[volume]) - active_storage = duckdb.sql("""select property_value from volume_location - where parent_property = 'STORAGE_LOCATIONS' and property = 'ACTIVE' - """).fetchall()[0][0] - all_properties = duckdb.execute("""select property_value from volume_location - where parent_property = 'STORAGE_LOCATIONS' and property like 'STORAGE_LOCATION_%'""").fetchall() - for properties in all_properties: - loads = json.loads(properties[0]) - if loads.get('NAME') == active_storage: - volume_mapping[volume] = loads - break - return volume_mapping - - def load_database_schema(self, database: str, schema: str): - tables = self.load_iceberg_tables(database, schema) - external_volumes = self.load_external_volumes_for_tables(tables) - - tables["external_location"] = tables.apply( - lambda x: (external_volumes[x["external_volume_name"]].get('STORAGE_BASE_URL') - + x["base_location"]), axis=1) - if database not in self.databases: - self.databases[database] = {} - - self.databases[database][schema] = dict(zip(tables.name, tables.external_location)) - def get_table_references(self, cursor: duckdb.DuckDBPyConnection, tables: List[sqlglot.exp.Table]) -> typing.Dict[ sqlglot.exp.Table, sqlglot.exp.Expression]: if len(tables) == 0: @@ -163,7 +200,8 @@ def get_table_references(self, cursor: duckdb.DuckDBPyConnection, tables: List[s result = cur.fetchall() used_tables = ",".join(set(table.sql() for table in tables)) - logging.getLogger("❄️cloud").info(f"Executed metadata query to get Iceberg table locations for tables {used_tables}") + logging.getLogger("❄️cloud").info( + f"Executed metadata query to get Iceberg table locations for tables {used_tables}") return {table: SnowflakeShowIcebergTables._get_ref(json.loads(result[0][idx])) for idx, table in enumerate(tables)} @@ -173,12 +211,53 @@ def _get_ref(table_information): return sqlglot.exp.func("iceberg_scan", sqlglot.exp.Literal.string(location)) - def find_table_location(self, database: str, schema: str, table_name: str, lazy_check: bool = True) -> str: - table_location = self.databases.get(database, {}).get(schema, {}).get(table_name) - if table_location is None: - if lazy_check: - self.load_database_schema(database, schema) - return self.find_table_location(database, schema, table_name, lazy_check=False) - else: - raise Exception(f"Table {table_name} not found in {database}.{schema}") - return table_location + # def find_table_location(self, database: str, schema: str, table_name: str, lazy_check: bool = True) -> str: + # table_location = self.databases.get(database, {}).get(schema, {}).get(table_name) + # if table_location is None: + # if lazy_check: + # self.load_database_schema(database, schema) + # return self.find_table_location(database, schema, table_name, lazy_check=False) + # else: + # raise Exception(f"Table {table_name} not found in {database}.{schema}") + # return table_location + # def load_external_volumes_for_tables(self, tables: pd.DataFrame) -> pd.DataFrame: + # volumes = tables["external_volume_name"].unique() + # + # volume_mapping = {} + # for volume in volumes: + # volume_location = pd.read_sql("DESC EXTERNAL VOLUME identifier(%s)", self.connection, params=[volume]) + # active_storage = duckdb.sql("""select property_value from volume_location + # where parent_property = 'STORAGE_LOCATIONS' and property = 'ACTIVE' + # """).fetchall()[0][0] + # all_properties = duckdb.execute("""select property_value from volume_location + # where parent_property = 'STORAGE_LOCATIONS' and property like 'STORAGE_LOCATION_%'""").fetchall() + # for properties in all_properties: + # loads = json.loads(properties[0]) + # if loads.get('NAME') == active_storage: + # volume_mapping[volume] = loads + # break + # return volume_mapping + + # def load_database_schema(self, database: str, schema: str): + # tables = self.load_iceberg_tables(database, schema) + # external_volumes = self.load_external_volumes_for_tables(tables) + # + # tables["external_location"] = tables.apply( + # lambda x: (external_volumes[x["external_volume_name"]].get('STORAGE_BASE_URL') + # + x["base_location"]), axis=1) + # if database not in self.databases: + # self.databases[database] = {} + # + # self.databases[database][schema] = dict(zip(tables.name, tables.external_location)) + + # def load_iceberg_tables(self, database: str, schema: str, after: Optional[str] = None) -> pd.DataFrame: + # query = "SHOW ICEBERG TABLES IN SCHEMA IDENTIFIER(%s) LIMIT %s", [database + '.' + schema, MAX_LIMIT] + # if after is not None: + # query[0] += " AFTER %s" + # query[1].append(after) + # tables = pd.read_sql(query[0], self.connection, params=query[1]) + # if len(tables.index) >= MAX_LIMIT: + # after = tables.iloc[-1, :]["name"] + # return tables + self.load_iceberg_tables(database, schema, after=after) + # else: + # return tables diff --git a/universql/lake/cloud.py b/universql/lake/cloud.py index b74fb16..d1e96b8 100644 --- a/universql/lake/cloud.py +++ b/universql/lake/cloud.py @@ -48,9 +48,4 @@ def get_iceberg_table_from_data_lake(metadata_file_path: str, cache_directory): PY_IO_IMPL: "universql.lake.cloud.iceberg", CACHE_DIRECTORY_KEY: cache_directory, }) - return from_metadata - - -def register_data_lake(duckdb: DuckDBPyConnection, args: dict): - duckdb.register_filesystem(s3(args.get('cache_directory'), args.get('aws_profile'))) - duckdb.register_filesystem(gcs(args.get('cache_directory'), args.get('gcp_project'))) + return from_metadata \ No newline at end of file diff --git a/universql/server.py b/universql/server.py index c6023bb..aeb3d47 100644 --- a/universql/server.py +++ b/universql/server.py @@ -60,7 +60,7 @@ async def login_request(request: Request) -> JSONResponse: token = str(uuid4()) message = None try: - session = UniverSQLSession(token, credentials, login_data.get("SESSION_PARAMETERS")) + session = UniverSQLSession(current_context, token, credentials, login_data.get("SESSION_PARAMETERS")) sessions[session.token] = session except OAuthError as e: message = e.args[0] diff --git a/universql/util.py b/universql/util.py index 7470730..57a0dbd 100644 --- a/universql/util.py +++ b/universql/util.py @@ -65,10 +65,6 @@ class Catalog(Enum): "name": "CLIENT_OUT_OF_BAND_TELEMETRY_ENABLED", "value": False }, - { - "name": "ENABLE_STAGE_S3_PRIVATELINK_FOR_US_EAST_1", - "value": False - }, { "name": "TIMESTAMP_NTZ_OUTPUT_FORMAT", "value": "YYYY-MM-DD HH24:MI:SS.FF3" diff --git a/universql/warehouse/duckdb.py b/universql/warehouse/duckdb.py index 72048fe..285ec91 100644 --- a/universql/warehouse/duckdb.py +++ b/universql/warehouse/duckdb.py @@ -1,169 +1,170 @@ import datetime -import json import logging import os import time -from typing import List +import typing +from typing import List, Optional -import click import duckdb import pyarrow import pyarrow as pa import sqlglot -from pyarrow import DataType +from fakesnow.fakes import FakeSnowflakeCursor, FakeSnowflakeConnection +from pyarrow import Table +from pyarrow.lib import ChunkedArray from snowflake.connector import DatabaseError +from sqlglot import ParseError +from sqlglot.optimizer.simplify import simplify from universql.catalog import get_catalog from universql.catalog.snow.show_iceberg_tables import cloud_logger -from universql.lake.cloud import register_data_lake +from universql.lake.cloud import s3, gcs from universql.util import get_columns_for_duckdb, SnowflakeError, Compute, Catalog logging.basicConfig(level=logging.INFO) logger = logging.getLogger("🐥") -context = click.get_current_context() -COMPUTE = context.params.get('compute') -CATALOG = context.params.get('catalog') -con = duckdb.connect(read_only=False, config={ - 'max_memory': context.params.get('max_memory'), - 'temp_directory': os.path.join(context.params.get('cache_directory'), "duckdb-staging"), - 'max_temp_directory_size': context.params.get('max_cache_size'), -}) -con.install_extension("iceberg") -con.load_extension("iceberg") - - -def apply_transformation(arrow): - for idx, field in enumerate(arrow.schema): - if pyarrow.types.is_int64(field.type): - new_type = pa.decimal128(38, 0) - pa_field = pa.field(field.name, type=new_type, nullable=field.nullable, - metadata=field.metadata) - arrow = arrow.set_column(idx, pa_field, arrow[idx].cast(new_type)) - if pyarrow.types.is_timestamp(field.type): - pa_field = pa.field(field.name, type=pa.int64(), nullable=field.nullable, - metadata=field.metadata) - cast = pa.compute.divide(arrow[idx].cast(pa.int64()), 1000000) - arrow = arrow.set_column(idx, pa_field, cast) - if pyarrow.types.is_time(field.type): - pa_field = pa.field(field.name, type=pa.int64(), nullable=field.nullable, - metadata=field.metadata) - cast = arrow[idx].cast(pa.int64()) - arrow = arrow.set_column(idx, pa_field, cast) - return arrow - - -locally_supported_queries = ["select", "union", "join"] +queries_that_doesnt_need_warehouse = ["show"] class UniverSQLSession: - def __init__(self, token, credentials: dict, session_parameters: dict): + def __init__(self, context, token, credentials: dict, session_parameters: dict): + self.context = context self.credentials = credentials self.session_parameters = [{"name": item[0], "value": item[1]} for item in session_parameters.items()] self.token = token - self.catalog = get_catalog(context.params, self.token, + self.catalog = get_catalog(context, self.token, self.credentials) - self.duckdb = con.cursor() + self.duckdb = duckdb.connect(read_only=False, config={ + 'max_memory': context.get('max_memory'), + 'temp_directory': os.path.join(context.get('cache_directory'), "duckdb-staging"), + 'max_temp_directory_size': context.get('max_cache_size'), + }) + self.duckdb.install_extension("iceberg") + self.duckdb.load_extension("iceberg") + fake_snowflake_conn = FakeSnowflakeConnection(self.duckdb, "main", "public", False, False) + fake_snowflake_conn.database_set = True + fake_snowflake_conn.schema_set = True + self.duckdb_emulator = FakeSnowflakeCursor(fake_snowflake_conn, self.duckdb) self.snowflake = self.catalog.cursor() - register_data_lake(self.duckdb, context.params) + self.register_data_lake(context) self.processing = False - def get_duckdb_transformer(self, tables: List[sqlglot.exp.Expression]): - locations = self.catalog.get_table_references(self.duckdb, tables) + def register_data_lake(self, args: dict): + self.duckdb.register_filesystem(s3(args.get('cache_directory'), args.get('aws_profile'))) + self.duckdb.register_filesystem(gcs(args.get('cache_directory'), args.get('gcp_project'))) + def sync_duckdb_catalog(self, tables: List[sqlglot.exp.Expression], ast: sqlglot.exp.Expression) -> Optional[ + sqlglot.exp.Expression]: + try: + locations = self.catalog.get_table_references(self.duckdb_emulator, tables) + except DatabaseError as e: + error_message = (f"[{self.token}] Unable to find location of Iceberg tables. " + f"See: https://github.com/buremba/universql#cant-query-native-snowflake-tables. Cause: {e.msg}") + cloud_logger.warning(error_message) + return None + + views = [f"CREATE OR REPLACE VIEW main.\"{table.sql()}\" AS SELECT * FROM {expression.sql()};" for + table, expression in locations.items()] + views_sql = "\n".join(views) + if views: + self.duckdb.execute(views_sql) + logger.info(f"[{self.token}] Creating views for Iceberg tables: \n{views_sql}") def replace_icebergs_with_duckdb_reference( expression: sqlglot.exp.Expression) -> sqlglot.exp.Expression: if isinstance(expression, sqlglot.exp.Table): if expression.name != "": - return locations[expression] + new_table = sqlglot.exp.to_table(f"main.{sqlglot.exp.parse_identifier(expression.sql())}") + return new_table + # return locations[expression] else: return expression return expression - return replace_icebergs_with_duckdb_reference + return ast.transform(replace_icebergs_with_duckdb_reference).transform(fix_snowflake_to_duckdb_types) - def do_query(self, query: str) -> (str, List, pyarrow.Table): - self.processing = True - try: - logger.info("[%s] Executing \n%s" % (self.token, query)) - start_time = time.perf_counter() - queries = sqlglot.parse(query, read="snowflake") - last_run_on_duckdb = False + def _do_query(self, raw_query: str) -> (str, List, pyarrow.Table): + start_time = time.perf_counter() + compute = self.context.get('compute') + local_error_message = "" - for ast in queries: - can_run_locally = ast.key in locally_supported_queries + try: + queries = sqlglot.parse(raw_query, read="snowflake") + except ParseError as e: + local_error_message = f"Unable to parse query with SQLGlot: {e.args}" + queries = None - passthrough_message = f"Can't run the query locally. " \ - f"Only {', '.join(locally_supported_queries)} queries are supported." - if COMPUTE == Compute.LOCAL.value and not not can_run_locally: - raise SnowflakeError(self.token, passthrough_message) + should_run_locally = compute != Compute.SNOWFLAKE.value + can_run_locally = queries is not None + run_snowflake_already = False - if can_run_locally: + if can_run_locally and should_run_locally: + for ast in queries: + if ast.key in queries_that_doesnt_need_warehouse: + self.do_snowflake_query(queries, raw_query, start_time, local_error_message) + run_snowflake_already = True + else: tables = list(ast.find_all(sqlglot.exp.Table)) + transformed_ast = self.sync_duckdb_catalog(tables, simplify(ast)) + if transformed_ast is None: + can_run_locally = False + break + sql = transformed_ast.sql(dialect="duckdb", pretty=True) + planned_duration = time.perf_counter() - start_time + timedelta = datetime.timedelta(seconds=planned_duration) + + logger.info("[%s] Re-written for DuckDB as: (%s)\n%s" % (self.token, timedelta, sql)) try: - transformer = self.get_duckdb_transformer(tables) - duckdb_query = ast.transform(transformer) - - try: - sql = duckdb_query.sql(dialect="duckdb") - planned_duration = time.perf_counter() - start_time - timedelta = datetime.timedelta(seconds=planned_duration) - - logger.info("[%s] Re-written as: (%s)\n%s" % (self.token, timedelta, sql)) - self.duckdb.execute(sql) - last_run_on_duckdb = True - continue - except duckdb.Error as e: - if COMPUTE == Compute.LOCAL.value: - raise SnowflakeError(self.token, json.dumps(e.args), - getattr(e, 'sqlstate', None)) - logger.warning("Unable to run DuckDB query locally. ") + self.duckdb_emulator.execute(sql) + # except duckdb.Error as e: except DatabaseError as e: - error_message = (f"[{self.token}] Unable to find location of Iceberg tables. " - f"See: https://github.com/buremba/universql#cant-query-native-snowflake-tables. Cause: {e.msg}") - if COMPUTE == Compute.LOCAL.value: - raise SnowflakeError(self.token, error_message, e.sqlstate) - else: - cloud_logger.warning(error_message) - else: - logger.info(f"[{self.token}] {passthrough_message}") + local_error_message = f"DuckDB error: {e.args}" + can_run_locally = False + break + + catalog = self.context.get('catalog') + if compute == Compute.LOCAL.value or catalog == Catalog.POLARIS.value: + raise SnowflakeError(self.token, f"Can't run the query locally, {local_error_message}") + + if can_run_locally and not run_snowflake_already: + formatting = (self.token, datetime.timedelta(seconds=time.perf_counter() - start_time)) + logger.info(f"[{self.token}] Run locally 🚀 ({formatting})") + return self.get_duckdb_result() + else: + self.do_snowflake_query(queries, raw_query, start_time, local_error_message) + return self.get_snowflake_result() - if CATALOG == Catalog.POLARIS.value: - raise SnowflakeError(self.token, - "Polaris catalog only supports read-only queries and DuckDB query is failed. " - "Unable to run the query.") - elif CATALOG == Catalog.SNOWFLAKE.value: - try: - self.snowflake.execute(ast) - last_run_on_duckdb = False - except SnowflakeError as e: - cloud_logger.error(f"[{self.token}] {e.message}") - raise SnowflakeError(self.token, e.message, e.sql_state) - - end_time = time.perf_counter() - start_time - formatting = (self.token, datetime.timedelta(seconds=end_time)) - if last_run_on_duckdb: - result = self.get_duckdb_result() - logger.info("[%s] Run locally 🚀 (%s)" % formatting) - else: - result = self.get_snowflake_result() - logger.info("[%s] Query is done. (%s)" % formatting) + + def do_snowflake_query(self, queries, raw_query, start_time, local_error_message): + try: + self.snowflake.execute(queries, raw_query) + formatting = (self.token, datetime.timedelta(seconds=time.perf_counter() - start_time)) + logger.info(f"[{self.token}] Query is done. ({formatting})") + except SnowflakeError as e: + final_error = f"{local_error_message}. {e.message}" + cloud_logger.error(f"[{self.token}] {final_error}") + raise SnowflakeError(self.token, final_error, e.sql_state) + + def do_query(self, raw_query: str) -> (str, List, pyarrow.Table): + logger.info("[%s] Executing \n%s" % (self.token, raw_query)) + self.processing = True + try: + return self._do_query(raw_query) finally: self.processing = False - return result - def close(self): - self.duckdb.close() + self.duckdb_emulator.close() self.snowflake.close() - @staticmethod - def get_field_for_duckdb(column: list[str], arrow_type: DataType) -> pa.Field: + def get_field_from_duckdb(self, column: list[str], arrow_table: Table, idx: int) -> typing.Tuple[ + Optional[ChunkedArray], pa.Field]: (field_name, field_type) = column[0], column[1] - pa_type = arrow_type + pa_type = arrow_table.schema[idx].type metadata = {} + transformed_data = None if field_type == 'NUMBER': metadata["logicalType"] = "FIXED" @@ -173,7 +174,9 @@ def get_field_for_duckdb(column: list[str], arrow_type: DataType) -> pa.Field: elif field_type == 'Date': pa_type = pa.date32() metadata["logicalType"] = "DATE" - elif field_type == pa.binary(): + elif field_type == 'Time': + metadata["logicalType"] = "TIME" + elif field_type == "BINARY": metadata["logicalType"] = "BINARY" elif field_type == "TIMESTAMP" or field_type == "DATETIME" or field_type == "TIMESTAMP_LTZ": metadata["logicalType"] = "TIMESTAMP_LTZ" @@ -190,27 +193,60 @@ def get_field_for_duckdb(column: list[str], arrow_type: DataType) -> pa.Field: metadata["precision"] = "0" metadata["scale"] = "9" metadata["physicalType"] = "SB16" - elif arrow_type == pa.string(): - metadata["logicalType"] = "TEXT" - metadata["charLength"] = "8388608" - metadata["byteLength"] = "8388608" - elif arrow_type == pa.bool_(): + elif field_type == "JSON": + pa_type = pa.utf8() + metadata["logicalType"] = "OBJECT" + metadata["charLength"] = "16777216" + metadata["byteLength"] = "16777216" + metadata["scale"] = "0" + metadata["precision"] = "38" + metadata["finalType"] = "T" + elif pa_type == pa.bool_(): metadata["logicalType"] = "BOOLEAN" + elif field_type == 'list': + pa_type = pa.utf8() + arrow_to_project = self.duckdb.from_arrow(arrow_table.select([field_name])) + metadata["logicalType"] = "ARRAY" + metadata["charLength"] = "16777216" + metadata["byteLength"] = "16777216" + metadata["scale"] = "0" + metadata["precision"] = "38" + metadata["finalType"] = "T" + transformed_data = (arrow_to_project.project(f"to_json({field_name})").arrow())[0] + elif pa_type == pa.string(): + metadata["logicalType"] = "TEXT" + metadata["charLength"] = "16777216" + metadata["byteLength"] = "16777216" else: raise Exception() - return pa.field(field_name, type=pa_type, nullable=True, metadata=metadata) + field = pa.field(field_name, type=pa_type, nullable=True, metadata=metadata) + if transformed_data is None: + return arrow_table[idx].cast(field.type), field + else: + return transformed_data, field def get_duckdb_result(self): - arrow_table = self.duckdb.fetch_arrow_table() - schema = pa.schema( - [self.get_field_for_duckdb(column, arrow_table.schema[idx].type) - for idx, column in enumerate(self.duckdb.description)]) - table = arrow_table.cast(schema) - return "arrow", get_columns_for_duckdb(table.schema), apply_transformation(table) + arrow_table = self.duckdb_emulator._arrow_table + if arrow_table is None: + raise SnowflakeError(self.token, "No result returned from DuckDB") + for idx, column in enumerate(self.duckdb.description): + array, schema = self.get_field_from_duckdb(column, arrow_table, idx) + arrow_table = arrow_table.set_column(idx, schema, array) + return "arrow", get_columns_for_duckdb(arrow_table.schema), arrow_table def get_snowflake_result(self): arrow = self.snowflake.get_as_table() columns = self.snowflake.get_v1_columns() - transformed_arrow = apply_transformation(arrow) - return "arrow", columns, transformed_arrow + return "arrow", columns, arrow + + +def fix_snowflake_to_duckdb_types( + expression: sqlglot.exp.Expression) -> sqlglot.exp.Expression: + if isinstance(expression, sqlglot.exp.DataType): + if expression.this.value in ["TIMESTAMPLTZ", "TIMESTAMPTZ"]: + return sqlglot.exp.DataType.build("TIMESTAMPTZ") + if expression.this.value in ["VARIANT"]: + return sqlglot.exp.DataType.build("JSON") + + return expression