Skip to content

Commit 6384e29

Browse files
Update: Implement locking mechanism for CMAB service to enhance concurrency
1 parent 1ea261b commit 6384e29

File tree

2 files changed

+74
-23
lines changed

2 files changed

+74
-23
lines changed

optimizely/cmab/cmab_service.py

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import uuid
1414
import json
1515
import hashlib
16+
import threading
1617

1718
from typing import Optional, List, TypedDict
1819
from optimizely.cmab.cmab_client import DefaultCmabClient
@@ -22,6 +23,8 @@
2223
from optimizely.decision.optimizely_decide_option import OptimizelyDecideOption
2324
from optimizely import logger as _logging
2425

26+
NUM_LOCK_STRIPES = 1000
27+
2528

2629
class CmabDecision(TypedDict):
2730
variation_id: str
@@ -52,40 +55,50 @@ def __init__(self, cmab_cache: LRUCache[str, CmabCacheValue],
5255
self.cmab_cache = cmab_cache
5356
self.cmab_client = cmab_client
5457
self.logger = logger
58+
self.locks = [threading.Lock() for _ in range(NUM_LOCK_STRIPES)]
59+
60+
def _get_lock_index(self, user_id: str, rule_id: str) -> int:
61+
"""Calculate the lock index for a given user and rule combination."""
62+
# Create a hash of user_id + rule_id for consistent lock selection
63+
hash_input = f"{user_id}{rule_id}"
64+
hash_value = int(hashlib.md5(hash_input.encode()).hexdigest(), 16) % NUM_LOCK_STRIPES
65+
return hash_value
5566

5667
def get_decision(self, project_config: ProjectConfig, user_context: OptimizelyUserContext,
5768
rule_id: str, options: List[str]) -> CmabDecision:
5869

59-
filtered_attributes = self._filter_attributes(project_config, user_context, rule_id)
70+
lock_index = self._get_lock_index(user_context.user_id, rule_id)
71+
with self.locks[lock_index]:
72+
filtered_attributes = self._filter_attributes(project_config, user_context, rule_id)
6073

61-
if OptimizelyDecideOption.IGNORE_CMAB_CACHE in options:
62-
return self._fetch_decision(rule_id, user_context.user_id, filtered_attributes)
74+
if OptimizelyDecideOption.IGNORE_CMAB_CACHE in options:
75+
return self._fetch_decision(rule_id, user_context.user_id, filtered_attributes)
6376

64-
if OptimizelyDecideOption.RESET_CMAB_CACHE in options:
65-
self.cmab_cache.reset()
77+
if OptimizelyDecideOption.RESET_CMAB_CACHE in options:
78+
self.cmab_cache.reset()
6679

67-
cache_key = self._get_cache_key(user_context.user_id, rule_id)
80+
cache_key = self._get_cache_key(user_context.user_id, rule_id)
6881

69-
if OptimizelyDecideOption.INVALIDATE_USER_CMAB_CACHE in options:
70-
self.cmab_cache.remove(cache_key)
82+
if OptimizelyDecideOption.INVALIDATE_USER_CMAB_CACHE in options:
83+
self.cmab_cache.remove(cache_key)
7184

72-
cached_value = self.cmab_cache.lookup(cache_key)
85+
cached_value = self.cmab_cache.lookup(cache_key)
7386

74-
attributes_hash = self._hash_attributes(filtered_attributes)
87+
attributes_hash = self._hash_attributes(filtered_attributes)
7588

76-
if cached_value:
77-
if cached_value['attributes_hash'] == attributes_hash:
78-
return CmabDecision(variation_id=cached_value['variation_id'], cmab_uuid=cached_value['cmab_uuid'])
79-
else:
80-
self.cmab_cache.remove(cache_key)
89+
if cached_value:
90+
if cached_value['attributes_hash'] == attributes_hash:
91+
return CmabDecision(variation_id=cached_value['variation_id'], cmab_uuid=cached_value['cmab_uuid'])
92+
else:
93+
self.cmab_cache.remove(cache_key)
8194

82-
cmab_decision = self._fetch_decision(rule_id, user_context.user_id, filtered_attributes)
83-
self.cmab_cache.save(cache_key, {
84-
'attributes_hash': attributes_hash,
85-
'variation_id': cmab_decision['variation_id'],
86-
'cmab_uuid': cmab_decision['cmab_uuid'],
87-
})
88-
return cmab_decision
95+
cmab_decision = self._fetch_decision(rule_id, user_context.user_id, filtered_attributes)
96+
self.cmab_cache.save(cache_key, {
97+
'attributes_hash': attributes_hash,
98+
'variation_id': cmab_decision['variation_id'],
99+
'cmab_uuid': cmab_decision['cmab_uuid'],
100+
})
101+
return cmab_decision
89102

90103
def _fetch_decision(self, rule_id: str, user_id: str, attributes: UserAttributes) -> CmabDecision:
91104
cmab_uuid = str(uuid.uuid4())

tests/test_cmab_service.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# limitations under the License.
1313
import unittest
1414
from unittest.mock import MagicMock
15-
from optimizely.cmab.cmab_service import DefaultCmabService
15+
from optimizely.cmab.cmab_service import DefaultCmabService, NUM_LOCK_STRIPES
1616
from optimizely.optimizely_user_context import OptimizelyUserContext
1717
from optimizely.decision.optimizely_decide_option import OptimizelyDecideOption
1818
from optimizely.odp.lru_cache import LRUCache
@@ -185,3 +185,41 @@ def test_only_cmab_attributes_passed_to_client(self):
185185
{"age": 25, "location": "USA"},
186186
decision["cmab_uuid"]
187187
)
188+
189+
def test_same_user_rule_combination_uses_consistent_lock(self):
190+
"""Verifies that the same user/rule combination always uses the same lock index"""
191+
user_id = "test_user"
192+
rule_id = "test_rule"
193+
194+
# Get lock index multiple times
195+
index1 = self.cmab_service._get_lock_index(user_id, rule_id)
196+
index2 = self.cmab_service._get_lock_index(user_id, rule_id)
197+
index3 = self.cmab_service._get_lock_index(user_id, rule_id)
198+
199+
# All should be the same
200+
self.assertEqual(index1, index2, "Same user/rule should always use same lock")
201+
self.assertEqual(index2, index3, "Same user/rule should always use same lock")
202+
203+
def test_lock_striping_distribution(self):
204+
"""Verifies that different user/rule combinations use different locks to allow for better concurrency"""
205+
test_cases = [
206+
("user1", "rule1"),
207+
("user2", "rule1"),
208+
("user1", "rule2"),
209+
("user3", "rule3"),
210+
("user4", "rule4"),
211+
]
212+
213+
lock_indices = set()
214+
for user_id, rule_id in test_cases:
215+
index = self.cmab_service._get_lock_index(user_id, rule_id)
216+
217+
# Verify index is within expected range
218+
self.assertGreaterEqual(index, 0, "Lock index should be non-negative")
219+
self.assertLess(index, NUM_LOCK_STRIPES, "Lock index should be less than NUM_LOCK_STRIPES")
220+
221+
lock_indices.add(index)
222+
223+
# We should have multiple different lock indices (though not necessarily all unique due to hash collisions)
224+
self.assertGreater(len(lock_indices), 1,
225+
"Different user/rule combinations should generally use different locks")

0 commit comments

Comments
 (0)