Skip to content

Commit c801c77

Browse files
authored
Updates to flag providers (#146)
* Updates to flag providers * Add additional updates * Add tests * Update test utils * Update test utils * default split
1 parent 139e398 commit c801c77

File tree

6 files changed

+167
-27
lines changed

6 files changed

+167
-27
lines changed

mixpanel/flags/local_feature_flags.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
normalized_hash,
1818
prepare_common_query_params,
1919
EXPOSURE_EVENT,
20+
generate_traceparent
2021
)
2122

2223
logger = logging.getLogger(__name__)
@@ -168,17 +169,18 @@ def is_enabled(self, flag_key: str, context: Dict[str, Any]) -> bool:
168169
:param Dict[str, Any] context: Context dictionary containing user's distinct_id and any other attributes needed for rollout evaluation
169170
"""
170171
variant_value = self.get_variant_value(flag_key, False, context)
171-
return bool(variant_value)
172+
return variant_value == True
172173

173174
def get_variant(
174-
self, flag_key: str, fallback_value: SelectedVariant, context: Dict[str, Any]
175+
self, flag_key: str, fallback_value: SelectedVariant, context: Dict[str, Any], report_exposure: bool = True
175176
) -> SelectedVariant:
176177
"""
177178
Gets the selected variant for a feature flag
178179
179180
:param str flag_key: The key of the feature flag to evaluate
180181
:param SelectedVariant fallback_value: The default variant to return if evaluation fails
181182
:param Dict[str, Any] context: Context dictionary containing user's distinct_id and any other attributes needed for rollout evaluation
183+
:param bool report_exposure: Whether to track an exposure event for this flag evaluation. Defaults to True.
182184
"""
183185
start_time = time.perf_counter()
184186
flag_definition = self._flag_definitions.get(flag_key)
@@ -193,20 +195,21 @@ def get_variant(
193195
)
194196
return fallback_value
195197

198+
selected_variant: Optional[SelectedVariant] = None
199+
196200
if test_user_variant := self._get_variant_override_for_test_user(
197201
flag_definition, context
198202
):
199-
return test_user_variant
200-
201-
if rollout := self._get_assigned_rollout(
202-
flag_definition, context_value, context
203-
):
204-
variant = self._get_assigned_variant(
203+
selected_variant = test_user_variant
204+
elif rollout := self._get_assigned_rollout(flag_definition, context_value, context):
205+
selected_variant = self._get_assigned_variant(
205206
flag_definition, context_value, flag_key, rollout
206207
)
208+
209+
if report_exposure and selected_variant is not None:
207210
end_time = time.perf_counter()
208-
self._track_exposure(flag_key, variant, end_time - start_time, context)
209-
return variant
211+
self._track_exposure(flag_key, selected_variant, end_time - start_time, context)
212+
return selected_variant
210213

211214
logger.info(
212215
f"{flag_definition.context} context {context_value} not eligible for any rollout for flag: {flag_key}"
@@ -241,12 +244,17 @@ def _get_assigned_variant(
241244
):
242245
return variant
243246

244-
variants = flag_definition.ruleset.variants
245247

246248
hash_input = str(context_value) + flag_name
247249

248250
variant_hash = normalized_hash(hash_input, "variant")
249251

252+
variants = [variant.model_copy(deep=True) for variant in flag_definition.ruleset.variants]
253+
if rollout.variant_splits:
254+
for variant in variants:
255+
if variant.key in rollout.variant_splits:
256+
variant.split = rollout.variant_splits[variant.key]
257+
250258
selected = variants[0]
251259
cumulative = 0.0
252260
for variant in variants:
@@ -255,7 +263,11 @@ def _get_assigned_variant(
255263
if variant_hash < cumulative:
256264
break
257265

258-
return SelectedVariant(variant_key=selected.key, variant_value=selected.value)
266+
return SelectedVariant(
267+
variant_key=selected.key,
268+
variant_value=selected.value,
269+
experiment_id=flag_definition.experiment_id,
270+
is_experiment_active=flag_definition.is_experiment_active)
259271

260272
def _get_assigned_rollout(
261273
self,
@@ -304,15 +316,20 @@ def _get_matching_variant(
304316
for variant in flag.ruleset.variants:
305317
if variant_key.casefold() == variant.key.casefold():
306318
return SelectedVariant(
307-
variant_key=variant.key, variant_value=variant.value
319+
variant_key=variant.key,
320+
variant_value=variant.value,
321+
experiment_id=flag.experiment_id,
322+
is_experiment_active=flag.is_experiment_active,
323+
is_qa_tester=True,
308324
)
309325
return None
310326

311327
async def _afetch_flag_definitions(self) -> None:
312328
try:
313329
start_time = datetime.now()
330+
headers = {"traceparent": generate_traceparent()}
314331
response = await self._async_client.get(
315-
self.FLAGS_DEFINITIONS_URL_PATH, params=self._request_params
332+
self.FLAGS_DEFINITIONS_URL_PATH, params=self._request_params, headers=headers
316333
)
317334
end_time = datetime.now()
318335
self._handle_response(response, start_time, end_time)
@@ -322,8 +339,9 @@ async def _afetch_flag_definitions(self) -> None:
322339
def _fetch_flag_definitions(self) -> None:
323340
try:
324341
start_time = datetime.now()
342+
headers = {"traceparent": generate_traceparent()}
325343
response = self._sync_client.get(
326-
self.FLAGS_DEFINITIONS_URL_PATH, params=self._request_params
344+
self.FLAGS_DEFINITIONS_URL_PATH, params=self._request_params, headers=headers
327345
)
328346
end_time = datetime.now()
329347
self._handle_response(response, start_time, end_time)
@@ -370,6 +388,9 @@ def _track_exposure(
370388
"$experiment_type": "feature_flag",
371389
"Flag evaluation mode": "local",
372390
"Variant fetch latency (ms)": latency_in_seconds * 1000,
391+
"$experiment_id": variant.experiment_id,
392+
"$is_experiment_active": variant.is_experiment_active,
393+
"$is_qa_tester": variant.is_qa_tester,
373394
}
374395

375396
self._tracker(distinct_id, EXPOSURE_EVENT, properties)

mixpanel/flags/remote_feature_flags.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from asgiref.sync import sync_to_async
99

1010
from .types import RemoteFlagsConfig, SelectedVariant, RemoteFlagsResponse
11-
from .utils import REQUEST_HEADERS, EXPOSURE_EVENT, prepare_common_query_params
11+
from .utils import REQUEST_HEADERS, EXPOSURE_EVENT, prepare_common_query_params, generate_traceparent
1212

1313
logger = logging.getLogger(__name__)
1414
logging.getLogger("httpx").setLevel(logging.ERROR)
@@ -66,7 +66,8 @@ async def aget_variant(
6666
try:
6767
params = self._prepare_query_params(flag_key, context)
6868
start_time = datetime.now()
69-
response = await self._async_client.get(self.FLAGS_URL_PATH, params=params)
69+
headers = {"traceparent": generate_traceparent()}
70+
response = await self._async_client.get(self.FLAGS_URL_PATH, params=params, headers=headers)
7071
end_time = datetime.now()
7172
self._instrument_call(start_time, end_time)
7273
selected_variant, is_fallback = self._handle_response(
@@ -96,7 +97,7 @@ async def ais_enabled(self, flag_key: str, context: Dict[str, Any]) -> bool:
9697
:param Dict[str, Any] context: Context dictionary containing user attributes and rollout context
9798
"""
9899
variant_value = await self.aget_variant_value(flag_key, False, context)
99-
return bool(variant_value)
100+
return variant_value == True
100101

101102
def get_variant_value(
102103
self, flag_key: str, fallback_value: Any, context: Dict[str, Any]
@@ -126,7 +127,8 @@ def get_variant(
126127
try:
127128
params = self._prepare_query_params(flag_key, context)
128129
start_time = datetime.now()
129-
response = self._sync_client.get(self.FLAGS_URL_PATH, params=params)
130+
headers = {"traceparent": generate_traceparent()}
131+
response = self._sync_client.get(self.FLAGS_URL_PATH, params=params, headers=headers)
130132
end_time = datetime.now()
131133
self._instrument_call(start_time, end_time)
132134
selected_variant, is_fallback = self._handle_response(
@@ -152,7 +154,7 @@ def is_enabled(self, flag_key: str, context: Dict[str, Any]) -> bool:
152154
:param Dict[str, Any] context: Context dictionary containing user attributes and rollout context
153155
"""
154156
variant_value = self.get_variant_value(flag_key, False, context)
155-
return bool(variant_value)
157+
return variant_value == True
156158

157159
def _prepare_query_params(
158160
self, flag_key: str, context: Dict[str, Any]

mixpanel/flags/test_local_feature_flags.py

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,18 @@
99
from .types import LocalFlagsConfig, ExperimentationFlag, RuleSet, Variant, Rollout, FlagTestUsers, ExperimentationFlags, VariantOverride
1010
from .local_feature_flags import LocalFeatureFlagsProvider
1111

12+
1213
def create_test_flag(
1314
flag_key: str = "test_flag",
1415
context: str = "distinct_id",
1516
variants: Optional[list[Variant]] = None,
1617
variant_override: Optional[VariantOverride] = None,
1718
rollout_percentage: float = 100.0,
1819
runtime_evaluation: Optional[Dict] = None,
19-
test_users: Optional[Dict[str, str]] = None) -> ExperimentationFlag:
20+
test_users: Optional[Dict[str, str]] = None,
21+
experiment_id: Optional[str] = None,
22+
is_experiment_active: Optional[bool] = None,
23+
variant_splits: Optional[Dict[str, float]] = None) -> ExperimentationFlag:
2024

2125
if variants is None:
2226
variants = [
@@ -27,7 +31,8 @@ def create_test_flag(
2731
rollouts = [Rollout(
2832
rollout_percentage=rollout_percentage,
2933
runtime_evaluation_definition=runtime_evaluation,
30-
variant_override=variant_override
34+
variant_override=variant_override,
35+
variant_splits=variant_splits
3136
)]
3237

3338
test_config = None
@@ -47,7 +52,9 @@ def create_test_flag(
4752
status="active",
4853
project_id=123,
4954
ruleset=ruleset,
50-
context=context
55+
context=context,
56+
experiment_id=experiment_id,
57+
is_experiment_active=is_experiment_active
5158
)
5259

5360

@@ -216,6 +223,32 @@ async def test_get_variant_value_picks_correct_variant_with_hundred_percent_spli
216223
result = self._flags.get_variant_value("test_flag", "fallback", {"distinct_id": "user123"})
217224
assert result == "variant_a"
218225

226+
@respx.mock
227+
async def test_get_variant_value_picks_correct_variant_with_half_migrated_group_splits(self):
228+
variants = [
229+
Variant(key="A", value="variant_a", is_control=False, split=100.0),
230+
Variant(key="B", value="variant_b", is_control=False, split=0.0),
231+
Variant(key="C", value="variant_c", is_control=False, split=0.0)
232+
]
233+
variant_splits = {"A": 0.0, "B": 100.0, "C": 0.0}
234+
flag = create_test_flag(variants=variants, rollout_percentage=100.0, variant_splits=variant_splits)
235+
await self.setup_flags([flag])
236+
result = self._flags.get_variant_value("test_flag", "fallback", {"distinct_id": "user123"})
237+
assert result == "variant_b"
238+
239+
@respx.mock
240+
async def test_get_variant_value_picks_correct_variant_with_full_migrated_group_splits(self):
241+
variants = [
242+
Variant(key="A", value="variant_a", is_control=False),
243+
Variant(key="B", value="variant_b", is_control=False),
244+
Variant(key="C", value="variant_c", is_control=False),
245+
]
246+
variant_splits = {"A": 0.0, "B": 0.0, "C": 100.0}
247+
flag = create_test_flag(variants=variants, rollout_percentage=100.0, variant_splits=variant_splits)
248+
await self.setup_flags([flag])
249+
result = self._flags.get_variant_value("test_flag", "fallback", {"distinct_id": "user123"})
250+
assert result == "variant_c"
251+
219252
@respx.mock
220253
async def test_get_variant_value_picks_overriden_variant(self):
221254
variants = [
@@ -236,6 +269,43 @@ async def test_get_variant_value_tracks_exposure_when_variant_selected(self):
236269
_ = self._flags.get_variant_value("test_flag", "fallback", {"distinct_id": "user123"})
237270
self._mock_tracker.assert_called_once()
238271

272+
@respx.mock
273+
@pytest.mark.parametrize("experiment_id,is_experiment_active,use_qa_user", [
274+
("exp-123", True, True), # QA tester with active experiment
275+
("exp-456", False, True), # QA tester with inactive experiment
276+
("exp-789", True, False), # Regular user with active experiment
277+
("exp-000", False, False), # Regular user with inactive experiment
278+
(None, None, True), # QA tester with no experiment
279+
(None, None, False), # Regular user with no experiment
280+
])
281+
async def test_get_variant_value_tracks_exposure_with_correct_properties(self, experiment_id, is_experiment_active, use_qa_user):
282+
flag = create_test_flag(
283+
experiment_id=experiment_id,
284+
is_experiment_active=is_experiment_active,
285+
test_users={"qa_user": "treatment"}
286+
)
287+
288+
await self.setup_flags([flag])
289+
290+
distinct_id = "qa_user" if use_qa_user else "regular_user"
291+
292+
with patch('mixpanel.flags.utils.normalized_hash') as mock_hash:
293+
mock_hash.return_value = 0.5
294+
_ = self._flags.get_variant_value("test_flag", "fallback", {"distinct_id": distinct_id})
295+
296+
self._mock_tracker.assert_called_once()
297+
298+
call_args = self._mock_tracker.call_args
299+
properties = call_args[0][2]
300+
301+
assert properties["$experiment_id"] == experiment_id
302+
assert properties["$is_experiment_active"] == is_experiment_active
303+
304+
if use_qa_user:
305+
assert properties["$is_qa_tester"] == True
306+
else:
307+
assert properties.get("$is_qa_tester") is None
308+
239309
@respx.mock
240310
async def test_get_variant_value_does_not_track_exposure_on_fallback(self):
241311
await self.setup_flags([])

mixpanel/flags/test_utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import re
2+
import pytest
3+
import random
4+
import string
5+
from .utils import generate_traceparent, normalized_hash
6+
7+
class TestUtils:
8+
def test_traceparent_format_is_correct(self):
9+
traceparent = generate_traceparent()
10+
11+
# W3C traceparent format: 00-{32 hex chars}-{16 hex chars}-{2 hex chars}
12+
# https://www.w3.org/TR/trace-context/#traceparent-header
13+
pattern = r'^00-[0-9a-f]{32}-[0-9a-f]{16}-01$'
14+
15+
assert re.match(pattern, traceparent), f"Traceparent '{traceparent}' does not match W3C format"
16+
17+
@pytest.mark.parametrize("key,salt,expected_hash", [
18+
("abc", "variant", 0.72),
19+
("def", "variant", 0.21),
20+
])
21+
def test_normalized_hash_for_known_inputs(self, key, salt, expected_hash):
22+
result = normalized_hash(key, salt)
23+
assert result == expected_hash, f"Expected hash of {expected_hash} for '{key}' with salt '{salt}', got {result}"

mixpanel/flags/types.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class Variant(BaseModel):
2020
key: str
2121
value: Any
2222
is_control: bool
23-
split: float
23+
split: Optional[float] = 0.0
2424

2525
class FlagTestUsers(BaseModel):
2626
users: Dict[str, str]
@@ -32,6 +32,7 @@ class Rollout(BaseModel):
3232
rollout_percentage: float
3333
runtime_evaluation_definition: Optional[Dict[str, str]] = None
3434
variant_override: Optional[VariantOverride] = None
35+
variant_splits: Optional[Dict[str,float]] = None
3536

3637
class RuleSet(BaseModel):
3738
variants: List[Variant]
@@ -41,16 +42,23 @@ class RuleSet(BaseModel):
4142
class ExperimentationFlag(BaseModel):
4243
id: str
4344
name: str
44-
key: str
45+
key: str
4546
status: str
4647
project_id: int
47-
ruleset: RuleSet
48+
ruleset: RuleSet
4849
context: str
50+
experiment_id: Optional[str] = None
51+
is_experiment_active: Optional[bool] = None
52+
4953

5054
class SelectedVariant(BaseModel):
5155
# variant_key can be None if being used as a fallback
5256
variant_key: Optional[str] = None
5357
variant_value: Any
58+
experiment_id: Optional[str] = None
59+
is_experiment_active: Optional[bool] = None
60+
is_qa_tester: Optional[bool] = None
61+
5462

5563
class ExperimentationFlags(BaseModel):
5664
flags: List[ExperimentationFlag]

mixpanel/flags/utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import uuid
2+
import httpx
13
from typing import Dict
24

35
EXPOSURE_EVENT = "$experiment_started"
@@ -47,4 +49,18 @@ def prepare_common_query_params(token: str, sdk_version: str) -> Dict[str, str]:
4749
'token': token
4850
}
4951

50-
return params
52+
return params
53+
54+
def generate_traceparent() -> str:
55+
"""Generates a W3C traceparent header for easy interop with distributed tracing systems i.e Open Telemetry
56+
https://www.w3.org/TR/trace-context/#traceparent-header
57+
:return: A traceparent string
58+
"""
59+
trace_id = uuid.uuid4().hex
60+
span_id = uuid.uuid4().hex[:16]
61+
62+
# Trace flags: '01' for sampled
63+
trace_flags = '01'
64+
65+
traceparent = f"00-{trace_id}-{span_id}-{trace_flags}"
66+
return traceparent

0 commit comments

Comments
 (0)