Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def on_change(**_):
def another_method():
pass

assert server.state._change_callbacks["a"][0] == on_change
assert server.state._change_callbacks["a"][0][0] == on_change
assert server.trigger_name(another_method) == "my_name"
assert server.name == "test_enable_module"

Expand Down
79 changes: 78 additions & 1 deletion tests/test_translator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

from trame_server.core import Controller, State
from trame_server.core import Controller, State, Translator

logger = logging.getLogger(__name__)

Expand All @@ -17,6 +17,34 @@ def func3():
return 3


def test_translator():
a_translator = Translator()
a_translator.add_translation("foo", "a_foo")

assert a_translator.translate_key("foo") == "a_foo"
assert a_translator.translate_key("bar") == "bar"
assert a_translator.reverse_translate_key("a_foo") == "foo"
assert a_translator.reverse_translate_key("bar") == "bar"

b_translator = Translator()
b_translator.set_prefix("b_")

assert b_translator.translate_key("foo") == "b_foo"
assert b_translator.translate_key("bar") == "b_bar"
assert b_translator.reverse_translate_key("b_foo") == "foo"
assert b_translator.reverse_translate_key("b_bar") == "bar"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you test a translator that has prefix+custom-translation?

c_translator = Translator()
c_translator.set_prefix("c_")
c_translator.add_translation("foo", "still_foo")

assert c_translator.translate_key("foo") == "still_foo"
assert c_translator.translate_key("bar") == "c_bar"
assert c_translator.reverse_translate_key("still_foo") == "foo"
assert c_translator.reverse_translate_key("c_foo") == "foo"
assert c_translator.reverse_translate_key("c_bar") == "bar"


def test_state_translation():
root_state = State()
a_state = State(internal=root_state)
Expand Down Expand Up @@ -273,3 +301,52 @@ def test_controller_prefix_and_translation():
func() == root_controller[func_name]()
for func_name, func in expected_controller.items()
)


def test_change_callback():
# Ensure change callbacks are passed translated kwargs when using translations
test_passed = False

root_state = State()

a_state = State(internal=root_state)
a_state.translator.add_translation("foo", "a_foo")

def on_a_foo_change(*_args, **kwargs):
nonlocal test_passed
assert "foo" in kwargs
assert "a_foo" not in kwargs
assert kwargs["foo"] == 123
test_passed = "foo" in kwargs and "a_foo" not in kwargs

a_state.change("foo")(on_a_foo_change)
a_state.ready()
a_state.foo = 123
root_state.foo = 456
a_state.flush()

assert test_passed

# Ensure change callbacks are passed translated kwargs when using prefix
test_passed = False

root_state = State()

b_state = State(internal=root_state)
b_state.translator.set_prefix("b_")

def on_b_foo_change(*_args, **kwargs):
nonlocal test_passed
assert "foo" in kwargs
assert "b_foo" not in kwargs
assert kwargs["foo"] == 456
test_passed = "foo" in kwargs and "b_foo" not in kwargs

b_state.change("foo")(on_b_foo_change)

b_state.ready()
b_state.foo = 456
root_state.foo = 123
b_state.flush()

assert test_passed
9 changes: 6 additions & 3 deletions trame_server/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def flush(self):

# Execute state listeners
self._state_listeners.add_all(_keys)
for fn in self._state_listeners:
for fn, translator in self._state_listeners:
if isinstance(fn, weakref.WeakMethod):
callback = fn()
if callback is None:
Expand All @@ -308,7 +308,10 @@ def flush(self):
if not inspect.iscoroutinefunction(callback):
callback = reload(callback)

coroutine = callback(**self._pushed_state)
reverse_translated_state = translator.reverse_translate_dict(
self._pushed_state
)
coroutine = callback(**reverse_translated_state)
if inspect.isawaitable(coroutine):
asynchronous.create_task(coroutine)

Expand Down Expand Up @@ -362,7 +365,7 @@ def register_change_callback(func):
if name not in self._change_callbacks:
self._change_callbacks[name] = []

self._change_callbacks[name].append(func)
self._change_callbacks[name].append((func, self._translator))
return func

return register_change_callback
32 changes: 32 additions & 0 deletions trame_server/utils/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,14 @@ def __init__(self, prefix=None):
logger.info("Translator(prefix=%s)", prefix)
self._prefix = prefix
self._transl = {}
self._reverse_transl = {}

def set_prefix(self, prefix):
self._prefix = prefix

def add_translation(self, key, translated_key):
self._transl[key] = translated_key
self._reverse_transl[translated_key] = key

def translate_key(self, key):
# Reserved keys
Expand All @@ -85,12 +87,42 @@ def translate_key(self, key):

return key

def reverse_translate_key(self, translated_key):
# Reserved keys
if is_name_reserved(translated_key):
return translated_key

if translated_key in self._reverse_transl:
return self._reverse_transl[translated_key]

if self._prefix:
return translated_key.removeprefix(self._prefix)

return translated_key

def translate_list(self, key_list):
return [self.translate_key(v) for v in key_list]

def translate_dict(self, key_dict):
return {self.translate_key(k): v for k, v in key_dict.items()}

def reverse_translate_list(self, key_list):
return [self.reverse_translate_key(v) for v in key_list]

def reverse_translate_dict(self, key_dict):
d = {}

for key, value in key_dict.items():
reverse_key = self.reverse_translate_key(key)
translated_key = self.translate_key(reverse_key)

# If key != translated_key it means that this key is shadowed by something
# else in this state, so it should not be included in the translated dict
if key == translated_key:
d[reverse_key] = value

return d

def translate_js_expression(self, state, expression):
tokens = []
for token in split_when(expression, js_tokenizer):
Expand Down