Skip to content

Commit 906d902

Browse files
committed
Updates to flag providers
1 parent 139e398 commit 906d902

File tree

4 files changed

+77
-14
lines changed

4 files changed

+77
-14
lines changed

mixpanel/flags/local_feature_flags.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import time
55
import threading
66
from datetime import datetime, timedelta
7-
from typing import Dict, Any, Callable, Optional
7+
from typing import List, Dict, Any, Callable, Optional
88
from .types import (
99
ExperimentationFlag,
1010
ExperimentationFlags,
@@ -17,6 +17,7 @@
1717
normalized_hash,
1818
prepare_common_query_params,
1919
EXPOSURE_EVENT,
20+
generate_traceparent,
2021
)
2122

2223
logger = logging.getLogger(__name__)
@@ -170,15 +171,32 @@ def is_enabled(self, flag_key: str, context: Dict[str, Any]) -> bool:
170171
variant_value = self.get_variant_value(flag_key, False, context)
171172
return bool(variant_value)
172173

174+
def get_all_variants(self, context: Dict[str, Any], reportExposureEvents: bool = False) -> List[SelectedVariant]:
175+
"""
176+
Gets the selected variant for all feature flags that the current user context is in the rollout for.
177+
:param Dict[str, Any] context: The user context to evaluate against the feature flags
178+
:param bool reportExposureEvents: Whether to immediately report exposure events to your Mixpanel project for each flag evaluated. Defaults to False.
179+
"""
180+
variants: Dict[str, SelectedVariant] = {}
181+
fallback = SelectedVariant(variant_key=None, variant_value=None)
182+
183+
for flag_key in self._flag_definitions.keys():
184+
variant = self.get_variant(flag_key, fallback, context, report_exposure=reportExposureEvents)
185+
if variant.variant_key is not None:
186+
variants[flag_key] = variant
187+
188+
return variants
189+
173190
def get_variant(
174-
self, flag_key: str, fallback_value: SelectedVariant, context: Dict[str, Any]
191+
self, flag_key: str, fallback_value: SelectedVariant, context: Dict[str, Any], report_exposure: bool = True
175192
) -> SelectedVariant:
176193
"""
177194
Gets the selected variant for a feature flag
178195
179196
:param str flag_key: The key of the feature flag to evaluate
180197
:param SelectedVariant fallback_value: The default variant to return if evaluation fails
181198
:param Dict[str, Any] context: Context dictionary containing user's distinct_id and any other attributes needed for rollout evaluation
199+
:param bool report_exposure: Whether to track an exposure event for this flag evaluation. Defaults to True.
182200
"""
183201
start_time = time.perf_counter()
184202
flag_definition = self._flag_definitions.get(flag_key)
@@ -205,7 +223,8 @@ def get_variant(
205223
flag_definition, context_value, flag_key, rollout
206224
)
207225
end_time = time.perf_counter()
208-
self._track_exposure(flag_key, variant, end_time - start_time, context)
226+
if report_exposure:
227+
self._track_exposure(flag_key, variant, end_time - start_time, context)
209228
return variant
210229

211230
logger.info(
@@ -241,12 +260,17 @@ def _get_assigned_variant(
241260
):
242261
return variant
243262

244-
variants = flag_definition.ruleset.variants
245263

246264
hash_input = str(context_value) + flag_name
247265

248266
variant_hash = normalized_hash(hash_input, "variant")
249267

268+
variants = [variant.model_copy(deep=True) for variant in flag_definition.ruleset.variants]
269+
if rollout.variant_splits:
270+
for variant in variants:
271+
if variant.key in rollout.variant_splits:
272+
variant.split = rollout.variant_splits[variant.key]
273+
250274
selected = variants[0]
251275
cumulative = 0.0
252276
for variant in variants:
@@ -255,7 +279,11 @@ def _get_assigned_variant(
255279
if variant_hash < cumulative:
256280
break
257281

258-
return SelectedVariant(variant_key=selected.key, variant_value=selected.value)
282+
return SelectedVariant(
283+
variant_key=selected.key,
284+
variant_value=selected.value,
285+
experiment_id=flag_definition.experiment_id,
286+
is_experiment_active=flag_definition.is_experiment_active)
259287

260288
def _get_assigned_rollout(
261289
self,
@@ -304,15 +332,20 @@ def _get_matching_variant(
304332
for variant in flag.ruleset.variants:
305333
if variant_key.casefold() == variant.key.casefold():
306334
return SelectedVariant(
307-
variant_key=variant.key, variant_value=variant.value
335+
variant_key=variant.key,
336+
variant_value=variant.value,
337+
experiment_id=flag.experiment_id,
338+
is_experiment_active=flag.is_experiment_active,
339+
is_qa_tester=True,
308340
)
309341
return None
310342

311343
async def _afetch_flag_definitions(self) -> None:
312344
try:
313345
start_time = datetime.now()
346+
headers = {"traceparent": generate_traceparent()}
314347
response = await self._async_client.get(
315-
self.FLAGS_DEFINITIONS_URL_PATH, params=self._request_params
348+
self.FLAGS_DEFINITIONS_URL_PATH, params=self._request_params, headers=headers
316349
)
317350
end_time = datetime.now()
318351
self._handle_response(response, start_time, end_time)
@@ -322,8 +355,9 @@ async def _afetch_flag_definitions(self) -> None:
322355
def _fetch_flag_definitions(self) -> None:
323356
try:
324357
start_time = datetime.now()
358+
headers = {"traceparent": generate_traceparent()}
325359
response = self._sync_client.get(
326-
self.FLAGS_DEFINITIONS_URL_PATH, params=self._request_params
360+
self.FLAGS_DEFINITIONS_URL_PATH, params=self._request_params, headers=headers
327361
)
328362
end_time = datetime.now()
329363
self._handle_response(response, start_time, end_time)
@@ -370,6 +404,9 @@ def _track_exposure(
370404
"$experiment_type": "feature_flag",
371405
"Flag evaluation mode": "local",
372406
"Variant fetch latency (ms)": latency_in_seconds * 1000,
407+
"$experiment_id": variant.experiment_id,
408+
"$is_experiment_active": variant.is_experiment_active,
409+
"$is_qa_tester": variant.is_qa_tester,
373410
}
374411

375412
self._tracker(distinct_id, EXPOSURE_EVENT, properties)

mixpanel/flags/remote_feature_flags.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from asgiref.sync import sync_to_async
99

1010
from .types import RemoteFlagsConfig, SelectedVariant, RemoteFlagsResponse
11-
from .utils import REQUEST_HEADERS, EXPOSURE_EVENT, prepare_common_query_params
11+
from .utils import REQUEST_HEADERS, EXPOSURE_EVENT, prepare_common_query_params, generate_traceparent
1212

1313
logger = logging.getLogger(__name__)
1414
logging.getLogger("httpx").setLevel(logging.ERROR)
@@ -66,7 +66,8 @@ async def aget_variant(
6666
try:
6767
params = self._prepare_query_params(flag_key, context)
6868
start_time = datetime.now()
69-
response = await self._async_client.get(self.FLAGS_URL_PATH, params=params)
69+
headers = {"traceparent": generate_traceparent()}
70+
response = await self._async_client.get(self.FLAGS_URL_PATH, params=params, headers=headers)
7071
end_time = datetime.now()
7172
self._instrument_call(start_time, end_time)
7273
selected_variant, is_fallback = self._handle_response(
@@ -126,7 +127,8 @@ def get_variant(
126127
try:
127128
params = self._prepare_query_params(flag_key, context)
128129
start_time = datetime.now()
129-
response = self._sync_client.get(self.FLAGS_URL_PATH, params=params)
130+
headers = {"traceparent": generate_traceparent()}
131+
response = self._sync_client.get(self.FLAGS_URL_PATH, params=params, headers=headers)
130132
end_time = datetime.now()
131133
self._instrument_call(start_time, end_time)
132134
selected_variant, is_fallback = self._handle_response(

mixpanel/flags/types.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class Rollout(BaseModel):
3232
rollout_percentage: float
3333
runtime_evaluation_definition: Optional[Dict[str, str]] = None
3434
variant_override: Optional[VariantOverride] = None
35+
variant_splits: Optional[Dict[str,float]] = None
3536

3637
class RuleSet(BaseModel):
3738
variants: List[Variant]
@@ -41,16 +42,23 @@ class RuleSet(BaseModel):
4142
class ExperimentationFlag(BaseModel):
4243
id: str
4344
name: str
44-
key: str
45+
key: str
4546
status: str
4647
project_id: int
47-
ruleset: RuleSet
48+
ruleset: RuleSet
4849
context: str
50+
experiment_id: Optional[str] = None
51+
is_experiment_active: Optional[bool] = None
52+
4953

5054
class SelectedVariant(BaseModel):
5155
# variant_key can be None if being used as a fallback
5256
variant_key: Optional[str] = None
5357
variant_value: Any
58+
experiment_id: Optional[str] = None
59+
is_experiment_active: Optional[bool] = None
60+
is_qa_tester: Optional[bool] = None
61+
5462

5563
class ExperimentationFlags(BaseModel):
5664
flags: List[ExperimentationFlag]

mixpanel/flags/utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Dict
2+
from uuid import uuid
23

34
EXPOSURE_EVENT = "$experiment_started"
45

@@ -47,4 +48,19 @@ def prepare_common_query_params(token: str, sdk_version: str) -> Dict[str, str]:
4748
'token': token
4849
}
4950

50-
return params
51+
return params
52+
53+
def generate_traceparent() -> str:
54+
""" Generates a W3C traceparent header for easy interop with distributed tracing systems i.e Open Telemetry
55+
https://www.w3.org/TR/trace-context/#traceparent-header
56+
:return: A traceparent string
57+
"""
58+
59+
trace_id = uuid.uuid4().hex
60+
span_id = uuid.uuid4().hex[:16]
61+
62+
# Trace flags: '01' for sampled
63+
trace_flags = '01'
64+
65+
traceparent = f"00-{trace_id}-{span_id}-{trace_flags}"
66+
return traceparent

0 commit comments

Comments
 (0)