Skip to content

Commit aee6863

Browse files
authored
Telemetry server-side flag integration (databricks#646)
* feature_flag Signed-off-by: Sai Shree Pradhan <[email protected]> * fix static type check Signed-off-by: Sai Shree Pradhan <[email protected]> * fix static type check Signed-off-by: Sai Shree Pradhan <[email protected]> * force enable telemetry Signed-off-by: Sai Shree Pradhan <[email protected]> * added flag Signed-off-by: Sai Shree Pradhan <[email protected]> * linting Signed-off-by: Sai Shree Pradhan <[email protected]> * tests Signed-off-by: Sai Shree Pradhan <[email protected]> * changed flag value to be of any type Signed-off-by: Sai Shree Pradhan <[email protected]> * test fix Signed-off-by: Sai Shree Pradhan <[email protected]> --------- Signed-off-by: Sai Shree Pradhan <[email protected]>
1 parent 2f8b1ab commit aee6863

File tree

5 files changed

+289
-11
lines changed

5 files changed

+289
-11
lines changed

src/databricks/sql/client.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -248,12 +248,6 @@ def read(self) -> Optional[OAuthToken]:
248248
self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True)
249249
self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True)
250250
self._cursors = [] # type: List[Cursor]
251-
252-
self.server_telemetry_enabled = True
253-
self.client_telemetry_enabled = kwargs.get("enable_telemetry", False)
254-
self.telemetry_enabled = (
255-
self.client_telemetry_enabled and self.server_telemetry_enabled
256-
)
257251
self.telemetry_batch_size = kwargs.get(
258252
"telemetry_batch_size", TelemetryClientFactory.DEFAULT_BATCH_SIZE
259253
)
@@ -288,6 +282,10 @@ def read(self) -> Optional[OAuthToken]:
288282
)
289283
self.staging_allowed_local_path = kwargs.get("staging_allowed_local_path", None)
290284

285+
self.force_enable_telemetry = kwargs.get("force_enable_telemetry", False)
286+
self.enable_telemetry = kwargs.get("enable_telemetry", False)
287+
self.telemetry_enabled = TelemetryHelper.is_telemetry_enabled(self)
288+
291289
TelemetryClientFactory.initialize_telemetry_client(
292290
telemetry_enabled=self.telemetry_enabled,
293291
session_id_hex=self.get_session_id_hex(),
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
import threading
2+
import time
3+
import requests
4+
from dataclasses import dataclass, field
5+
from concurrent.futures import ThreadPoolExecutor
6+
from typing import Dict, Optional, List, Any, TYPE_CHECKING
7+
8+
if TYPE_CHECKING:
9+
from databricks.sql.client import Connection
10+
11+
12+
@dataclass
13+
class FeatureFlagEntry:
14+
"""Represents a single feature flag from the server response."""
15+
16+
name: str
17+
value: str
18+
19+
20+
@dataclass
21+
class FeatureFlagsResponse:
22+
"""Represents the full JSON response from the feature flag endpoint."""
23+
24+
flags: List[FeatureFlagEntry] = field(default_factory=list)
25+
ttl_seconds: Optional[int] = None
26+
27+
@classmethod
28+
def from_dict(cls, data: Dict[str, Any]) -> "FeatureFlagsResponse":
29+
"""Factory method to create an instance from a dictionary (parsed JSON)."""
30+
flags_data = data.get("flags", [])
31+
flags_list = [FeatureFlagEntry(**flag) for flag in flags_data]
32+
return cls(flags=flags_list, ttl_seconds=data.get("ttl_seconds"))
33+
34+
35+
# --- Constants ---
36+
FEATURE_FLAGS_ENDPOINT_SUFFIX_FORMAT = (
37+
"/api/2.0/connector-service/feature-flags/PYTHON/{}"
38+
)
39+
DEFAULT_TTL_SECONDS = 900 # 15 minutes
40+
REFRESH_BEFORE_EXPIRY_SECONDS = 10 # Start proactive refresh 10s before expiry
41+
42+
43+
class FeatureFlagsContext:
44+
"""
45+
Manages fetching and caching of server-side feature flags for a connection.
46+
47+
1. The very first check for any flag is a synchronous, BLOCKING operation.
48+
2. Subsequent refreshes (triggered near TTL expiry) are done asynchronously
49+
in the background, returning stale data until the refresh completes.
50+
"""
51+
52+
def __init__(self, connection: "Connection", executor: ThreadPoolExecutor):
53+
from databricks.sql import __version__
54+
55+
self._connection = connection
56+
self._executor = executor # Used for ASYNCHRONOUS refreshes
57+
self._lock = threading.RLock()
58+
59+
# Cache state: `None` indicates the cache has never been loaded.
60+
self._flags: Optional[Dict[str, str]] = None
61+
self._ttl_seconds: int = DEFAULT_TTL_SECONDS
62+
self._last_refresh_time: float = 0
63+
64+
endpoint_suffix = FEATURE_FLAGS_ENDPOINT_SUFFIX_FORMAT.format(__version__)
65+
self._feature_flag_endpoint = (
66+
f"https://{self._connection.session.host}{endpoint_suffix}"
67+
)
68+
69+
def _is_refresh_needed(self) -> bool:
70+
"""Checks if the cache is due for a proactive background refresh."""
71+
if self._flags is None:
72+
return False # Not eligible for refresh until loaded once.
73+
74+
refresh_threshold = self._last_refresh_time + (
75+
self._ttl_seconds - REFRESH_BEFORE_EXPIRY_SECONDS
76+
)
77+
return time.monotonic() > refresh_threshold
78+
79+
def get_flag_value(self, name: str, default_value: Any) -> Any:
80+
"""
81+
Checks if a feature is enabled.
82+
- BLOCKS on the first call until flags are fetched.
83+
- Returns cached values on subsequent calls, triggering non-blocking refreshes if needed.
84+
"""
85+
with self._lock:
86+
# If cache has never been loaded, perform a synchronous, blocking fetch.
87+
if self._flags is None:
88+
self._refresh_flags()
89+
90+
# If a proactive background refresh is needed, start one. This is non-blocking.
91+
elif self._is_refresh_needed():
92+
# We don't check for an in-flight refresh; the executor queues the task, which is safe.
93+
self._executor.submit(self._refresh_flags)
94+
95+
assert self._flags is not None
96+
97+
# Now, return the value from the populated cache.
98+
return self._flags.get(name, default_value)
99+
100+
def _refresh_flags(self):
101+
"""Performs a synchronous network request to fetch and update flags."""
102+
headers = {}
103+
try:
104+
# Authenticate the request
105+
self._connection.session.auth_provider.add_headers(headers)
106+
headers["User-Agent"] = self._connection.session.useragent_header
107+
108+
response = requests.get(
109+
self._feature_flag_endpoint, headers=headers, timeout=30
110+
)
111+
112+
if response.status_code == 200:
113+
ff_response = FeatureFlagsResponse.from_dict(response.json())
114+
self._update_cache_from_response(ff_response)
115+
else:
116+
# On failure, initialize with an empty dictionary to prevent re-blocking.
117+
if self._flags is None:
118+
self._flags = {}
119+
120+
except Exception as e:
121+
# On exception, initialize with an empty dictionary to prevent re-blocking.
122+
if self._flags is None:
123+
self._flags = {}
124+
125+
def _update_cache_from_response(self, ff_response: FeatureFlagsResponse):
126+
"""Atomically updates the internal cache state from a successful server response."""
127+
with self._lock:
128+
self._flags = {flag.name: flag.value for flag in ff_response.flags}
129+
if ff_response.ttl_seconds is not None and ff_response.ttl_seconds > 0:
130+
self._ttl_seconds = ff_response.ttl_seconds
131+
self._last_refresh_time = time.monotonic()
132+
133+
134+
class FeatureFlagsContextFactory:
135+
"""
136+
Manages a singleton instance of FeatureFlagsContext per connection session.
137+
Also manages a shared ThreadPoolExecutor for all background refresh operations.
138+
"""
139+
140+
_context_map: Dict[str, FeatureFlagsContext] = {}
141+
_executor: Optional[ThreadPoolExecutor] = None
142+
_lock = threading.Lock()
143+
144+
@classmethod
145+
def _initialize(cls):
146+
"""Initializes the shared executor for async refreshes if it doesn't exist."""
147+
if cls._executor is None:
148+
cls._executor = ThreadPoolExecutor(
149+
max_workers=3, thread_name_prefix="feature-flag-refresher"
150+
)
151+
152+
@classmethod
153+
def get_instance(cls, connection: "Connection") -> FeatureFlagsContext:
154+
"""Gets or creates a FeatureFlagsContext for the given connection."""
155+
with cls._lock:
156+
cls._initialize()
157+
assert cls._executor is not None
158+
159+
# Use the unique session ID as the key
160+
key = connection.get_session_id_hex()
161+
if key not in cls._context_map:
162+
cls._context_map[key] = FeatureFlagsContext(connection, cls._executor)
163+
return cls._context_map[key]
164+
165+
@classmethod
166+
def remove_instance(cls, connection: "Connection"):
167+
"""Removes the context for a given connection and shuts down the executor if no clients remain."""
168+
with cls._lock:
169+
key = connection.get_session_id_hex()
170+
if key in cls._context_map:
171+
cls._context_map.pop(key, None)
172+
173+
# If this was the last active context, clean up the thread pool.
174+
if not cls._context_map and cls._executor is not None:
175+
cls._executor.shutdown(wait=False)
176+
cls._executor = None

src/databricks/sql/telemetry/telemetry_client.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import time
33
import logging
44
from concurrent.futures import ThreadPoolExecutor
5-
from typing import Dict, Optional
5+
from typing import Dict, Optional, TYPE_CHECKING
66
from databricks.sql.common.http import TelemetryHttpClient
77
from databricks.sql.telemetry.models.event import (
88
TelemetryEvent,
@@ -36,6 +36,10 @@
3636
import uuid
3737
import locale
3838
from databricks.sql.telemetry.utils import BaseTelemetryClient
39+
from databricks.sql.common.feature_flag import FeatureFlagsContextFactory
40+
41+
if TYPE_CHECKING:
42+
from databricks.sql.client import Connection
3943

4044
logger = logging.getLogger(__name__)
4145

@@ -44,6 +48,7 @@ class TelemetryHelper:
4448
"""Helper class for getting telemetry related information."""
4549

4650
_DRIVER_SYSTEM_CONFIGURATION = None
51+
TELEMETRY_FEATURE_FLAG_NAME = "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForPythonDriver"
4752

4853
@classmethod
4954
def get_driver_system_configuration(cls) -> DriverSystemConfiguration:
@@ -98,6 +103,20 @@ def get_auth_flow(auth_provider):
98103
else:
99104
return None
100105

106+
@staticmethod
107+
def is_telemetry_enabled(connection: "Connection") -> bool:
108+
if connection.force_enable_telemetry:
109+
return True
110+
111+
if connection.enable_telemetry:
112+
context = FeatureFlagsContextFactory.get_instance(connection)
113+
flag_value = context.get_flag_value(
114+
TelemetryHelper.TELEMETRY_FEATURE_FLAG_NAME, default_value=False
115+
)
116+
return str(flag_value).lower() == "true"
117+
else:
118+
return False
119+
101120

102121
class NoopTelemetryClient(BaseTelemetryClient):
103122
"""

tests/e2e/test_concurrent_telemetry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def execute_query_worker(thread_id):
7676

7777
time.sleep(random.uniform(0, 0.05))
7878

79-
with self.connection(extra_params={"enable_telemetry": True}) as conn:
79+
with self.connection(extra_params={"force_enable_telemetry": True}) as conn:
8080
# Capture the session ID from the connection before executing the query
8181
session_id_hex = conn.get_session_id_hex()
8282
with capture_lock:

tests/unit/test_telemetry.py

Lines changed: 88 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77
NoopTelemetryClient,
88
TelemetryClientFactory,
99
TelemetryHelper,
10-
BaseTelemetryClient,
1110
)
1211
from databricks.sql.telemetry.models.enums import AuthMech, AuthFlow
1312
from databricks.sql.auth.authenticators import (
1413
AccessTokenAuthProvider,
1514
DatabricksOAuthProvider,
1615
ExternalAuthProvider,
1716
)
17+
from databricks import sql
1818

1919

2020
@pytest.fixture
@@ -311,8 +311,6 @@ def test_connection_failure_sends_correct_telemetry_payload(
311311
mock_session.side_effect = Exception(error_message)
312312

313313
try:
314-
from databricks import sql
315-
316314
sql.connect(server_hostname="test-host", http_path="/test-path")
317315
except Exception as e:
318316
assert str(e) == error_message
@@ -321,3 +319,90 @@ def test_connection_failure_sends_correct_telemetry_payload(
321319
call_arguments = mock_export_failure_log.call_args
322320
assert call_arguments[0][0] == "Exception"
323321
assert call_arguments[0][1] == error_message
322+
323+
324+
@patch("databricks.sql.client.Session")
325+
class TestTelemetryFeatureFlag:
326+
"""Tests the interaction between the telemetry feature flag and connection parameters."""
327+
328+
def _mock_ff_response(self, mock_requests_get, enabled: bool):
329+
"""Helper to configure the mock response for the feature flag endpoint."""
330+
mock_response = MagicMock()
331+
mock_response.status_code = 200
332+
payload = {
333+
"flags": [
334+
{
335+
"name": "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForPythonDriver",
336+
"value": str(enabled).lower(),
337+
}
338+
],
339+
"ttl_seconds": 3600,
340+
}
341+
mock_response.json.return_value = payload
342+
mock_requests_get.return_value = mock_response
343+
344+
@patch("databricks.sql.common.feature_flag.requests.get")
345+
def test_telemetry_enabled_when_flag_is_true(
346+
self, mock_requests_get, MockSession
347+
):
348+
"""Telemetry should be ON when enable_telemetry=True and server flag is 'true'."""
349+
self._mock_ff_response(mock_requests_get, enabled=True)
350+
mock_session_instance = MockSession.return_value
351+
mock_session_instance.guid_hex = "test-session-ff-true"
352+
mock_session_instance.auth_provider = AccessTokenAuthProvider("token")
353+
354+
conn = sql.client.Connection(
355+
server_hostname="test",
356+
http_path="test",
357+
access_token="test",
358+
enable_telemetry=True,
359+
)
360+
361+
assert conn.telemetry_enabled is True
362+
mock_requests_get.assert_called_once()
363+
client = TelemetryClientFactory.get_telemetry_client("test-session-ff-true")
364+
assert isinstance(client, TelemetryClient)
365+
366+
@patch("databricks.sql.common.feature_flag.requests.get")
367+
def test_telemetry_disabled_when_flag_is_false(
368+
self, mock_requests_get, MockSession
369+
):
370+
"""Telemetry should be OFF when enable_telemetry=True but server flag is 'false'."""
371+
self._mock_ff_response(mock_requests_get, enabled=False)
372+
mock_session_instance = MockSession.return_value
373+
mock_session_instance.guid_hex = "test-session-ff-false"
374+
mock_session_instance.auth_provider = AccessTokenAuthProvider("token")
375+
376+
conn = sql.client.Connection(
377+
server_hostname="test",
378+
http_path="test",
379+
access_token="test",
380+
enable_telemetry=True,
381+
)
382+
383+
assert conn.telemetry_enabled is False
384+
mock_requests_get.assert_called_once()
385+
client = TelemetryClientFactory.get_telemetry_client("test-session-ff-false")
386+
assert isinstance(client, NoopTelemetryClient)
387+
388+
@patch("databricks.sql.common.feature_flag.requests.get")
389+
def test_telemetry_disabled_when_flag_request_fails(
390+
self, mock_requests_get, MockSession
391+
):
392+
"""Telemetry should default to OFF if the feature flag network request fails."""
393+
mock_requests_get.side_effect = Exception("Network is down")
394+
mock_session_instance = MockSession.return_value
395+
mock_session_instance.guid_hex = "test-session-ff-fail"
396+
mock_session_instance.auth_provider = AccessTokenAuthProvider("token")
397+
398+
conn = sql.client.Connection(
399+
server_hostname="test",
400+
http_path="test",
401+
access_token="test",
402+
enable_telemetry=True,
403+
)
404+
405+
assert conn.telemetry_enabled is False
406+
mock_requests_get.assert_called_once()
407+
client = TelemetryClientFactory.get_telemetry_client("test-session-ff-fail")
408+
assert isinstance(client, NoopTelemetryClient)

0 commit comments

Comments
 (0)