diff --git a/examples/query_tags_example.py b/examples/query_tags_example.py new file mode 100644 index 000000000..f615d082c --- /dev/null +++ b/examples/query_tags_example.py @@ -0,0 +1,30 @@ +import os +import databricks.sql as sql + +""" +This example demonstrates how to use Query Tags. + +Query Tags are key-value pairs that can be attached to SQL executions and will appear +in the system.query.history table for analytical purposes. + +Format: "key1:value1,key2:value2,key3:value3" +""" + +print("=== Query Tags Example ===\n") + +with sql.connect( + server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + http_path=os.getenv("DATABRICKS_HTTP_PATH"), + access_token=os.getenv("DATABRICKS_TOKEN"), + session_configuration={ + 'QUERY_TAGS': 'team:engineering,test:query-tags', + 'ansi_mode': False + } +) as connection: + + with connection.cursor() as cursor: + cursor.execute("SELECT 1") + result = cursor.fetchone() + print(f" Result: {result[0]}") + +print("\n=== Query Tags Example Complete ===") \ No newline at end of file diff --git a/src/databricks/sql/backend/sea/utils/constants.py b/src/databricks/sql/backend/sea/utils/constants.py index 46ce8c98a..61ecf969e 100644 --- a/src/databricks/sql/backend/sea/utils/constants.py +++ b/src/databricks/sql/backend/sea/utils/constants.py @@ -15,6 +15,7 @@ "STATEMENT_TIMEOUT": "0", "TIMEZONE": "UTC", "USE_CACHED_RESULT": "true", + "QUERY_TAGS": "", } diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 3fa87b1af..251a901d5 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -848,10 +848,21 @@ def test_socket_timeout_user_defined(self): query = "select * from range(1000000000)" cursor.execute(query) - def test_ssp_passthrough(self): + @pytest.mark.parametrize( + "extra_params", + [ + { + "use_sea": False, + }, + { + "use_sea": True, + }, + ], + ) + def test_ssp_passthrough(self, extra_params): for enable_ansi in (True, False): with self.cursor( - {"session_configuration": {"ansi_mode": enable_ansi}} + {"session_configuration": {"ansi_mode": enable_ansi, "QUERY_TAGS": "team:marketing,dashboard:abc123,driver:python"}, **extra_params} ) as cursor: cursor.execute("SET ansi_mode") assert list(cursor.fetchone()) == ["ansi_mode", str(enable_ansi)] diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index f604f2874..26a898cb8 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -185,6 +185,7 @@ def test_session_management(self, sea_client, mock_http_client, thrift_session_i session_config = { "ANSI_MODE": "FALSE", # Supported parameter "STATEMENT_TIMEOUT": "3600", # Supported parameter + "QUERY_TAGS": "team:marketing,dashboard:abc123", # Supported parameter "unsupported_param": "value", # Unsupported parameter } catalog = "test_catalog" @@ -196,6 +197,7 @@ def test_session_management(self, sea_client, mock_http_client, thrift_session_i "session_confs": { "ansi_mode": "FALSE", "statement_timeout": "3600", + "query_tags": "team:marketing,dashboard:abc123", }, "catalog": catalog, "schema": schema, @@ -641,6 +643,7 @@ def test_filter_session_configuration(self): "TIMEZONE": "UTC", "enable_photon": False, "MAX_FILE_PARTITION_BYTES": 128.5, + "QUERY_TAGS": "team:engineering,project:data-pipeline", "unsupported_param": "value", "ANOTHER_UNSUPPORTED": 42, } @@ -663,6 +666,7 @@ def test_filter_session_configuration(self): "timezone": "UTC", # string -> "UTC", key lowercased "enable_photon": "False", # boolean False -> "False", key lowercased "max_file_partition_bytes": "128.5", # float -> "128.5", key lowercased + "query_tags": "team:engineering,project:data-pipeline", } assert result == expected_result @@ -683,12 +687,14 @@ def test_filter_session_configuration(self): "ansi_mode": "false", # lowercase key "STATEMENT_TIMEOUT": 7200, # uppercase key "TiMeZoNe": "America/New_York", # mixed case key + "QueRy_TaGs": "team:marketing,test:case-insensitive", } result = _filter_session_configuration(case_insensitive_config) expected_case_result = { "ansi_mode": "false", "statement_timeout": "7200", "timezone": "America/New_York", + "query_tags": "team:marketing,test:case-insensitive", } assert result == expected_case_result diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 6823b1b33..62a6db0dc 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -154,7 +154,7 @@ def test_socket_timeout_passthrough(self, mock_client_class): @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_configuration_passthrough(self, mock_client_class): - mock_session_config = Mock() + mock_session_config = {"ANSI_MODE": "FALSE", "QUERY_TAGS": "team:engineering,project:data-pipeline"} databricks.sql.connect( session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS )