Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 36 additions & 15 deletions mixpanel/flags/local_feature_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
normalized_hash,
prepare_common_query_params,
EXPOSURE_EVENT,
generate_traceparent
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -168,17 +169,18 @@ def is_enabled(self, flag_key: str, context: Dict[str, Any]) -> bool:
:param Dict[str, Any] context: Context dictionary containing user's distinct_id and any other attributes needed for rollout evaluation
"""
variant_value = self.get_variant_value(flag_key, False, context)
return bool(variant_value)
return variant_value == True

def get_variant(
self, flag_key: str, fallback_value: SelectedVariant, context: Dict[str, Any]
self, flag_key: str, fallback_value: SelectedVariant, context: Dict[str, Any], report_exposure: bool = True
) -> SelectedVariant:
"""
Gets the selected variant for a feature flag

:param str flag_key: The key of the feature flag to evaluate
:param SelectedVariant fallback_value: The default variant to return if evaluation fails
:param Dict[str, Any] context: Context dictionary containing user's distinct_id and any other attributes needed for rollout evaluation
:param bool report_exposure: Whether to track an exposure event for this flag evaluation. Defaults to True.
"""
start_time = time.perf_counter()
flag_definition = self._flag_definitions.get(flag_key)
Expand All @@ -193,20 +195,21 @@ def get_variant(
)
return fallback_value

selected_variant: Optional[SelectedVariant] = None

if test_user_variant := self._get_variant_override_for_test_user(
flag_definition, context
):
return test_user_variant

if rollout := self._get_assigned_rollout(
flag_definition, context_value, context
):
variant = self._get_assigned_variant(
selected_variant = test_user_variant
elif rollout := self._get_assigned_rollout(flag_definition, context_value, context):
selected_variant = self._get_assigned_variant(
flag_definition, context_value, flag_key, rollout
)

if report_exposure and selected_variant is not None:
end_time = time.perf_counter()
self._track_exposure(flag_key, variant, end_time - start_time, context)
return variant
self._track_exposure(flag_key, selected_variant, end_time - start_time, context)
return selected_variant

logger.info(
f"{flag_definition.context} context {context_value} not eligible for any rollout for flag: {flag_key}"
Expand Down Expand Up @@ -241,12 +244,17 @@ def _get_assigned_variant(
):
return variant

variants = flag_definition.ruleset.variants

hash_input = str(context_value) + flag_name

variant_hash = normalized_hash(hash_input, "variant")

variants = [variant.model_copy(deep=True) for variant in flag_definition.ruleset.variants]
if rollout.variant_splits:
for variant in variants:
if variant.key in rollout.variant_splits:
variant.split = rollout.variant_splits[variant.key]

selected = variants[0]
cumulative = 0.0
for variant in variants:
Expand All @@ -255,7 +263,11 @@ def _get_assigned_variant(
if variant_hash < cumulative:
break

return SelectedVariant(variant_key=selected.key, variant_value=selected.value)
return SelectedVariant(
variant_key=selected.key,
variant_value=selected.value,
experiment_id=flag_definition.experiment_id,
is_experiment_active=flag_definition.is_experiment_active)

def _get_assigned_rollout(
self,
Expand Down Expand Up @@ -304,15 +316,20 @@ def _get_matching_variant(
for variant in flag.ruleset.variants:
if variant_key.casefold() == variant.key.casefold():
return SelectedVariant(
variant_key=variant.key, variant_value=variant.value
variant_key=variant.key,
variant_value=variant.value,
experiment_id=flag.experiment_id,
is_experiment_active=flag.is_experiment_active,
is_qa_tester=True,
)
return None

async def _afetch_flag_definitions(self) -> None:
try:
start_time = datetime.now()
headers = {"traceparent": generate_traceparent()}
response = await self._async_client.get(
self.FLAGS_DEFINITIONS_URL_PATH, params=self._request_params
self.FLAGS_DEFINITIONS_URL_PATH, params=self._request_params, headers=headers
)
end_time = datetime.now()
self._handle_response(response, start_time, end_time)
Expand All @@ -322,8 +339,9 @@ async def _afetch_flag_definitions(self) -> None:
def _fetch_flag_definitions(self) -> None:
try:
start_time = datetime.now()
headers = {"traceparent": generate_traceparent()}
response = self._sync_client.get(
self.FLAGS_DEFINITIONS_URL_PATH, params=self._request_params
self.FLAGS_DEFINITIONS_URL_PATH, params=self._request_params, headers=headers
)
end_time = datetime.now()
self._handle_response(response, start_time, end_time)
Expand Down Expand Up @@ -370,6 +388,9 @@ def _track_exposure(
"$experiment_type": "feature_flag",
"Flag evaluation mode": "local",
"Variant fetch latency (ms)": latency_in_seconds * 1000,
"$experiment_id": variant.experiment_id,
"$is_experiment_active": variant.is_experiment_active,
"$is_qa_tester": variant.is_qa_tester,
}

self._tracker(distinct_id, EXPOSURE_EVENT, properties)
Expand Down
12 changes: 7 additions & 5 deletions mixpanel/flags/remote_feature_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from asgiref.sync import sync_to_async

from .types import RemoteFlagsConfig, SelectedVariant, RemoteFlagsResponse
from .utils import REQUEST_HEADERS, EXPOSURE_EVENT, prepare_common_query_params
from .utils import REQUEST_HEADERS, EXPOSURE_EVENT, prepare_common_query_params, generate_traceparent

logger = logging.getLogger(__name__)
logging.getLogger("httpx").setLevel(logging.ERROR)
Expand Down Expand Up @@ -66,7 +66,8 @@ async def aget_variant(
try:
params = self._prepare_query_params(flag_key, context)
start_time = datetime.now()
response = await self._async_client.get(self.FLAGS_URL_PATH, params=params)
headers = {"traceparent": generate_traceparent()}
response = await self._async_client.get(self.FLAGS_URL_PATH, params=params, headers=headers)
end_time = datetime.now()
self._instrument_call(start_time, end_time)
selected_variant, is_fallback = self._handle_response(
Expand Down Expand Up @@ -96,7 +97,7 @@ async def ais_enabled(self, flag_key: str, context: Dict[str, Any]) -> bool:
:param Dict[str, Any] context: Context dictionary containing user attributes and rollout context
"""
variant_value = await self.aget_variant_value(flag_key, False, context)
return bool(variant_value)
return variant_value == True

def get_variant_value(
self, flag_key: str, fallback_value: Any, context: Dict[str, Any]
Expand Down Expand Up @@ -126,7 +127,8 @@ def get_variant(
try:
params = self._prepare_query_params(flag_key, context)
start_time = datetime.now()
response = self._sync_client.get(self.FLAGS_URL_PATH, params=params)
headers = {"traceparent": generate_traceparent()}
response = self._sync_client.get(self.FLAGS_URL_PATH, params=params, headers=headers)
end_time = datetime.now()
self._instrument_call(start_time, end_time)
selected_variant, is_fallback = self._handle_response(
Expand All @@ -152,7 +154,7 @@ def is_enabled(self, flag_key: str, context: Dict[str, Any]) -> bool:
:param Dict[str, Any] context: Context dictionary containing user attributes and rollout context
"""
variant_value = self.get_variant_value(flag_key, False, context)
return bool(variant_value)
return variant_value == True

def _prepare_query_params(
self, flag_key: str, context: Dict[str, Any]
Expand Down
76 changes: 73 additions & 3 deletions mixpanel/flags/test_local_feature_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,18 @@
from .types import LocalFlagsConfig, ExperimentationFlag, RuleSet, Variant, Rollout, FlagTestUsers, ExperimentationFlags, VariantOverride
from .local_feature_flags import LocalFeatureFlagsProvider


def create_test_flag(
flag_key: str = "test_flag",
context: str = "distinct_id",
variants: Optional[list[Variant]] = None,
variant_override: Optional[VariantOverride] = None,
rollout_percentage: float = 100.0,
runtime_evaluation: Optional[Dict] = None,
test_users: Optional[Dict[str, str]] = None) -> ExperimentationFlag:
test_users: Optional[Dict[str, str]] = None,
experiment_id: Optional[str] = None,
is_experiment_active: Optional[bool] = None,
variant_splits: Optional[Dict[str, float]] = None) -> ExperimentationFlag:

if variants is None:
variants = [
Expand All @@ -27,7 +31,8 @@ def create_test_flag(
rollouts = [Rollout(
rollout_percentage=rollout_percentage,
runtime_evaluation_definition=runtime_evaluation,
variant_override=variant_override
variant_override=variant_override,
variant_splits=variant_splits
)]

test_config = None
Expand All @@ -47,7 +52,9 @@ def create_test_flag(
status="active",
project_id=123,
ruleset=ruleset,
context=context
context=context,
experiment_id=experiment_id,
is_experiment_active=is_experiment_active
)


Expand Down Expand Up @@ -216,6 +223,32 @@ async def test_get_variant_value_picks_correct_variant_with_hundred_percent_spli
result = self._flags.get_variant_value("test_flag", "fallback", {"distinct_id": "user123"})
assert result == "variant_a"

@respx.mock
async def test_get_variant_value_picks_correct_variant_with_half_migrated_group_splits(self):
variants = [
Variant(key="A", value="variant_a", is_control=False, split=100.0),
Variant(key="B", value="variant_b", is_control=False, split=0.0),
Variant(key="C", value="variant_c", is_control=False, split=0.0)
]
variant_splits = {"A": 0.0, "B": 100.0, "C": 0.0}
flag = create_test_flag(variants=variants, rollout_percentage=100.0, variant_splits=variant_splits)
await self.setup_flags([flag])
result = self._flags.get_variant_value("test_flag", "fallback", {"distinct_id": "user123"})
assert result == "variant_b"

@respx.mock
async def test_get_variant_value_picks_correct_variant_with_full_migrated_group_splits(self):
variants = [
Variant(key="A", value="variant_a", is_control=False),
Variant(key="B", value="variant_b", is_control=False),
Variant(key="C", value="variant_c", is_control=False),
]
variant_splits = {"A": 0.0, "B": 0.0, "C": 100.0}
flag = create_test_flag(variants=variants, rollout_percentage=100.0, variant_splits=variant_splits)
await self.setup_flags([flag])
result = self._flags.get_variant_value("test_flag", "fallback", {"distinct_id": "user123"})
assert result == "variant_c"

@respx.mock
async def test_get_variant_value_picks_overriden_variant(self):
variants = [
Expand All @@ -236,6 +269,43 @@ async def test_get_variant_value_tracks_exposure_when_variant_selected(self):
_ = self._flags.get_variant_value("test_flag", "fallback", {"distinct_id": "user123"})
self._mock_tracker.assert_called_once()

@respx.mock
@pytest.mark.parametrize("experiment_id,is_experiment_active,use_qa_user", [
("exp-123", True, True), # QA tester with active experiment
("exp-456", False, True), # QA tester with inactive experiment
("exp-789", True, False), # Regular user with active experiment
("exp-000", False, False), # Regular user with inactive experiment
(None, None, True), # QA tester with no experiment
(None, None, False), # Regular user with no experiment
])
async def test_get_variant_value_tracks_exposure_with_correct_properties(self, experiment_id, is_experiment_active, use_qa_user):
flag = create_test_flag(
experiment_id=experiment_id,
is_experiment_active=is_experiment_active,
test_users={"qa_user": "treatment"}
)

await self.setup_flags([flag])

distinct_id = "qa_user" if use_qa_user else "regular_user"

with patch('mixpanel.flags.utils.normalized_hash') as mock_hash:
mock_hash.return_value = 0.5
_ = self._flags.get_variant_value("test_flag", "fallback", {"distinct_id": distinct_id})

self._mock_tracker.assert_called_once()

call_args = self._mock_tracker.call_args
properties = call_args[0][2]

assert properties["$experiment_id"] == experiment_id
assert properties["$is_experiment_active"] == is_experiment_active

if use_qa_user:
assert properties["$is_qa_tester"] == True
else:
assert properties.get("$is_qa_tester") is None

@respx.mock
async def test_get_variant_value_does_not_track_exposure_on_fallback(self):
await self.setup_flags([])
Expand Down
39 changes: 39 additions & 0 deletions mixpanel/flags/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import re
import pytest
import random
import string
from .utils import generate_traceparent, normalized_hash

class TestUtils:
def test_traceparent_format_is_correct(self):
traceparent = generate_traceparent()

# W3C traceparent format: 00-{32 hex chars}-{16 hex chars}-{2 hex chars}
# https://www.w3.org/TR/trace-context/#traceparent-header
pattern = r'^00-[0-9a-f]{32}-[0-9a-f]{16}-01$'

assert re.match(pattern, traceparent), f"Traceparent '{traceparent}' does not match W3C format"

def test_traceparent_pseudo_randomness(self):
traceparents = set()

for _ in range(100):
traceparents.add(generate_traceparent())

assert len(traceparents) == 100, f"Expected 100 unique traceparents, got {len(traceparents)}"

@pytest.mark.parametrize("key,salt,expected_hash", [
("abc", "variant", 0.72),
("def", "variant", 0.21),
])
def test_normalized_hash_for_known_inputs(self, key, salt, expected_hash):
result = normalized_hash(key, salt)
assert result == expected_hash, f"Expected hash of {expected_hash} for '{key}' with salt '{salt}', got {result}"

def test_normalized_hash_is_between_0_and_1(self):
for _ in range(100):
length = random.randint(5, 20)
random_string = ''.join(random.choices(string.ascii_letters + string.digits, k=length))
random_salt = ''.join(random.choices(string.ascii_letters, k=10))
result = normalized_hash(random_string, random_salt)
assert 0.0 <= result < 1.0, f"Hash value {result} is not in range [0, 1] for input '{random_string}'"
14 changes: 11 additions & 3 deletions mixpanel/flags/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class Variant(BaseModel):
key: str
value: Any
is_control: bool
split: float
split: Optional[float] = None

class FlagTestUsers(BaseModel):
users: Dict[str, str]
Expand All @@ -32,6 +32,7 @@ class Rollout(BaseModel):
rollout_percentage: float
runtime_evaluation_definition: Optional[Dict[str, str]] = None
variant_override: Optional[VariantOverride] = None
variant_splits: Optional[Dict[str,float]] = None

class RuleSet(BaseModel):
variants: List[Variant]
Expand All @@ -41,16 +42,23 @@ class RuleSet(BaseModel):
class ExperimentationFlag(BaseModel):
id: str
name: str
key: str
key: str
status: str
project_id: int
ruleset: RuleSet
ruleset: RuleSet
context: str
experiment_id: Optional[str] = None
is_experiment_active: Optional[bool] = None


class SelectedVariant(BaseModel):
# variant_key can be None if being used as a fallback
variant_key: Optional[str] = None
variant_value: Any
experiment_id: Optional[str] = None
is_experiment_active: Optional[bool] = None
is_qa_tester: Optional[bool] = None


class ExperimentationFlags(BaseModel):
flags: List[ExperimentationFlag]
Expand Down
Loading