diff --git a/CHANGES.txt b/CHANGES.txt index 259700b..cfdb755 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -1,3 +1,6 @@ +v5.0.0b2 +* Update local flags evaluation to not use threadpool for exposure event tracking and add some docs + v5.0.0b1 * Added initial feature flagging support diff --git a/mixpanel/__init__.py b/mixpanel/__init__.py index 8faac12..decc461 100644 --- a/mixpanel/__init__.py +++ b/mixpanel/__init__.py @@ -30,7 +30,7 @@ from .flags.remote_feature_flags import RemoteFeatureFlagsProvider from .flags.types import LocalFlagsConfig, RemoteFlagsConfig -__version__ = '5.0.0b1' +__version__ = '5.0.0b2' logger = logging.getLogger(__name__) diff --git a/mixpanel/flags/local_feature_flags.py b/mixpanel/flags/local_feature_flags.py index 40dc8b8..6d95334 100644 --- a/mixpanel/flags/local_feature_flags.py +++ b/mixpanel/flags/local_feature_flags.py @@ -5,24 +5,44 @@ import threading from datetime import datetime, timedelta from typing import Dict, Any, Callable, Optional -from concurrent.futures import Future, ThreadPoolExecutor -from .types import ExperimentationFlag, ExperimentationFlags, SelectedVariant, LocalFlagsConfig, Rollout -from .utils import REQUEST_HEADERS, normalized_hash, prepare_common_query_params, EXPOSURE_EVENT +from .types import ( + ExperimentationFlag, + ExperimentationFlags, + SelectedVariant, + LocalFlagsConfig, + Rollout, +) +from .utils import ( + REQUEST_HEADERS, + normalized_hash, + prepare_common_query_params, + EXPOSURE_EVENT, +) logger = logging.getLogger(__name__) logging.getLogger("httpx").setLevel(logging.ERROR) + class LocalFeatureFlagsProvider: FLAGS_DEFINITIONS_URL_PATH = "/flags/definitions" - def __init__(self, token: str, config: LocalFlagsConfig, version: str, tracker: Callable) -> None: + def __init__( + self, token: str, config: LocalFlagsConfig, version: str, tracker: Callable + ) -> None: + """ + Initializes the LocalFeatureFlagsProvider + :param str token: your project's Mixpanel token + :param LocalFlagsConfig config: configuration options for the local feature flags provider + :param str version: the version of the Mixpanel library being used, just for tracking + :param Callable tracker: A function used to track flags exposure events to mixpanel + """ self._token: str = token self._config: LocalFlagsConfig = config self._version = version self._tracker: Callable = tracker - self._executor: ThreadPoolExecutor = config.custom_executor or ThreadPoolExecutor(max_workers=5) self._flag_definitions: Dict[str, ExperimentationFlag] = dict() + self._are_flags_ready = False httpx_client_parameters = { "base_url": f"https://{config.api_host}", @@ -33,42 +53,63 @@ def __init__(self, token: str, config: LocalFlagsConfig, version: str, tracker: self._request_params = prepare_common_query_params(self._token, self._version) - self._async_client: httpx.AsyncClient = httpx.AsyncClient(**httpx_client_parameters) + self._async_client: httpx.AsyncClient = httpx.AsyncClient( + **httpx_client_parameters + ) self._sync_client: httpx.Client = httpx.Client(**httpx_client_parameters) self._async_polling_task: Optional[asyncio.Task] = None - self._sync_polling_task: Optional[Future] = None + self._sync_polling_task: Optional[threading.Thread] = None self._sync_stop_event = threading.Event() def start_polling_for_definitions(self): + """ + Fetches flag definitions for the current project. + If configured by the caller, starts a background thread to poll for updates at regular intervals, if one does not already exist. + """ self._fetch_flag_definitions() if self._config.enable_polling: if not self._sync_polling_task and not self._async_polling_task: self._sync_stop_event.clear() - self._sync_polling_task = self._executor.submit(self._start_continuous_polling) + self._sync_polling_task = threading.Thread( + target=self._start_continuous_polling, daemon=True + ) + self._sync_polling_task.start() else: - logging.error("A polling task is already running") + logging.warning("A polling task is already running") def stop_polling_for_definitions(self): + """ + If there exists a reference to a background thread polling for flag definition updates, signal it to stop and clear the reference. + Once stopped, the polling thread cannot be restarted. + """ if self._sync_polling_task: self._sync_stop_event.set() - self._sync_polling_task.cancel() self._sync_polling_task = None else: logging.info("There is no polling task to cancel.") async def astart_polling_for_definitions(self): + """ + Fetches flag definitions for the current project. + If configured by the caller, starts an async task on the event loop to poll for updates at regular intervals, if one does not already exist. + """ await self._afetch_flag_definitions() if self._config.enable_polling: if not self._sync_polling_task and not self._async_polling_task: - self._async_polling_task = asyncio.create_task(self._astart_continuous_polling()) + self._async_polling_task = asyncio.create_task( + self._astart_continuous_polling() + ) else: logging.error("A polling task is already running") async def astop_polling_for_definitions(self): + """ + If there exists an async task to poll for flag definition updates, cancel the task and clear the reference to it. + """ if self._async_polling_task: self._async_polling_task.cancel() self._async_polling_task = None @@ -76,7 +117,9 @@ async def astop_polling_for_definitions(self): logging.info("There is no polling task to cancel.") async def _astart_continuous_polling(self): - logging.info(f"Initialized async polling for flag definition updates every '{self._config.polling_interval_in_seconds}' seconds") + logging.info( + f"Initialized async polling for flag definition updates every '{self._config.polling_interval_in_seconds}' seconds" + ) try: while True: await asyncio.sleep(self._config.polling_interval_in_seconds) @@ -85,29 +128,58 @@ async def _astart_continuous_polling(self): logging.info("Async polling was cancelled") def _start_continuous_polling(self): - logging.info(f"Initialized sync polling for flag definition updates every '{self._config.polling_interval_in_seconds}' seconds") + logging.info( + f"Initialized sync polling for flag definition updates every '{self._config.polling_interval_in_seconds}' seconds" + ) while not self._sync_stop_event.is_set(): - if self._sync_stop_event.wait(timeout=self._config.polling_interval_in_seconds): + if self._sync_stop_event.wait( + timeout=self._config.polling_interval_in_seconds + ): break self._fetch_flag_definitions() def are_flags_ready(self) -> bool: """ - Check if flag definitions have been loaded and are ready for use. - :return: True if flag definitions are populated, False otherwise. + Check if the call to fetch flag definitions has been made successfully. """ - return bool(self._flag_definitions) + return self._are_flags_ready - def get_variant_value(self, flag_key: str, fallback_value: Any, context: Dict[str, Any]) -> Any: - variant = self.get_variant(flag_key, SelectedVariant(variant_value=fallback_value), context) + def get_variant_value( + self, flag_key: str, fallback_value: Any, context: Dict[str, Any] + ) -> Any: + """ + Get the value of a feature flag variant. + + :param str flag_key: The key of the feature flag to evaluate + :param Any fallback_value: The default value to return if the flag is not found or evaluation fails + :param Dict[str, Any] context: Context dictionary containing user's distinct_id and any other attributes needed for rollout evaluation + """ + variant = self.get_variant( + flag_key, SelectedVariant(variant_value=fallback_value), context + ) return variant.variant_value def is_enabled(self, flag_key: str, context: Dict[str, Any]) -> bool: + """ + Check if a feature flag is enabled for the given context. + + :param str flag_key: The key of the feature flag to check + :param Dict[str, Any] context: Context dictionary containing user's distinct_id and any other attributes needed for rollout evaluation + """ variant_value = self.get_variant_value(flag_key, False, context) return bool(variant_value) - def get_variant(self, flag_key: str, fallback_value: SelectedVariant, context: Dict[str, Any]) -> SelectedVariant: + def get_variant( + self, flag_key: str, fallback_value: SelectedVariant, context: Dict[str, Any] + ) -> SelectedVariant: + """ + Gets the selected variant for a feature flag + + :param str flag_key: The key of the feature flag to evaluate + :param SelectedVariant fallback_value: The default variant to return if evaluation fails + :param Dict[str, Any] context: Context dictionary containing user's distinct_id and any other attributes needed for rollout evaluation + """ start_time = time.perf_counter() flag_definition = self._flag_definitions.get(flag_key) @@ -115,23 +187,35 @@ def get_variant(self, flag_key: str, fallback_value: SelectedVariant, context: D logger.warning(f"Cannot find flag definition for key: '{flag_key}'") return fallback_value - if not(context_value := context.get(flag_definition.context)): - logger.warning(f"The rollout context, '{flag_definition.context}' for flag, '{flag_key}' is not present in the supplied context dictionary") + if not (context_value := context.get(flag_definition.context)): + logger.warning( + f"The rollout context, '{flag_definition.context}' for flag, '{flag_key}' is not present in the supplied context dictionary" + ) return fallback_value - if test_user_variant := self._get_variant_override_for_test_user(flag_definition, context): + if test_user_variant := self._get_variant_override_for_test_user( + flag_definition, context + ): return test_user_variant - if rollout := self._get_assigned_rollout(flag_definition, context_value, context): - variant = self._get_assigned_variant(flag_definition, context_value, flag_key, rollout) + if rollout := self._get_assigned_rollout( + flag_definition, context_value, context + ): + variant = self._get_assigned_variant( + flag_definition, context_value, flag_key, rollout + ) end_time = time.perf_counter() - self.track_exposure(flag_key, variant, end_time - start_time, context) + self._track_exposure(flag_key, variant, end_time - start_time, context) return variant - logger.info(f"{flag_definition.context} context {context_value} not eligible for any rollout for flag: {flag_key}") + logger.info( + f"{flag_definition.context} context {context_value} not eligible for any rollout for flag: {flag_key}" + ) return fallback_value - def _get_variant_override_for_test_user(self, flag_definition: ExperimentationFlag, context: Dict[str, Any]) -> Optional[SelectedVariant]: + def _get_variant_override_for_test_user( + self, flag_definition: ExperimentationFlag, context: Dict[str, Any] + ) -> Optional[SelectedVariant]: """""" if not flag_definition.ruleset.test or not flag_definition.ruleset.test.users: return None @@ -144,9 +228,17 @@ def _get_variant_override_for_test_user(self, flag_definition: ExperimentationFl return self._get_matching_variant(variant_key, flag_definition) - def _get_assigned_variant(self, flag_definition: ExperimentationFlag, context_value: Any, flag_name: str, rollout: Rollout) -> SelectedVariant: + def _get_assigned_variant( + self, + flag_definition: ExperimentationFlag, + context_value: Any, + flag_name: str, + rollout: Rollout, + ) -> SelectedVariant: if rollout.variant_override: - if variant := self._get_matching_variant(rollout.variant_override.key, flag_definition): + if variant := self._get_matching_variant( + rollout.variant_override.key, flag_definition + ): return variant variants = flag_definition.ruleset.variants @@ -165,18 +257,28 @@ def _get_assigned_variant(self, flag_definition: ExperimentationFlag, context_va return SelectedVariant(variant_key=selected.key, variant_value=selected.value) - def _get_assigned_rollout(self, flag_definition: ExperimentationFlag, context_value: Any, context: Dict[str, Any]) -> Optional[Rollout]: + def _get_assigned_rollout( + self, + flag_definition: ExperimentationFlag, + context_value: Any, + context: Dict[str, Any], + ) -> Optional[Rollout]: hash_input = str(context_value) + flag_definition.key rollout_hash = normalized_hash(hash_input, "rollout") for rollout in flag_definition.ruleset.rollout: - if rollout_hash < rollout.rollout_percentage and self._is_runtime_evaluation_satisfied(rollout, context): + if ( + rollout_hash < rollout.rollout_percentage + and self._is_runtime_evaluation_satisfied(rollout, context) + ): return rollout return None - def _is_runtime_evaluation_satisfied(self, rollout: Rollout, context: Dict[str, Any]) -> bool: + def _is_runtime_evaluation_satisfied( + self, rollout: Rollout, context: Dict[str, Any] + ) -> bool: if not rollout.runtime_evaluation_definition: return True @@ -196,16 +298,22 @@ def _is_runtime_evaluation_satisfied(self, rollout: Rollout, context: Dict[str, return True - def _get_matching_variant(self, variant_key: str, flag: ExperimentationFlag) -> Optional[SelectedVariant]: + def _get_matching_variant( + self, variant_key: str, flag: ExperimentationFlag + ) -> Optional[SelectedVariant]: for variant in flag.ruleset.variants: if variant_key.casefold() == variant.key.casefold(): - return SelectedVariant(variant_key=variant.key, variant_value=variant.value) + return SelectedVariant( + variant_key=variant.key, variant_value=variant.value + ) return None async def _afetch_flag_definitions(self) -> None: try: start_time = datetime.now() - response = await self._async_client.get(self.FLAGS_DEFINITIONS_URL_PATH, params=self._request_params) + response = await self._async_client.get( + self.FLAGS_DEFINITIONS_URL_PATH, params=self._request_params + ) end_time = datetime.now() self._handle_response(response, start_time, end_time) except Exception: @@ -214,15 +322,21 @@ async def _afetch_flag_definitions(self) -> None: def _fetch_flag_definitions(self) -> None: try: start_time = datetime.now() - response = self._sync_client.get(self.FLAGS_DEFINITIONS_URL_PATH, params=self._request_params) + response = self._sync_client.get( + self.FLAGS_DEFINITIONS_URL_PATH, params=self._request_params + ) end_time = datetime.now() self._handle_response(response, start_time, end_time) except Exception: logger.exception("Failed to fetch feature flag definitions") - def _handle_response(self, response: httpx.Response, start_time: datetime, end_time: datetime) -> None: + def _handle_response( + self, response: httpx.Response, start_time: datetime, end_time: datetime + ) -> None: request_duration: timedelta = end_time - start_time - logging.info(f"Request started at '{start_time.isoformat()}', completed at '{end_time.isoformat()}', duration: '{request_duration.total_seconds():.3f}s'") + logging.info( + f"Request started at '{start_time.isoformat()}', completed at '{end_time.isoformat()}', duration: '{request_duration.total_seconds():.3f}s'" + ) response.raise_for_status() @@ -237,21 +351,32 @@ def _handle_response(self, response: httpx.Response, start_time: datetime, end_t logger.exception("Failed to parse flag definitions") self._flag_definitions = flags - logger.info(f"Successfully fetched {len(self._flag_definitions)} flag definitions") - - - def track_exposure(self, flag_key: str, variant: SelectedVariant, latency_in_seconds: float, context: Dict[str, Any]): + self._are_flags_ready = True + logger.debug( + f"Successfully fetched {len(self._flag_definitions)} flag definitions" + ) + + def _track_exposure( + self, + flag_key: str, + variant: SelectedVariant, + latency_in_seconds: float, + context: Dict[str, Any], + ): if distinct_id := context.get("distinct_id"): properties = { - 'Experiment name': flag_key, - 'Variant name': variant.variant_key, - '$experiment_type': 'feature_flag', + "Experiment name": flag_key, + "Variant name": variant.variant_key, + "$experiment_type": "feature_flag", "Flag evaluation mode": "local", - "Variant fetch latency (ms)": latency_in_seconds * 1000 + "Variant fetch latency (ms)": latency_in_seconds * 1000, } - self._executor.submit(self._tracker, distinct_id, EXPOSURE_EVENT, properties) + + self._tracker(distinct_id, EXPOSURE_EVENT, properties) else: - logging.error("Cannot track exposure event without a distinct_id in the context") + logging.error( + "Cannot track exposure event without a distinct_id in the context" + ) async def __aenter__(self): return self diff --git a/mixpanel/flags/remote_feature_flags.py b/mixpanel/flags/remote_feature_flags.py index 1969bb1..5d7f5d9 100644 --- a/mixpanel/flags/remote_feature_flags.py +++ b/mixpanel/flags/remote_feature_flags.py @@ -3,26 +3,27 @@ import json import urllib.parse import asyncio -from datetime import datetime +from datetime import datetime from typing import Dict, Any, Callable from asgiref.sync import sync_to_async from .types import RemoteFlagsConfig, SelectedVariant, RemoteFlagsResponse -from concurrent.futures import ThreadPoolExecutor from .utils import REQUEST_HEADERS, EXPOSURE_EVENT, prepare_common_query_params logger = logging.getLogger(__name__) logging.getLogger("httpx").setLevel(logging.ERROR) + class RemoteFeatureFlagsProvider: FLAGS_URL_PATH = "/flags" - def __init__(self, token: str, config: RemoteFlagsConfig, version: str, tracker: Callable) -> None: + def __init__( + self, token: str, config: RemoteFlagsConfig, version: str, tracker: Callable + ) -> None: self._token: str = token self._config: RemoteFlagsConfig = config self._version: str = version self._tracker: Callable = tracker - self._executor: ThreadPoolExecutor = config.custom_executor or ThreadPoolExecutor(max_workers=5) httpx_client_parameters = { "base_url": f"https://{config.api_host}", @@ -31,27 +32,56 @@ def __init__(self, token: str, config: RemoteFlagsConfig, version: str, tracker: "timeout": httpx.Timeout(config.request_timeout_in_seconds), } - self._async_client: httpx.AsyncClient = httpx.AsyncClient(**httpx_client_parameters) + self._async_client: httpx.AsyncClient = httpx.AsyncClient( + **httpx_client_parameters + ) self._sync_client: httpx.Client = httpx.Client(**httpx_client_parameters) self._request_params_base = prepare_common_query_params(self._token, version) - async def aget_variant_value(self, flag_key: str, fallback_value: Any, context: Dict[str, Any]) -> Any: - variant = await self.aget_variant(flag_key, SelectedVariant(variant_value=fallback_value), context) + async def aget_variant_value( + self, flag_key: str, fallback_value: Any, context: Dict[str, Any] + ) -> Any: + """ + Gets the selected variant value of a feature flag variant for the current user context from remote server. + + :param str flag_key: The key of the feature flag to evaluate + :param Any fallback_value: The default value to return if the flag is not found or evaluation fails + :param Dict[str, Any] context: Context dictionary containing user attributes and rollout context + """ + variant = await self.aget_variant( + flag_key, SelectedVariant(variant_value=fallback_value), context + ) return variant.variant_value - async def aget_variant(self, flag_key: str, fallback_value: SelectedVariant, context: Dict[str, Any]) -> SelectedVariant: + async def aget_variant( + self, flag_key: str, fallback_value: SelectedVariant, context: Dict[str, Any] + ) -> SelectedVariant: + """ + Asynchronously gets the selected variant of a feature flag variant for the current user context from remote server. + + :param str flag_key: The key of the feature flag to evaluate + :param SelectedVariant fallback_value: The default variant to return if evaluation fails + :param Dict[str, Any] context: Context dictionary containing user attributes and rollout context + """ try: params = self._prepare_query_params(flag_key, context) start_time = datetime.now() response = await self._async_client.get(self.FLAGS_URL_PATH, params=params) end_time = datetime.now() self._instrument_call(start_time, end_time) - selected_variant, is_fallback = self._handle_response(flag_key, fallback_value, response) + selected_variant, is_fallback = self._handle_response( + flag_key, fallback_value, response + ) if not is_fallback and (distinct_id := context.get("distinct_id")): - properties = self._build_tracking_properties(flag_key, selected_variant, start_time, end_time) + properties = self._build_tracking_properties( + flag_key, selected_variant, start_time, end_time + ) asyncio.create_task( - sync_to_async(self._tracker, executor=self._executor, thread_sensitive=False)(distinct_id, EXPOSURE_EVENT, properties)) + sync_to_async(self._tracker, thread_sensitive=False)( + distinct_id, EXPOSURE_EVENT, properties + ) + ) return selected_variant except Exception: @@ -59,25 +89,55 @@ async def aget_variant(self, flag_key: str, fallback_value: SelectedVariant, con return fallback_value async def ais_enabled(self, flag_key: str, context: Dict[str, Any]) -> bool: + """ + Asynchronously checks if a feature flag is enabled for the given context. + + :param str flag_key: The key of the feature flag to check + :param Dict[str, Any] context: Context dictionary containing user attributes and rollout context + """ variant_value = await self.aget_variant_value(flag_key, False, context) return bool(variant_value) - def get_variant_value(self, flag_key: str, fallback_value: Any, context: Dict[str, Any]) -> Any: - variant = self.get_variant(flag_key, SelectedVariant(variant_value=fallback_value), context) + def get_variant_value( + self, flag_key: str, fallback_value: Any, context: Dict[str, Any] + ) -> Any: + """ + Synchronously gets the value of a feature flag variant from remote server. + + :param str flag_key: The key of the feature flag to evaluate + :param Any fallback_value: The default value to return if the flag is not found or evaluation fails + :param Dict[str, Any] context: Context dictionary containing user attributes and rollout context + """ + variant = self.get_variant( + flag_key, SelectedVariant(variant_value=fallback_value), context + ) return variant.variant_value - def get_variant(self, flag_key: str, fallback_value: SelectedVariant, context: Dict[str, Any]) -> SelectedVariant: + def get_variant( + self, flag_key: str, fallback_value: SelectedVariant, context: Dict[str, Any] + ) -> SelectedVariant: + """ + Synchronously gets the selected variant for a feature flag from remote server. + + :param str flag_key: The key of the feature flag to evaluate + :param SelectedVariant fallback_value: The default variant to return if evaluation fails + :param Dict[str, Any] context: Context dictionary containing user attributes and rollout context + """ try: params = self._prepare_query_params(flag_key, context) start_time = datetime.now() response = self._sync_client.get(self.FLAGS_URL_PATH, params=params) end_time = datetime.now() self._instrument_call(start_time, end_time) - selected_variant, is_fallback = self._handle_response(flag_key, fallback_value, response) + selected_variant, is_fallback = self._handle_response( + flag_key, fallback_value, response + ) if not is_fallback and (distinct_id := context.get("distinct_id")): - properties = self._build_tracking_properties(flag_key, selected_variant, start_time, end_time) - self._executor.submit(self._tracker, distinct_id, EXPOSURE_EVENT, properties) + properties = self._build_tracking_properties( + flag_key, selected_variant, start_time, end_time + ) + self._tracker(distinct_id, EXPOSURE_EVENT, properties) return selected_variant except Exception: @@ -85,41 +145,56 @@ def get_variant(self, flag_key: str, fallback_value: SelectedVariant, context: D return fallback_value def is_enabled(self, flag_key: str, context: Dict[str, Any]) -> bool: + """ + Synchronously checks if a feature flag is enabled for the given context. + + :param str flag_key: The key of the feature flag to check + :param Dict[str, Any] context: Context dictionary containing user attributes and rollout context + """ variant_value = self.get_variant_value(flag_key, False, context) return bool(variant_value) - def _prepare_query_params(self, flag_key: str, context: Dict[str, Any]) -> Dict[str, str]: + def _prepare_query_params( + self, flag_key: str, context: Dict[str, Any] + ) -> Dict[str, str]: params = self._request_params_base.copy() - context_json = json.dumps(context).encode('utf-8') + context_json = json.dumps(context).encode("utf-8") url_encoded_context = urllib.parse.quote(context_json) - params.update({ - 'flag_key': flag_key, - 'context': url_encoded_context - }) + params.update({"flag_key": flag_key, "context": url_encoded_context}) return params def _instrument_call(self, start_time: datetime, end_time: datetime) -> None: request_duration = end_time - start_time formatted_start_time = start_time.isoformat() formatted_end_time = end_time.isoformat() - logging.info(f"Request started at '{formatted_start_time}', completed at '{formatted_end_time}', duration: '{request_duration.total_seconds():.3f}s'") - - def _build_tracking_properties(self, flag_key: str, variant: SelectedVariant, start_time: datetime, end_time: datetime) -> Dict[str, Any]: + logging.info( + f"Request started at '{formatted_start_time}', completed at '{formatted_end_time}', duration: '{request_duration.total_seconds():.3f}s'" + ) + + def _build_tracking_properties( + self, + flag_key: str, + variant: SelectedVariant, + start_time: datetime, + end_time: datetime, + ) -> Dict[str, Any]: request_duration = end_time - start_time formatted_start_time = start_time.isoformat() formatted_end_time = end_time.isoformat() return { - 'Experiment name': flag_key, - 'Variant name': variant.variant_key, - '$experiment_type': 'feature_flag', + "Experiment name": flag_key, + "Variant name": variant.variant_key, + "$experiment_type": "feature_flag", "Flag evaluation mode": "remote", "Variant fetch start time": formatted_start_time, "Variant fetch complete time": formatted_end_time, "Variant fetch latency (ms)": request_duration.total_seconds() * 1000, } - def _handle_response(self, flag_key: str, fallback_value: SelectedVariant, response: httpx.Response) -> tuple[SelectedVariant, bool]: + def _handle_response( + self, flag_key: str, fallback_value: SelectedVariant, response: httpx.Response + ) -> tuple[SelectedVariant, bool]: response.raise_for_status() flags_response = RemoteFlagsResponse.model_validate(response.json()) @@ -127,7 +202,9 @@ def _handle_response(self, flag_key: str, fallback_value: SelectedVariant, respo if flag_key in flags_response.flags: return flags_response.flags[flag_key], False else: - logging.warning(f"Flag '{flag_key}' not found in remote response. Returning fallback, '{fallback_value}'") + logging.warning( + f"Flag '{flag_key}' not found in remote response. Returning fallback, '{fallback_value}'" + ) return fallback_value, True def __enter__(self): diff --git a/mixpanel/flags/test_local_feature_flags.py b/mixpanel/flags/test_local_feature_flags.py index 63a437f..9019a01 100644 --- a/mixpanel/flags/test_local_feature_flags.py +++ b/mixpanel/flags/test_local_feature_flags.py @@ -60,17 +60,25 @@ def create_flags_response(flags: List[ExperimentationFlag]) -> httpx.Response: @pytest.mark.asyncio class TestLocalFeatureFlagsProviderAsync: - async def get_flags_provider(self, config: LocalFlagsConfig) -> LocalFeatureFlagsProvider: - mock_tracker = Mock() - flags_provider = LocalFeatureFlagsProvider("test-token", config, "1.0.0", mock_tracker) - await flags_provider.astart_polling_for_definitions() - return flags_provider + @pytest.fixture(autouse=True) + async def setup_method(self): + self._mock_tracker = Mock() + + config_no_polling = LocalFlagsConfig(enable_polling=False) + self._flags = LocalFeatureFlagsProvider("test-token", config_no_polling, "1.0.0", self._mock_tracker) + + config_with_polling = LocalFlagsConfig(enable_polling=True, polling_interval_in_seconds=0) + self._flags_with_polling = LocalFeatureFlagsProvider("test-token", config_with_polling, "1.0.0", self._mock_tracker) + + yield + + await self._flags.__aexit__(None, None, None) + await self._flags_with_polling.__aexit__(None, None, None) async def setup_flags(self, flags: List[ExperimentationFlag]): respx.get("https://api.mixpanel.com/flags/definitions").mock( return_value=create_flags_response(flags)) - - return await self.get_flags_provider(LocalFlagsConfig(enable_polling=False)) + await self._flags.astart_polling_for_definitions() async def setup_flags_with_polling(self, flags_in_order: List[List[ExperimentationFlag]] = [[]]): responses = [create_flags_response(flag) for flag in flags_in_order] @@ -81,14 +89,13 @@ async def setup_flags_with_polling(self, flags_in_order: List[List[Experimentati repeat(responses[-1]), ) ) - - return await self.get_flags_provider(LocalFlagsConfig(enable_polling=True, polling_interval_in_seconds=0)) + await self._flags_with_polling.astart_polling_for_definitions() @respx.mock async def test_get_variant_value_returns_fallback_when_no_flag_definitions(self): - flags = await self.setup_flags([]) - result = flags.get_variant_value("nonexistent_flag", "control", {"distinct_id": "user123"}) + await self.setup_flags([]) + result = self._flags.get_variant_value("nonexistent_flag", "control", {"distinct_id": "user123"}) assert result == "control" @respx.mock @@ -97,29 +104,29 @@ async def test_get_variant_value_returns_fallback_if_flag_definition_call_fails( return_value=httpx.Response(status_code=500) ) - flags = await self.get_flags_provider(LocalFlagsConfig(enable_polling=False)) - result = flags.get_variant_value("nonexistent_flag", "control", {"distinct_id": "user123"}) + await self._flags.astart_polling_for_definitions() + result = self._flags.get_variant_value("nonexistent_flag", "control", {"distinct_id": "user123"}) assert result == "control" @respx.mock async def test_get_variant_value_returns_fallback_when_flag_does_not_exist(self): other_flag = create_test_flag("other_flag") - flags = await self.setup_flags([other_flag]) - result = flags.get_variant_value("nonexistent_flag", "control", {"distinct_id": "user123"}) + await self.setup_flags([other_flag]) + result = self._flags.get_variant_value("nonexistent_flag", "control", {"distinct_id": "user123"}) assert result == "control" @respx.mock async def test_get_variant_value_returns_fallback_when_no_context(self): flag = create_test_flag(context="distinct_id") - flags = await self.setup_flags([flag]) - result = flags.get_variant_value("test_flag", "fallback", {}) + await self.setup_flags([flag]) + result = self._flags.get_variant_value("test_flag", "fallback", {}) assert result == "fallback" @respx.mock async def test_get_variant_value_returns_fallback_when_wrong_context_key(self): flag = create_test_flag(context="user_id") - flags = await self.setup_flags([flag]) - result = flags.get_variant_value("test_flag", "fallback", {"distinct_id": "user123"}) + await self.setup_flags([flag]) + result = self._flags.get_variant_value("test_flag", "fallback", {"distinct_id": "user123"}) assert result == "fallback" @respx.mock @@ -133,8 +140,8 @@ async def test_get_variant_value_returns_test_user_variant_when_configured(self) test_users={"test_user": "treatment"} ) - flags = await self.setup_flags([flag]) - result = flags.get_variant_value("test_flag", "control", {"distinct_id": "test_user"}) + await self.setup_flags([flag]) + result = self._flags.get_variant_value("test_flag", "control", {"distinct_id": "test_user"}) assert result == "true" @respx.mock @@ -147,31 +154,31 @@ async def test_get_variant_value_returns_fallback_when_test_user_variant_not_con variants=variants, test_users={"test_user": "nonexistent_variant"} ) - flags = await self.setup_flags([flag]) + await self.setup_flags([flag]) with patch('mixpanel.flags.utils.normalized_hash') as mock_hash: mock_hash.return_value = 0.5 - result = flags.get_variant_value("test_flag", "fallback", {"distinct_id": "test_user"}) + result = self._flags.get_variant_value("test_flag", "fallback", {"distinct_id": "test_user"}) assert result == "false" @respx.mock async def test_get_variant_value_returns_fallback_when_rollout_percentage_zero(self): flag = create_test_flag(rollout_percentage=0.0) - flags = await self.setup_flags([flag]) - result = flags.get_variant_value("test_flag", "fallback", {"distinct_id": "user123"}) + await self.setup_flags([flag]) + result = self._flags.get_variant_value("test_flag", "fallback", {"distinct_id": "user123"}) assert result == "fallback" @respx.mock async def test_get_variant_value_returns_variant_when_rollout_percentage_hundred(self): flag = create_test_flag(rollout_percentage=100.0) - flags = await self.setup_flags([flag]) - result = flags.get_variant_value("test_flag", "fallback", {"distinct_id": "user123"}) + await self.setup_flags([flag]) + result = self._flags.get_variant_value("test_flag", "fallback", {"distinct_id": "user123"}) assert result != "fallback" @respx.mock async def test_get_variant_value_respects_runtime_evaluation_satisfied(self): runtime_eval = {"plan": "premium", "region": "US"} flag = create_test_flag(runtime_evaluation=runtime_eval) - flags = await self.setup_flags([flag]) + await self.setup_flags([flag]) context = { "distinct_id": "user123", "custom_properties": { @@ -179,14 +186,14 @@ async def test_get_variant_value_respects_runtime_evaluation_satisfied(self): "region": "US" } } - result = flags.get_variant_value("test_flag", "fallback", context) + result = self._flags.get_variant_value("test_flag", "fallback", context) assert result != "fallback" @respx.mock async def test_get_variant_value_returns_fallback_when_runtime_evaluation_not_satisfied(self): runtime_eval = {"plan": "premium", "region": "US"} flag = create_test_flag(runtime_evaluation=runtime_eval) - flags = await self.setup_flags([flag]) + await self.setup_flags([flag]) context = { "distinct_id": "user123", "custom_properties": { @@ -194,7 +201,7 @@ async def test_get_variant_value_returns_fallback_when_runtime_evaluation_not_sa "region": "US" } } - result = flags.get_variant_value("test_flag", "fallback", context) + result = self._flags.get_variant_value("test_flag", "fallback", context) assert result == "fallback" @respx.mock @@ -205,8 +212,8 @@ async def test_get_variant_value_picks_correct_variant_with_hundred_percent_spli Variant(key="C", value="variant_c", is_control=False, split=0.0) ] flag = create_test_flag(variants=variants, rollout_percentage=100.0) - flags = await self.setup_flags([flag]) - result = flags.get_variant_value("test_flag", "fallback", {"distinct_id": "user123"}) + await self.setup_flags([flag]) + result = self._flags.get_variant_value("test_flag", "fallback", {"distinct_id": "user123"}) assert result == "variant_a" @respx.mock @@ -216,45 +223,49 @@ async def test_get_variant_value_picks_overriden_variant(self): Variant(key="B", value="variant_b", is_control=False, split=0.0), ] flag = create_test_flag(variants=variants, variant_override=VariantOverride(key="B")) - flags = await self.setup_flags([flag]) - result = flags.get_variant_value("test_flag", "control", {"distinct_id": "user123"}) + await self.setup_flags([flag]) + result = self._flags.get_variant_value("test_flag", "control", {"distinct_id": "user123"}) assert result == "variant_b" @respx.mock async def test_get_variant_value_tracks_exposure_when_variant_selected(self): flag = create_test_flag() - flags = await self.setup_flags([flag]) + await self.setup_flags([flag]) with patch('mixpanel.flags.utils.normalized_hash') as mock_hash: mock_hash.return_value = 0.5 - _ = flags.get_variant_value("test_flag", "fallback", {"distinct_id": "user123"}) - flags._executor.shutdown() - flags._tracker.assert_called_once() + _ = self._flags.get_variant_value("test_flag", "fallback", {"distinct_id": "user123"}) + self._mock_tracker.assert_called_once() @respx.mock async def test_get_variant_value_does_not_track_exposure_on_fallback(self): - flags = await self.setup_flags([]) - _ = flags.get_variant_value("nonexistent_flag", "fallback", {"distinct_id": "user123"}) - flags._executor.shutdown() - flags._tracker.assert_not_called() + await self.setup_flags([]) + _ = self._flags.get_variant_value("nonexistent_flag", "fallback", {"distinct_id": "user123"}) + self._mock_tracker.assert_not_called() @respx.mock async def test_get_variant_value_does_not_track_exposure_without_distinct_id(self): flag = create_test_flag(context="company") - flags = await self.setup_flags([flag]) - _ = flags.get_variant_value("nonexistent_flag", "fallback", {"company_id": "company123"}) - flags._executor.shutdown() - flags._tracker.assert_not_called() + await self.setup_flags([flag]) + _ = self._flags.get_variant_value("nonexistent_flag", "fallback", {"company_id": "company123"}) + self._mock_tracker.assert_not_called() @respx.mock async def test_are_flags_ready_returns_true_when_flags_loaded(self): flag = create_test_flag() - flags = await self.setup_flags([flag]) - assert flags.are_flags_ready() == True + await self.setup_flags([flag]) + assert self._flags.are_flags_ready() == True + + @respx.mock + async def test_are_flags_ready_returns_true_when_empty_flags_loaded(self): + flag = create_test_flag() + await self.setup_flags([]) + assert self._flags.are_flags_ready() == True + @respx.mock async def test_is_enabled_returns_false_for_nonexistent_flag(self): - flags = await self.setup_flags([]) - result = flags.is_enabled("nonexistent_flag", {"distinct_id": "user123"}) + await self.setup_flags([]) + result = self._flags.is_enabled("nonexistent_flag", {"distinct_id": "user123"}) assert result == False @respx.mock @@ -263,8 +274,8 @@ async def test_is_enabled_returns_true_for_true_variant_value(self): Variant(key="treatment", value=True, is_control=False, split=100.0) ] flag = create_test_flag(variants=variants, rollout_percentage=100.0) - flags = await self.setup_flags([flag]) - result = flags.is_enabled("test_flag", {"distinct_id": "user123"}) + await self.setup_flags([flag]) + result = self._flags.is_enabled("test_flag", {"distinct_id": "user123"}) assert result == True @respx.mock @@ -285,16 +296,22 @@ async def track_fetch_calls(self): flag_v2 = create_test_flag(rollout_percentage=100.0) flags_in_order=[[flag_v1], [flag_v2]] - flags = await self.setup_flags_with_polling(flags_in_order) + await self.setup_flags_with_polling(flags_in_order) async with polling_limit_check: await polling_limit_check.wait_for(lambda: polling_iterations >= len(flags_in_order)) - result2 = flags.get_variant_value("test_flag", "fallback", {"distinct_id": "user123"}) + result2 = self._flags_with_polling.get_variant_value("test_flag", "fallback", {"distinct_id": "user123"}) assert result2 != "fallback" - await flags.astop_polling_for_definitions() - class TestLocalFeatureFlagsProviderSync: + def setup_method(self): + self.mock_tracker = Mock() + config_with_polling = LocalFlagsConfig(enable_polling=True, polling_interval_in_seconds=0) + self._flags_with_polling = LocalFeatureFlagsProvider("test-token", config_with_polling, "1.0.0", self.mock_tracker) + + def teardown_method(self): + self._flags_with_polling.__exit__(None, None, None) + def setup_flags_with_polling(self, flags_in_order: List[List[ExperimentationFlag]] = [[]]): responses = [create_flags_response(flag) for flag in flags_in_order] @@ -305,13 +322,7 @@ def setup_flags_with_polling(self, flags_in_order: List[List[ExperimentationFlag ) ) - return self.get_flags_provider(LocalFlagsConfig(enable_polling=True, polling_interval_in_seconds=0)) - - def get_flags_provider(self, config: LocalFlagsConfig) -> LocalFeatureFlagsProvider: - mock_tracker = Mock() - flags_provider = LocalFeatureFlagsProvider("test-token", config, "1.0.0", mock_tracker) - flags_provider.start_polling_for_definitions() - return flags_provider + self._flags_with_polling.start_polling_for_definitions() @respx.mock def test_get_variant_value_uses_most_recent_polled_flag(self): @@ -332,9 +343,8 @@ def track_fetch_calls(self): return original_fetch(self) with patch.object(LocalFeatureFlagsProvider, '_fetch_flag_definitions', track_fetch_calls): - flags = self.setup_flags_with_polling(flags_in_order) + self.setup_flags_with_polling(flags_in_order) polling_event.wait(timeout=5.0) - flags.stop_polling_for_definitions() assert (polling_iterations >= 3 ) - result2 = flags.get_variant_value("test_flag", "fallback", {"distinct_id": "user123"}) + result2 = self._flags_with_polling.get_variant_value("test_flag", "fallback", {"distinct_id": "user123"}) assert result2 != "fallback" diff --git a/mixpanel/flags/test_remote_feature_flags.py b/mixpanel/flags/test_remote_feature_flags.py index e7fff93..def080c 100644 --- a/mixpanel/flags/test_remote_feature_flags.py +++ b/mixpanel/flags/test_remote_feature_flags.py @@ -64,14 +64,12 @@ async def test_get_variant_value_tracks_exposure_event_if_variant_selected(self) if pending: await asyncio.gather(*pending, return_exceptions=True) - self._flags._executor.shutdown() self.mock_tracker.assert_called_once() @respx.mock async def test_get_variant_value_does_not_track_exposure_event_if_fallback(self): respx.get(ENDPOINT).mock(side_effect=httpx.RequestError("Network error")) await self._flags.aget_variant_value("test_flag", "control", {"distinct_id": "user123"}) - self._flags._executor.shutdown() self.mock_tracker.assert_not_called() @respx.mock @@ -135,14 +133,12 @@ def test_get_variant_value_tracks_exposure_event_if_variant_selected(self): return_value=create_success_response({"test_flag": SelectedVariant(variant_key="treatment", variant_value="treatment")})) self._flags.get_variant_value("test_flag", "control", {"distinct_id": "user123"}) - self._flags._executor.shutdown() self.mock_tracker.assert_called_once() @respx.mock def test_get_variant_value_does_not_track_exposure_event_if_fallback(self): respx.get(ENDPOINT).mock(side_effect=httpx.RequestError("Network error")) self._flags.get_variant_value("test_flag", "control", {"distinct_id": "user123"}) - self._flags._executor.shutdown() self.mock_tracker.assert_not_called() @respx.mock diff --git a/mixpanel/flags/types.py b/mixpanel/flags/types.py index b71eec6..186371d 100644 --- a/mixpanel/flags/types.py +++ b/mixpanel/flags/types.py @@ -1,5 +1,4 @@ from typing import Optional, List, Dict, Any -from concurrent.futures import ThreadPoolExecutor from pydantic import BaseModel, ConfigDict MIXPANEL_DEFAULT_API_ENDPOINT = "api.mixpanel.com" @@ -9,7 +8,6 @@ class FlagsConfig(BaseModel): api_host: str = "api.mixpanel.com" request_timeout_in_seconds: int = 10 - custom_executor: Optional[ThreadPoolExecutor] = None class LocalFlagsConfig(FlagsConfig): enable_polling: bool = True