diff --git a/tests/test_server.py b/tests/test_server.py index 15c4b3f..18f13a1 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -109,7 +109,7 @@ def on_change(**_): def another_method(): pass - assert server.state._change_callbacks["a"][0] == on_change + assert server.state._change_callbacks["a"][0][0] == on_change assert server.trigger_name(another_method) == "my_name" assert server.name == "test_enable_module" diff --git a/tests/test_translator.py b/tests/test_translator.py index 15c9105..a12e56c 100644 --- a/tests/test_translator.py +++ b/tests/test_translator.py @@ -1,6 +1,6 @@ import logging -from trame_server.core import Controller, State +from trame_server.core import Controller, State, Translator logger = logging.getLogger(__name__) @@ -17,6 +17,34 @@ def func3(): return 3 +def test_translator(): + a_translator = Translator() + a_translator.add_translation("foo", "a_foo") + + assert a_translator.translate_key("foo") == "a_foo" + assert a_translator.translate_key("bar") == "bar" + assert a_translator.reverse_translate_key("a_foo") == "foo" + assert a_translator.reverse_translate_key("bar") == "bar" + + b_translator = Translator() + b_translator.set_prefix("b_") + + assert b_translator.translate_key("foo") == "b_foo" + assert b_translator.translate_key("bar") == "b_bar" + assert b_translator.reverse_translate_key("b_foo") == "foo" + assert b_translator.reverse_translate_key("b_bar") == "bar" + + c_translator = Translator() + c_translator.set_prefix("c_") + c_translator.add_translation("foo", "still_foo") + + assert c_translator.translate_key("foo") == "still_foo" + assert c_translator.translate_key("bar") == "c_bar" + assert c_translator.reverse_translate_key("still_foo") == "foo" + assert c_translator.reverse_translate_key("c_foo") == "foo" + assert c_translator.reverse_translate_key("c_bar") == "bar" + + def test_state_translation(): root_state = State() a_state = State(internal=root_state) @@ -273,3 +301,52 @@ def test_controller_prefix_and_translation(): func() == root_controller[func_name]() for func_name, func in expected_controller.items() ) + + +def test_change_callback(): + # Ensure change callbacks are passed translated kwargs when using translations + test_passed = False + + root_state = State() + + a_state = State(internal=root_state) + a_state.translator.add_translation("foo", "a_foo") + + def on_a_foo_change(*_args, **kwargs): + nonlocal test_passed + assert "foo" in kwargs + assert "a_foo" not in kwargs + assert kwargs["foo"] == 123 + test_passed = "foo" in kwargs and "a_foo" not in kwargs + + a_state.change("foo")(on_a_foo_change) + a_state.ready() + a_state.foo = 123 + root_state.foo = 456 + a_state.flush() + + assert test_passed + + # Ensure change callbacks are passed translated kwargs when using prefix + test_passed = False + + root_state = State() + + b_state = State(internal=root_state) + b_state.translator.set_prefix("b_") + + def on_b_foo_change(*_args, **kwargs): + nonlocal test_passed + assert "foo" in kwargs + assert "b_foo" not in kwargs + assert kwargs["foo"] == 456 + test_passed = "foo" in kwargs and "b_foo" not in kwargs + + b_state.change("foo")(on_b_foo_change) + + b_state.ready() + b_state.foo = 456 + root_state.foo = 123 + b_state.flush() + + assert test_passed diff --git a/trame_server/state.py b/trame_server/state.py index eb82ac0..e16456c 100644 --- a/trame_server/state.py +++ b/trame_server/state.py @@ -296,7 +296,7 @@ def flush(self): # Execute state listeners self._state_listeners.add_all(_keys) - for fn in self._state_listeners: + for fn, translator in self._state_listeners: if isinstance(fn, weakref.WeakMethod): callback = fn() if callback is None: @@ -308,7 +308,10 @@ def flush(self): if not inspect.iscoroutinefunction(callback): callback = reload(callback) - coroutine = callback(**self._pushed_state) + reverse_translated_state = translator.reverse_translate_dict( + self._pushed_state + ) + coroutine = callback(**reverse_translated_state) if inspect.isawaitable(coroutine): asynchronous.create_task(coroutine) @@ -362,7 +365,7 @@ def register_change_callback(func): if name not in self._change_callbacks: self._change_callbacks[name] = [] - self._change_callbacks[name].append(func) + self._change_callbacks[name].append((func, self._translator)) return func return register_change_callback diff --git a/trame_server/utils/namespace.py b/trame_server/utils/namespace.py index 42320aa..f73b67b 100644 --- a/trame_server/utils/namespace.py +++ b/trame_server/utils/namespace.py @@ -65,12 +65,14 @@ def __init__(self, prefix=None): logger.info("Translator(prefix=%s)", prefix) self._prefix = prefix self._transl = {} + self._reverse_transl = {} def set_prefix(self, prefix): self._prefix = prefix def add_translation(self, key, translated_key): self._transl[key] = translated_key + self._reverse_transl[translated_key] = key def translate_key(self, key): # Reserved keys @@ -85,12 +87,42 @@ def translate_key(self, key): return key + def reverse_translate_key(self, translated_key): + # Reserved keys + if is_name_reserved(translated_key): + return translated_key + + if translated_key in self._reverse_transl: + return self._reverse_transl[translated_key] + + if self._prefix: + return translated_key.removeprefix(self._prefix) + + return translated_key + def translate_list(self, key_list): return [self.translate_key(v) for v in key_list] def translate_dict(self, key_dict): return {self.translate_key(k): v for k, v in key_dict.items()} + def reverse_translate_list(self, key_list): + return [self.reverse_translate_key(v) for v in key_list] + + def reverse_translate_dict(self, key_dict): + d = {} + + for key, value in key_dict.items(): + reverse_key = self.reverse_translate_key(key) + translated_key = self.translate_key(reverse_key) + + # If key != translated_key it means that this key is shadowed by something + # else in this state, so it should not be included in the translated dict + if key == translated_key: + d[reverse_key] = value + + return d + def translate_js_expression(self, state, expression): tokens = [] for token in split_when(expression, js_tokenizer):