Skip to content

Commit 119373e

Browse files
committed
Add tests
1 parent 4b3acb1 commit 119373e

File tree

6 files changed

+137
-33
lines changed

6 files changed

+137
-33
lines changed

mixpanel/flags/local_feature_flags.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import time
55
import threading
66
from datetime import datetime, timedelta
7-
from typing import List, Dict, Any, Callable, Optional
7+
from typing import Dict, Any, Callable, Optional
88
from .types import (
99
ExperimentationFlag,
1010
ExperimentationFlags,
@@ -17,7 +17,7 @@
1717
normalized_hash,
1818
prepare_common_query_params,
1919
EXPOSURE_EVENT,
20-
add_traceparent_header_to_request
20+
generate_traceparent
2121
)
2222

2323
logger = logging.getLogger(__name__)
@@ -50,7 +50,6 @@ def __init__(
5050
"headers": REQUEST_HEADERS,
5151
"auth": httpx.BasicAuth(token, ""),
5252
"timeout": httpx.Timeout(config.request_timeout_in_seconds),
53-
"event_hooks": {"request": [add_traceparent_header_to_request]},
5453
}
5554

5655
self._request_params = prepare_common_query_params(self._token, self._version)
@@ -170,7 +169,7 @@ def is_enabled(self, flag_key: str, context: Dict[str, Any]) -> bool:
170169
:param Dict[str, Any] context: Context dictionary containing user's distinct_id and any other attributes needed for rollout evaluation
171170
"""
172171
variant_value = self.get_variant_value(flag_key, False, context)
173-
return bool(variant_value)
172+
return variant_value == True
174173

175174
def get_variant(
176175
self, flag_key: str, fallback_value: SelectedVariant, context: Dict[str, Any], report_exposure: bool = True
@@ -196,21 +195,21 @@ def get_variant(
196195
)
197196
return fallback_value
198197

198+
selected_variant: Optional[SelectedVariant] = None
199+
199200
if test_user_variant := self._get_variant_override_for_test_user(
200201
flag_definition, context
201202
):
202-
return test_user_variant
203-
204-
if rollout := self._get_assigned_rollout(
205-
flag_definition, context_value, context
206-
):
207-
variant = self._get_assigned_variant(
203+
selected_variant = test_user_variant
204+
elif rollout := self._get_assigned_rollout(flag_definition, context_value, context):
205+
selected_variant = self._get_assigned_variant(
208206
flag_definition, context_value, flag_key, rollout
209207
)
208+
209+
if report_exposure and selected_variant is not None:
210210
end_time = time.perf_counter()
211-
if report_exposure:
212-
self._track_exposure(flag_key, variant, end_time - start_time, context)
213-
return variant
211+
self._track_exposure(flag_key, selected_variant, end_time - start_time, context)
212+
return selected_variant
214213

215214
logger.info(
216215
f"{flag_definition.context} context {context_value} not eligible for any rollout for flag: {flag_key}"
@@ -328,8 +327,9 @@ def _get_matching_variant(
328327
async def _afetch_flag_definitions(self) -> None:
329328
try:
330329
start_time = datetime.now()
330+
headers = {"traceparent": generate_traceparent()}
331331
response = await self._async_client.get(
332-
self.FLAGS_DEFINITIONS_URL_PATH, params=self._request_params,
332+
self.FLAGS_DEFINITIONS_URL_PATH, params=self._request_params, headers=headers
333333
)
334334
end_time = datetime.now()
335335
self._handle_response(response, start_time, end_time)
@@ -339,8 +339,9 @@ async def _afetch_flag_definitions(self) -> None:
339339
def _fetch_flag_definitions(self) -> None:
340340
try:
341341
start_time = datetime.now()
342+
headers = {"traceparent": generate_traceparent()}
342343
response = self._sync_client.get(
343-
self.FLAGS_DEFINITIONS_URL_PATH, params=self._request_params,
344+
self.FLAGS_DEFINITIONS_URL_PATH, params=self._request_params, headers=headers
344345
)
345346
end_time = datetime.now()
346347
self._handle_response(response, start_time, end_time)

mixpanel/flags/remote_feature_flags.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from asgiref.sync import sync_to_async
99

1010
from .types import RemoteFlagsConfig, SelectedVariant, RemoteFlagsResponse
11-
from .utils import REQUEST_HEADERS, EXPOSURE_EVENT, prepare_common_query_params, add_traceparent_header_to_request
11+
from .utils import REQUEST_HEADERS, EXPOSURE_EVENT, prepare_common_query_params, generate_traceparent
1212

1313
logger = logging.getLogger(__name__)
1414
logging.getLogger("httpx").setLevel(logging.ERROR)
@@ -30,7 +30,6 @@ def __init__(
3030
"headers": REQUEST_HEADERS,
3131
"auth": httpx.BasicAuth(token, ""),
3232
"timeout": httpx.Timeout(config.request_timeout_in_seconds),
33-
"event_hooks": {"request": [add_traceparent_header_to_request]},
3433
}
3534

3635
self._async_client: httpx.AsyncClient = httpx.AsyncClient(
@@ -67,7 +66,8 @@ async def aget_variant(
6766
try:
6867
params = self._prepare_query_params(flag_key, context)
6968
start_time = datetime.now()
70-
response = await self._async_client.get(self.FLAGS_URL_PATH, params=params)
69+
headers = {"traceparent": generate_traceparent()}
70+
response = await self._async_client.get(self.FLAGS_URL_PATH, params=params, headers=headers)
7171
end_time = datetime.now()
7272
self._instrument_call(start_time, end_time)
7373
selected_variant, is_fallback = self._handle_response(
@@ -97,7 +97,7 @@ async def ais_enabled(self, flag_key: str, context: Dict[str, Any]) -> bool:
9797
:param Dict[str, Any] context: Context dictionary containing user attributes and rollout context
9898
"""
9999
variant_value = await self.aget_variant_value(flag_key, False, context)
100-
return bool(variant_value)
100+
return variant_value == True
101101

102102
def get_variant_value(
103103
self, flag_key: str, fallback_value: Any, context: Dict[str, Any]
@@ -127,7 +127,8 @@ def get_variant(
127127
try:
128128
params = self._prepare_query_params(flag_key, context)
129129
start_time = datetime.now()
130-
response = self._sync_client.get(self.FLAGS_URL_PATH, params=params)
130+
headers = {"traceparent": generate_traceparent()}
131+
response = self._sync_client.get(self.FLAGS_URL_PATH, params=params, headers=headers)
131132
end_time = datetime.now()
132133
self._instrument_call(start_time, end_time)
133134
selected_variant, is_fallback = self._handle_response(
@@ -153,7 +154,7 @@ def is_enabled(self, flag_key: str, context: Dict[str, Any]) -> bool:
153154
:param Dict[str, Any] context: Context dictionary containing user attributes and rollout context
154155
"""
155156
variant_value = self.get_variant_value(flag_key, False, context)
156-
return bool(variant_value)
157+
return variant_value == True
157158

158159
def _prepare_query_params(
159160
self, flag_key: str, context: Dict[str, Any]

mixpanel/flags/test_local_feature_flags.py

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,18 @@
99
from .types import LocalFlagsConfig, ExperimentationFlag, RuleSet, Variant, Rollout, FlagTestUsers, ExperimentationFlags, VariantOverride
1010
from .local_feature_flags import LocalFeatureFlagsProvider
1111

12+
1213
def create_test_flag(
1314
flag_key: str = "test_flag",
1415
context: str = "distinct_id",
1516
variants: Optional[list[Variant]] = None,
1617
variant_override: Optional[VariantOverride] = None,
1718
rollout_percentage: float = 100.0,
1819
runtime_evaluation: Optional[Dict] = None,
19-
test_users: Optional[Dict[str, str]] = None) -> ExperimentationFlag:
20+
test_users: Optional[Dict[str, str]] = None,
21+
experiment_id: Optional[str] = None,
22+
is_experiment_active: Optional[bool] = None,
23+
variant_splits: Optional[Dict[str, float]] = None) -> ExperimentationFlag:
2024

2125
if variants is None:
2226
variants = [
@@ -27,7 +31,8 @@ def create_test_flag(
2731
rollouts = [Rollout(
2832
rollout_percentage=rollout_percentage,
2933
runtime_evaluation_definition=runtime_evaluation,
30-
variant_override=variant_override
34+
variant_override=variant_override,
35+
variant_splits=variant_splits
3136
)]
3237

3338
test_config = None
@@ -47,7 +52,9 @@ def create_test_flag(
4752
status="active",
4853
project_id=123,
4954
ruleset=ruleset,
50-
context=context
55+
context=context,
56+
experiment_id=experiment_id,
57+
is_experiment_active=is_experiment_active
5158
)
5259

5360

@@ -216,6 +223,32 @@ async def test_get_variant_value_picks_correct_variant_with_hundred_percent_spli
216223
result = self._flags.get_variant_value("test_flag", "fallback", {"distinct_id": "user123"})
217224
assert result == "variant_a"
218225

226+
@respx.mock
227+
async def test_get_variant_value_picks_correct_variant_with_half_migrated_group_splits(self):
228+
variants = [
229+
Variant(key="A", value="variant_a", is_control=False, split=100.0),
230+
Variant(key="B", value="variant_b", is_control=False, split=0.0),
231+
Variant(key="C", value="variant_c", is_control=False, split=0.0)
232+
]
233+
variant_splits = {"A": 0.0, "B": 100.0, "C": 0.0}
234+
flag = create_test_flag(variants=variants, rollout_percentage=100.0, variant_splits=variant_splits)
235+
await self.setup_flags([flag])
236+
result = self._flags.get_variant_value("test_flag", "fallback", {"distinct_id": "user123"})
237+
assert result == "variant_b"
238+
239+
@respx.mock
240+
async def test_get_variant_value_picks_correct_variant_with_full_migrated_group_splits(self):
241+
variants = [
242+
Variant(key="A", value="variant_a", is_control=False),
243+
Variant(key="B", value="variant_b", is_control=False),
244+
Variant(key="C", value="variant_c", is_control=False),
245+
]
246+
variant_splits = {"A": 0.0, "B": 0.0, "C": 100.0}
247+
flag = create_test_flag(variants=variants, rollout_percentage=100.0, variant_splits=variant_splits)
248+
await self.setup_flags([flag])
249+
result = self._flags.get_variant_value("test_flag", "fallback", {"distinct_id": "user123"})
250+
assert result == "variant_c"
251+
219252
@respx.mock
220253
async def test_get_variant_value_picks_overriden_variant(self):
221254
variants = [
@@ -236,6 +269,43 @@ async def test_get_variant_value_tracks_exposure_when_variant_selected(self):
236269
_ = self._flags.get_variant_value("test_flag", "fallback", {"distinct_id": "user123"})
237270
self._mock_tracker.assert_called_once()
238271

272+
@respx.mock
273+
@pytest.mark.parametrize("experiment_id,is_experiment_active,use_qa_user", [
274+
("exp-123", True, True), # QA tester with active experiment
275+
("exp-456", False, True), # QA tester with inactive experiment
276+
("exp-789", True, False), # Regular user with active experiment
277+
("exp-000", False, False), # Regular user with inactive experiment
278+
(None, None, True), # QA tester with no experiment
279+
(None, None, False), # Regular user with no experiment
280+
])
281+
async def test_get_variant_value_tracks_exposure_with_correct_properties(self, experiment_id, is_experiment_active, use_qa_user):
282+
flag = create_test_flag(
283+
experiment_id=experiment_id,
284+
is_experiment_active=is_experiment_active,
285+
test_users={"qa_user": "treatment"}
286+
)
287+
288+
await self.setup_flags([flag])
289+
290+
distinct_id = "qa_user" if use_qa_user else "regular_user"
291+
292+
with patch('mixpanel.flags.utils.normalized_hash') as mock_hash:
293+
mock_hash.return_value = 0.5
294+
_ = self._flags.get_variant_value("test_flag", "fallback", {"distinct_id": distinct_id})
295+
296+
self._mock_tracker.assert_called_once()
297+
298+
call_args = self._mock_tracker.call_args
299+
properties = call_args[0][2]
300+
301+
assert properties["$experiment_id"] == experiment_id
302+
assert properties["$is_experiment_active"] == is_experiment_active
303+
304+
if use_qa_user:
305+
assert properties["$is_qa_tester"] == True
306+
else:
307+
assert properties.get("$is_qa_tester") is None
308+
239309
@respx.mock
240310
async def test_get_variant_value_does_not_track_exposure_on_fallback(self):
241311
await self.setup_flags([])

mixpanel/flags/test_utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import re
2+
import pytest
3+
import random
4+
import string
5+
from .utils import generate_traceparent, normalized_hash
6+
7+
class TestUtils:
8+
def test_traceparent_format_is_correct(self):
9+
traceparent = generate_traceparent()
10+
11+
# W3C traceparent format: 00-{32 hex chars}-{16 hex chars}-{2 hex chars}
12+
# https://www.w3.org/TR/trace-context/#traceparent-header
13+
pattern = r'^00-[0-9a-f]{32}-[0-9a-f]{16}-01$'
14+
15+
assert re.match(pattern, traceparent), f"Traceparent '{traceparent}' does not match W3C format"
16+
17+
def test_traceparent_pseudo_randomness(self):
18+
traceparents = set()
19+
20+
for _ in range(100):
21+
traceparents.add(generate_traceparent())
22+
23+
assert len(traceparents) == 100, f"Expected 100 unique traceparents, got {len(traceparents)}"
24+
25+
@pytest.mark.parametrize("key,salt,expected_hash", [
26+
("abc", "variant", 0.72),
27+
("def", "variant", 0.21),
28+
])
29+
def test_normalized_hash_for_known_inputs(self, key, salt, expected_hash):
30+
result = normalized_hash(key, salt)
31+
assert result == expected_hash, f"Expected hash of {expected_hash} for '{key}' with salt '{salt}', got {result}"
32+
33+
def test_normalized_hash_is_between_0_and_1(self):
34+
for _ in range(100):
35+
length = random.randint(5, 20)
36+
random_string = ''.join(random.choices(string.ascii_letters + string.digits, k=length))
37+
random_salt = ''.join(random.choices(string.ascii_letters, k=10))
38+
result = normalized_hash(random_string, random_salt)
39+
assert 0.0 <= result < 1.0, f"Hash value {result} is not in range [0, 1] for input '{random_string}'"

mixpanel/flags/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class Variant(BaseModel):
2020
key: str
2121
value: Any
2222
is_control: bool
23-
split: float
23+
split: Optional[float] = None
2424

2525
class FlagTestUsers(BaseModel):
2626
users: Dict[str, str]

mixpanel/flags/utils.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,4 @@ def generate_traceparent() -> str:
6363
trace_flags = '01'
6464

6565
traceparent = f"00-{trace_id}-{span_id}-{trace_flags}"
66-
return traceparent
67-
68-
def add_traceparent_header_to_request(request: httpx.Request) -> None:
69-
"""Adds a W3C traceparent header to an outgoing HTTPX request for distributed tracing
70-
:param request: The HTTPX request object
71-
"""
72-
traceparent = generate_traceparent()
73-
request.headers['traceparent'] = traceparent
66+
return traceparent

0 commit comments

Comments
 (0)