diff --git a/mixpanel/flags/local_feature_flags.py b/mixpanel/flags/local_feature_flags.py index 6d95334..4b70132 100644 --- a/mixpanel/flags/local_feature_flags.py +++ b/mixpanel/flags/local_feature_flags.py @@ -17,6 +17,7 @@ normalized_hash, prepare_common_query_params, EXPOSURE_EVENT, + generate_traceparent ) logger = logging.getLogger(__name__) @@ -168,10 +169,10 @@ def is_enabled(self, flag_key: str, context: Dict[str, Any]) -> bool: :param Dict[str, Any] context: Context dictionary containing user's distinct_id and any other attributes needed for rollout evaluation """ variant_value = self.get_variant_value(flag_key, False, context) - return bool(variant_value) + return variant_value == True def get_variant( - self, flag_key: str, fallback_value: SelectedVariant, context: Dict[str, Any] + self, flag_key: str, fallback_value: SelectedVariant, context: Dict[str, Any], report_exposure: bool = True ) -> SelectedVariant: """ Gets the selected variant for a feature flag @@ -179,6 +180,7 @@ def get_variant( :param str flag_key: The key of the feature flag to evaluate :param SelectedVariant fallback_value: The default variant to return if evaluation fails :param Dict[str, Any] context: Context dictionary containing user's distinct_id and any other attributes needed for rollout evaluation + :param bool report_exposure: Whether to track an exposure event for this flag evaluation. Defaults to True. """ start_time = time.perf_counter() flag_definition = self._flag_definitions.get(flag_key) @@ -193,20 +195,21 @@ def get_variant( ) return fallback_value + selected_variant: Optional[SelectedVariant] = None + if test_user_variant := self._get_variant_override_for_test_user( flag_definition, context ): - return test_user_variant - - if rollout := self._get_assigned_rollout( - flag_definition, context_value, context - ): - variant = self._get_assigned_variant( + selected_variant = test_user_variant + elif rollout := self._get_assigned_rollout(flag_definition, context_value, context): + selected_variant = self._get_assigned_variant( flag_definition, context_value, flag_key, rollout ) + + if report_exposure and selected_variant is not None: end_time = time.perf_counter() - self._track_exposure(flag_key, variant, end_time - start_time, context) - return variant + self._track_exposure(flag_key, selected_variant, end_time - start_time, context) + return selected_variant logger.info( f"{flag_definition.context} context {context_value} not eligible for any rollout for flag: {flag_key}" @@ -241,12 +244,17 @@ def _get_assigned_variant( ): return variant - variants = flag_definition.ruleset.variants hash_input = str(context_value) + flag_name variant_hash = normalized_hash(hash_input, "variant") + variants = [variant.model_copy(deep=True) for variant in flag_definition.ruleset.variants] + if rollout.variant_splits: + for variant in variants: + if variant.key in rollout.variant_splits: + variant.split = rollout.variant_splits[variant.key] + selected = variants[0] cumulative = 0.0 for variant in variants: @@ -255,7 +263,11 @@ def _get_assigned_variant( if variant_hash < cumulative: break - return SelectedVariant(variant_key=selected.key, variant_value=selected.value) + return SelectedVariant( + variant_key=selected.key, + variant_value=selected.value, + experiment_id=flag_definition.experiment_id, + is_experiment_active=flag_definition.is_experiment_active) def _get_assigned_rollout( self, @@ -304,15 +316,20 @@ def _get_matching_variant( for variant in flag.ruleset.variants: if variant_key.casefold() == variant.key.casefold(): return SelectedVariant( - variant_key=variant.key, variant_value=variant.value + variant_key=variant.key, + variant_value=variant.value, + experiment_id=flag.experiment_id, + is_experiment_active=flag.is_experiment_active, + is_qa_tester=True, ) return None async def _afetch_flag_definitions(self) -> None: try: start_time = datetime.now() + headers = {"traceparent": generate_traceparent()} response = await self._async_client.get( - self.FLAGS_DEFINITIONS_URL_PATH, params=self._request_params + self.FLAGS_DEFINITIONS_URL_PATH, params=self._request_params, headers=headers ) end_time = datetime.now() self._handle_response(response, start_time, end_time) @@ -322,8 +339,9 @@ async def _afetch_flag_definitions(self) -> None: def _fetch_flag_definitions(self) -> None: try: start_time = datetime.now() + headers = {"traceparent": generate_traceparent()} response = self._sync_client.get( - self.FLAGS_DEFINITIONS_URL_PATH, params=self._request_params + self.FLAGS_DEFINITIONS_URL_PATH, params=self._request_params, headers=headers ) end_time = datetime.now() self._handle_response(response, start_time, end_time) @@ -370,6 +388,9 @@ def _track_exposure( "$experiment_type": "feature_flag", "Flag evaluation mode": "local", "Variant fetch latency (ms)": latency_in_seconds * 1000, + "$experiment_id": variant.experiment_id, + "$is_experiment_active": variant.is_experiment_active, + "$is_qa_tester": variant.is_qa_tester, } self._tracker(distinct_id, EXPOSURE_EVENT, properties) diff --git a/mixpanel/flags/remote_feature_flags.py b/mixpanel/flags/remote_feature_flags.py index 5d7f5d9..af62c74 100644 --- a/mixpanel/flags/remote_feature_flags.py +++ b/mixpanel/flags/remote_feature_flags.py @@ -8,7 +8,7 @@ from asgiref.sync import sync_to_async from .types import RemoteFlagsConfig, SelectedVariant, RemoteFlagsResponse -from .utils import REQUEST_HEADERS, EXPOSURE_EVENT, prepare_common_query_params +from .utils import REQUEST_HEADERS, EXPOSURE_EVENT, prepare_common_query_params, generate_traceparent logger = logging.getLogger(__name__) logging.getLogger("httpx").setLevel(logging.ERROR) @@ -66,7 +66,8 @@ async def aget_variant( try: params = self._prepare_query_params(flag_key, context) start_time = datetime.now() - response = await self._async_client.get(self.FLAGS_URL_PATH, params=params) + headers = {"traceparent": generate_traceparent()} + response = await self._async_client.get(self.FLAGS_URL_PATH, params=params, headers=headers) end_time = datetime.now() self._instrument_call(start_time, end_time) selected_variant, is_fallback = self._handle_response( @@ -96,7 +97,7 @@ async def ais_enabled(self, flag_key: str, context: Dict[str, Any]) -> bool: :param Dict[str, Any] context: Context dictionary containing user attributes and rollout context """ variant_value = await self.aget_variant_value(flag_key, False, context) - return bool(variant_value) + return variant_value == True def get_variant_value( self, flag_key: str, fallback_value: Any, context: Dict[str, Any] @@ -126,7 +127,8 @@ def get_variant( try: params = self._prepare_query_params(flag_key, context) start_time = datetime.now() - response = self._sync_client.get(self.FLAGS_URL_PATH, params=params) + headers = {"traceparent": generate_traceparent()} + response = self._sync_client.get(self.FLAGS_URL_PATH, params=params, headers=headers) end_time = datetime.now() self._instrument_call(start_time, end_time) selected_variant, is_fallback = self._handle_response( @@ -152,7 +154,7 @@ def is_enabled(self, flag_key: str, context: Dict[str, Any]) -> bool: :param Dict[str, Any] context: Context dictionary containing user attributes and rollout context """ variant_value = self.get_variant_value(flag_key, False, context) - return bool(variant_value) + return variant_value == True def _prepare_query_params( self, flag_key: str, context: Dict[str, Any] diff --git a/mixpanel/flags/test_local_feature_flags.py b/mixpanel/flags/test_local_feature_flags.py index 9019a01..dba1d20 100644 --- a/mixpanel/flags/test_local_feature_flags.py +++ b/mixpanel/flags/test_local_feature_flags.py @@ -9,6 +9,7 @@ from .types import LocalFlagsConfig, ExperimentationFlag, RuleSet, Variant, Rollout, FlagTestUsers, ExperimentationFlags, VariantOverride from .local_feature_flags import LocalFeatureFlagsProvider + def create_test_flag( flag_key: str = "test_flag", context: str = "distinct_id", @@ -16,7 +17,10 @@ def create_test_flag( variant_override: Optional[VariantOverride] = None, rollout_percentage: float = 100.0, runtime_evaluation: Optional[Dict] = None, - test_users: Optional[Dict[str, str]] = None) -> ExperimentationFlag: + test_users: Optional[Dict[str, str]] = None, + experiment_id: Optional[str] = None, + is_experiment_active: Optional[bool] = None, + variant_splits: Optional[Dict[str, float]] = None) -> ExperimentationFlag: if variants is None: variants = [ @@ -27,7 +31,8 @@ def create_test_flag( rollouts = [Rollout( rollout_percentage=rollout_percentage, runtime_evaluation_definition=runtime_evaluation, - variant_override=variant_override + variant_override=variant_override, + variant_splits=variant_splits )] test_config = None @@ -47,7 +52,9 @@ def create_test_flag( status="active", project_id=123, ruleset=ruleset, - context=context + context=context, + experiment_id=experiment_id, + is_experiment_active=is_experiment_active ) @@ -216,6 +223,32 @@ async def test_get_variant_value_picks_correct_variant_with_hundred_percent_spli result = self._flags.get_variant_value("test_flag", "fallback", {"distinct_id": "user123"}) assert result == "variant_a" + @respx.mock + async def test_get_variant_value_picks_correct_variant_with_half_migrated_group_splits(self): + variants = [ + Variant(key="A", value="variant_a", is_control=False, split=100.0), + Variant(key="B", value="variant_b", is_control=False, split=0.0), + Variant(key="C", value="variant_c", is_control=False, split=0.0) + ] + variant_splits = {"A": 0.0, "B": 100.0, "C": 0.0} + flag = create_test_flag(variants=variants, rollout_percentage=100.0, variant_splits=variant_splits) + await self.setup_flags([flag]) + result = self._flags.get_variant_value("test_flag", "fallback", {"distinct_id": "user123"}) + assert result == "variant_b" + + @respx.mock + async def test_get_variant_value_picks_correct_variant_with_full_migrated_group_splits(self): + variants = [ + Variant(key="A", value="variant_a", is_control=False), + Variant(key="B", value="variant_b", is_control=False), + Variant(key="C", value="variant_c", is_control=False), + ] + variant_splits = {"A": 0.0, "B": 0.0, "C": 100.0} + flag = create_test_flag(variants=variants, rollout_percentage=100.0, variant_splits=variant_splits) + await self.setup_flags([flag]) + result = self._flags.get_variant_value("test_flag", "fallback", {"distinct_id": "user123"}) + assert result == "variant_c" + @respx.mock async def test_get_variant_value_picks_overriden_variant(self): variants = [ @@ -236,6 +269,43 @@ async def test_get_variant_value_tracks_exposure_when_variant_selected(self): _ = self._flags.get_variant_value("test_flag", "fallback", {"distinct_id": "user123"}) self._mock_tracker.assert_called_once() + @respx.mock + @pytest.mark.parametrize("experiment_id,is_experiment_active,use_qa_user", [ + ("exp-123", True, True), # QA tester with active experiment + ("exp-456", False, True), # QA tester with inactive experiment + ("exp-789", True, False), # Regular user with active experiment + ("exp-000", False, False), # Regular user with inactive experiment + (None, None, True), # QA tester with no experiment + (None, None, False), # Regular user with no experiment + ]) + async def test_get_variant_value_tracks_exposure_with_correct_properties(self, experiment_id, is_experiment_active, use_qa_user): + flag = create_test_flag( + experiment_id=experiment_id, + is_experiment_active=is_experiment_active, + test_users={"qa_user": "treatment"} + ) + + await self.setup_flags([flag]) + + distinct_id = "qa_user" if use_qa_user else "regular_user" + + with patch('mixpanel.flags.utils.normalized_hash') as mock_hash: + mock_hash.return_value = 0.5 + _ = self._flags.get_variant_value("test_flag", "fallback", {"distinct_id": distinct_id}) + + self._mock_tracker.assert_called_once() + + call_args = self._mock_tracker.call_args + properties = call_args[0][2] + + assert properties["$experiment_id"] == experiment_id + assert properties["$is_experiment_active"] == is_experiment_active + + if use_qa_user: + assert properties["$is_qa_tester"] == True + else: + assert properties.get("$is_qa_tester") is None + @respx.mock async def test_get_variant_value_does_not_track_exposure_on_fallback(self): await self.setup_flags([]) diff --git a/mixpanel/flags/test_utils.py b/mixpanel/flags/test_utils.py new file mode 100644 index 0000000..b60b514 --- /dev/null +++ b/mixpanel/flags/test_utils.py @@ -0,0 +1,23 @@ +import re +import pytest +import random +import string +from .utils import generate_traceparent, normalized_hash + +class TestUtils: + def test_traceparent_format_is_correct(self): + traceparent = generate_traceparent() + + # W3C traceparent format: 00-{32 hex chars}-{16 hex chars}-{2 hex chars} + # https://www.w3.org/TR/trace-context/#traceparent-header + pattern = r'^00-[0-9a-f]{32}-[0-9a-f]{16}-01$' + + assert re.match(pattern, traceparent), f"Traceparent '{traceparent}' does not match W3C format" + + @pytest.mark.parametrize("key,salt,expected_hash", [ + ("abc", "variant", 0.72), + ("def", "variant", 0.21), + ]) + def test_normalized_hash_for_known_inputs(self, key, salt, expected_hash): + result = normalized_hash(key, salt) + assert result == expected_hash, f"Expected hash of {expected_hash} for '{key}' with salt '{salt}', got {result}" \ No newline at end of file diff --git a/mixpanel/flags/types.py b/mixpanel/flags/types.py index 186371d..20fe6ad 100644 --- a/mixpanel/flags/types.py +++ b/mixpanel/flags/types.py @@ -20,7 +20,7 @@ class Variant(BaseModel): key: str value: Any is_control: bool - split: float + split: Optional[float] = 0.0 class FlagTestUsers(BaseModel): users: Dict[str, str] @@ -32,6 +32,7 @@ class Rollout(BaseModel): rollout_percentage: float runtime_evaluation_definition: Optional[Dict[str, str]] = None variant_override: Optional[VariantOverride] = None + variant_splits: Optional[Dict[str,float]] = None class RuleSet(BaseModel): variants: List[Variant] @@ -41,16 +42,23 @@ class RuleSet(BaseModel): class ExperimentationFlag(BaseModel): id: str name: str - key: str + key: str status: str project_id: int - ruleset: RuleSet + ruleset: RuleSet context: str + experiment_id: Optional[str] = None + is_experiment_active: Optional[bool] = None + class SelectedVariant(BaseModel): # variant_key can be None if being used as a fallback variant_key: Optional[str] = None variant_value: Any + experiment_id: Optional[str] = None + is_experiment_active: Optional[bool] = None + is_qa_tester: Optional[bool] = None + class ExperimentationFlags(BaseModel): flags: List[ExperimentationFlag] diff --git a/mixpanel/flags/utils.py b/mixpanel/flags/utils.py index 987392b..863a705 100644 --- a/mixpanel/flags/utils.py +++ b/mixpanel/flags/utils.py @@ -1,3 +1,5 @@ +import uuid +import httpx from typing import Dict EXPOSURE_EVENT = "$experiment_started" @@ -47,4 +49,18 @@ def prepare_common_query_params(token: str, sdk_version: str) -> Dict[str, str]: 'token': token } - return params \ No newline at end of file + return params + +def generate_traceparent() -> str: + """Generates a W3C traceparent header for easy interop with distributed tracing systems i.e Open Telemetry + https://www.w3.org/TR/trace-context/#traceparent-header + :return: A traceparent string + """ + trace_id = uuid.uuid4().hex + span_id = uuid.uuid4().hex[:16] + + # Trace flags: '01' for sampled + trace_flags = '01' + + traceparent = f"00-{trace_id}-{span_id}-{trace_flags}" + return traceparent \ No newline at end of file