Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 35 additions & 22 deletions optimizely/cmab/cmab_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import uuid
import json
import hashlib
import threading

from typing import Optional, List, TypedDict
from optimizely.cmab.cmab_client import DefaultCmabClient
Expand All @@ -22,6 +23,8 @@
from optimizely.decision.optimizely_decide_option import OptimizelyDecideOption
from optimizely import logger as _logging

NUM_LOCK_STRIPES = 1000


class CmabDecision(TypedDict):
variation_id: str
Expand Down Expand Up @@ -52,40 +55,50 @@ def __init__(self, cmab_cache: LRUCache[str, CmabCacheValue],
self.cmab_cache = cmab_cache
self.cmab_client = cmab_client
self.logger = logger
self.locks = [threading.Lock() for _ in range(NUM_LOCK_STRIPES)]

def _get_lock_index(self, user_id: str, rule_id: str) -> int:
"""Calculate the lock index for a given user and rule combination."""
# Create a hash of user_id + rule_id for consistent lock selection
hash_input = f"{user_id}{rule_id}"
hash_value = int(hashlib.md5(hash_input.encode()).hexdigest(), 16) % NUM_LOCK_STRIPES
return hash_value

def get_decision(self, project_config: ProjectConfig, user_context: OptimizelyUserContext,
rule_id: str, options: List[str]) -> CmabDecision:

filtered_attributes = self._filter_attributes(project_config, user_context, rule_id)
lock_index = self._get_lock_index(user_context.user_id, rule_id)
with self.locks[lock_index]:
filtered_attributes = self._filter_attributes(project_config, user_context, rule_id)

if OptimizelyDecideOption.IGNORE_CMAB_CACHE in options:
return self._fetch_decision(rule_id, user_context.user_id, filtered_attributes)
if OptimizelyDecideOption.IGNORE_CMAB_CACHE in options:
return self._fetch_decision(rule_id, user_context.user_id, filtered_attributes)

if OptimizelyDecideOption.RESET_CMAB_CACHE in options:
self.cmab_cache.reset()
if OptimizelyDecideOption.RESET_CMAB_CACHE in options:
self.cmab_cache.reset()

cache_key = self._get_cache_key(user_context.user_id, rule_id)
cache_key = self._get_cache_key(user_context.user_id, rule_id)

if OptimizelyDecideOption.INVALIDATE_USER_CMAB_CACHE in options:
self.cmab_cache.remove(cache_key)
if OptimizelyDecideOption.INVALIDATE_USER_CMAB_CACHE in options:
self.cmab_cache.remove(cache_key)

cached_value = self.cmab_cache.lookup(cache_key)
cached_value = self.cmab_cache.lookup(cache_key)

attributes_hash = self._hash_attributes(filtered_attributes)
attributes_hash = self._hash_attributes(filtered_attributes)

if cached_value:
if cached_value['attributes_hash'] == attributes_hash:
return CmabDecision(variation_id=cached_value['variation_id'], cmab_uuid=cached_value['cmab_uuid'])
else:
self.cmab_cache.remove(cache_key)
if cached_value:
if cached_value['attributes_hash'] == attributes_hash:
return CmabDecision(variation_id=cached_value['variation_id'], cmab_uuid=cached_value['cmab_uuid'])
else:
self.cmab_cache.remove(cache_key)

cmab_decision = self._fetch_decision(rule_id, user_context.user_id, filtered_attributes)
self.cmab_cache.save(cache_key, {
'attributes_hash': attributes_hash,
'variation_id': cmab_decision['variation_id'],
'cmab_uuid': cmab_decision['cmab_uuid'],
})
return cmab_decision
cmab_decision = self._fetch_decision(rule_id, user_context.user_id, filtered_attributes)
self.cmab_cache.save(cache_key, {
'attributes_hash': attributes_hash,
'variation_id': cmab_decision['variation_id'],
'cmab_uuid': cmab_decision['cmab_uuid'],
})
return cmab_decision

def _fetch_decision(self, rule_id: str, user_id: str, attributes: UserAttributes) -> CmabDecision:
cmab_uuid = str(uuid.uuid4())
Expand Down
40 changes: 39 additions & 1 deletion tests/test_cmab_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# limitations under the License.
import unittest
from unittest.mock import MagicMock
from optimizely.cmab.cmab_service import DefaultCmabService
from optimizely.cmab.cmab_service import DefaultCmabService, NUM_LOCK_STRIPES
from optimizely.optimizely_user_context import OptimizelyUserContext
from optimizely.decision.optimizely_decide_option import OptimizelyDecideOption
from optimizely.odp.lru_cache import LRUCache
Expand Down Expand Up @@ -185,3 +185,41 @@ def test_only_cmab_attributes_passed_to_client(self):
{"age": 25, "location": "USA"},
decision["cmab_uuid"]
)

def test_same_user_rule_combination_uses_consistent_lock(self):
"""Verifies that the same user/rule combination always uses the same lock index"""
user_id = "test_user"
rule_id = "test_rule"

# Get lock index multiple times
index1 = self.cmab_service._get_lock_index(user_id, rule_id)
index2 = self.cmab_service._get_lock_index(user_id, rule_id)
index3 = self.cmab_service._get_lock_index(user_id, rule_id)

# All should be the same
self.assertEqual(index1, index2, "Same user/rule should always use same lock")
self.assertEqual(index2, index3, "Same user/rule should always use same lock")

def test_lock_striping_distribution(self):
"""Verifies that different user/rule combinations use different locks to allow for better concurrency"""
test_cases = [
("user1", "rule1"),
("user2", "rule1"),
("user1", "rule2"),
("user3", "rule3"),
("user4", "rule4"),
]

lock_indices = set()
for user_id, rule_id in test_cases:
index = self.cmab_service._get_lock_index(user_id, rule_id)

# Verify index is within expected range
self.assertGreaterEqual(index, 0, "Lock index should be non-negative")
self.assertLess(index, NUM_LOCK_STRIPES, "Lock index should be less than NUM_LOCK_STRIPES")

lock_indices.add(index)

# We should have multiple different lock indices (though not necessarily all unique due to hash collisions)
self.assertGreater(len(lock_indices), 1,
"Different user/rule combinations should generally use different locks")
Loading