diff --git a/Changelog.md b/Changelog.md index 59c6ce17..1b3c8326 100644 --- a/Changelog.md +++ b/Changelog.md @@ -9,6 +9,7 @@ - Feature #710: `experimental.utils.generate_base_model` can now be called with an instance of MarkupMachine directly (thanks @patrickwolf) - Bug #715: `HierarchicalMachine._final_check` wrongly determined a parallel state final when the last child was final (thanks @DenizKucukozturk) - Bug #716: `HierarchicalMachine` caused an `AssertionError` when `model_override` was `True` and `NestedSeperator` differed from `_` (thanks @pritam-dey3) +- Feat #706: Instroduce completion transitions which will be executed after a transition has been conducted (thanks @oEscal) ## 0.9.3 (July 2024) diff --git a/tests/test_async.py b/tests/test_async.py index db4d67b4..50cb8ba0 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -741,6 +741,16 @@ async def run(): await machine.switch_model_context(self) self.assertEqual(len(w), 1) + def test_completion_transition(self): + states = ['A', 'B', 'C'] + m = self.machine_cls(states=states, initial='A', auto_transitions=False) + m.add_transition('walk', 'A', 'B') + m.add_transition('', 'B', 'C') + + async def run(): + assert await m.walk() + assert m.is_C() + asyncio.run(run()) diff --git a/tests/test_core.py b/tests/test_core.py index c9a12136..d70a9c54 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1418,3 +1418,16 @@ class MyMachine(self.machine_cls): # type: ignore assert trans[0].my_int == 23 assert trans[0].my_dict == {"baz": "bar"} assert trans[0].my_none is None + + def test_completion_transition(self): + states = ['A', 'B', 'C'] + transitions = [ + ['walk', 'A', 'B'], + ['', 'B', 'C'], + ['complete', 'C', 'A'] + ] + + m = self.machine_cls(states=states, transitions=transitions, initial='A', auto_transitions=False) + self.assertTrue(m.is_A()) + m.walk() + self.assertTrue(m.is_C()) diff --git a/tests/test_graphviz.py b/tests/test_graphviz.py index 54dc95ad..5942afdc 100644 --- a/tests/test_graphviz.py +++ b/tests/test_graphviz.py @@ -204,7 +204,7 @@ def test_roi(self): g1 = m.get_graph(show_roi=True) dot, nodes, edges = self.parse_dot(g1) self.assertEqual(0, len(edges)) - self.assertIn(r'label="A\l"', dot) + self.assertIn('label="A\\n"', dot) # make sure that generating a graph without ROI has not influence on the later generated graph # this has to be checked since graph.custom_style is a class property and is persistent for multiple # calls of graph.generate() @@ -213,7 +213,7 @@ def test_roi(self): _ = m.get_graph() g2 = m.get_graph(show_roi=True) dot, _, _ = self.parse_dot(g2) - self.assertNotIn(r'label="A\l"', dot) + self.assertNotIn('label="A\\n"', dot) m.to_B() g3 = m.get_graph(show_roi=True) _, nodes, edges = self.parse_dot(g3) @@ -247,10 +247,10 @@ class CustomMachine(self.machine_cls): # type: ignore transitions=[{'trigger': 'event', 'source': 'A', 'dest': 'B', 'label': 'LabelEvent'}], initial='A', graph_engine=self.graph_engine) dot, _, _ = self.parse_dot(m.get_graph()) - self.assertIn(r'label="LabelA\l"', dot) - self.assertIn(r'label="NotLabelA\l"', dot) + self.assertIn('label="LabelA\\n"', dot) + self.assertIn('label="NotLabelA\\n"', dot) self.assertIn("label=LabelEvent", dot) - self.assertNotIn(r'label="A\l"', dot) + self.assertNotIn('label="A\\n"', dot) self.assertNotIn("label=event", dot) def test_binary_stream(self): @@ -389,7 +389,6 @@ def is_fast(*args, **kwargs): g1 = model.get_graph(show_roi=True) _, nodes, edges = self.parse_dot(g1) self.assertEqual(len(edges), 2) # reset and walk - print(nodes) self.assertEqual(len(nodes), 4) model.walk() model.run() diff --git a/tests/test_mermaid.py b/tests/test_mermaid.py index 589a1003..c80049de 100644 --- a/tests/test_mermaid.py +++ b/tests/test_mermaid.py @@ -1,26 +1,11 @@ from .test_graphviz import TestDiagrams, TestDiagramsNested -from .utils import Stuff, DummyModel -from .test_core import TestTransitions, TYPE_CHECKING -from transitions.extensions import ( - LockedGraphMachine, GraphMachine, HierarchicalGraphMachine, LockedHierarchicalGraphMachine -) -from transitions.extensions.states import add_state_features, Timeout, Tags +from transitions.extensions import HierarchicalGraphMachine, LockedHierarchicalGraphMachine from unittest import skipIf import tempfile import os import re import sys -from unittest import TestCase - -try: - # Just to skip tests if graphviz not installed - import graphviz as pgv # @UnresolvedImport -except ImportError: # pragma: no cover - pgv = None - -if TYPE_CHECKING: - from typing import Type, List, Collection, Union, Literal class TestMermaidDiagrams(TestDiagrams): diff --git a/transitions/core.py b/transitions/core.py index 3d469ac5..e166baa1 100644 --- a/transitions/core.py +++ b/transitions/core.py @@ -403,7 +403,12 @@ def trigger(self, model, *args, **kwargs): # noinspection PyProtectedMember # Machine._process should not be called somewhere else. That's why it should not be exposed # to Machine users. - return self.machine._process(func) + res = self.machine._process(func) + if res and self.machine._can_trigger(model, "", *args, **kwargs): + _LOGGER.debug("%sTriggering completion event", self.machine.name) + # Trigger the completion event if the machine allows it + res = self.machine.events[""].trigger(model, *args, **kwargs) + return res def _trigger(self, event_data): """Internal trigger function called by the ``Machine`` instance. This should not @@ -925,8 +930,9 @@ def _add_may_transition_func_for_trigger(self, trigger, model): self._checked_assignment(model, "may_%s" % trigger, partial(self._can_trigger, model, trigger)) def _add_trigger_to_model(self, trigger, model): - self._checked_assignment(model, trigger, partial(self.events[trigger].trigger, model)) - self._add_may_transition_func_for_trigger(trigger, model) + if trigger: + self._checked_assignment(model, trigger, partial(self.events[trigger].trigger, model)) + self._add_may_transition_func_for_trigger(trigger, model) def _get_trigger(self, model, trigger_name, *args, **kwargs): """Convenience function added to the model to trigger events by name. diff --git a/transitions/extensions/asyncio.py b/transitions/extensions/asyncio.py index d2525e53..cb8b098b 100644 --- a/transitions/extensions/asyncio.py +++ b/transitions/extensions/asyncio.py @@ -187,7 +187,11 @@ async def trigger(self, model, *args, **kwargs): successfully executed (True if successful, False if not). """ func = partial(self._trigger, EventData(None, self, self.machine, model, args=args, kwargs=kwargs)) - return await self.machine.process_context(func, model) + res = await self.machine.process_context(func, model) + if res and await self.machine._can_trigger(model, "", *args, **kwargs): + _LOGGER.debug("%sTriggering completion event", self.machine.name) + res = await self.machine.events[""].trigger(model, *args, **kwargs) + return res async def _trigger(self, event_data): event_data.state = self.machine.get_state(getattr(event_data.model, self.machine.model_attribute)) @@ -256,6 +260,9 @@ async def trigger_nested(self, event_data): while elems: done.add(machine.state_cls.separator.join(elems)) elems.pop() + if event_data.result and await self.machine._can_trigger(model, "", *event_data.args, **event_data.kwargs): + _LOGGER.debug("%sTriggering completion event", self.machine.name) + event_data.result = await self.machine.events[""].trigger_nested(event_data) return event_data.result async def _process(self, event_data): diff --git a/transitions/extensions/diagrams.py b/transitions/extensions/diagrams.py index 7538a900..325575ee 100644 --- a/transitions/extensions/diagrams.py +++ b/transitions/extensions/diagrams.py @@ -66,6 +66,7 @@ class GraphMachine(MarkupMachine): "directed": "true", "strict": "false", "rankdir": "LR", + "compound": "true", } style_attributes = { diff --git a/transitions/extensions/diagrams_base.py b/transitions/extensions/diagrams_base.py index e97bd06d..1738f6b9 100644 --- a/transitions/extensions/diagrams_base.py +++ b/transitions/extensions/diagrams_base.py @@ -63,18 +63,21 @@ def get_graph(self, title=None, roi_state=None): """ def _convert_state_attributes(self, state): - label = state.get("label", state["name"]) + label = state.get("label", state["name"]) + "\\n" if self.machine.show_state_attributes: if "tags" in state: - label += " [" + ", ".join(state["tags"]) + "]" + label += "[" + ", ".join(state["tags"]) + "]\\n" if "on_enter" in state: - label += r"\l- enter:\l + " + r"\l + ".join(state["on_enter"]) + label += "- enter:\\l" + for action in state["on_enter"]: + label += " + " + action + "\\l" if "on_exit" in state: - label += r"\l- exit:\l + " + r"\l + ".join(state["on_exit"]) + label += "- exit:\\l" + for action in state["on_exit"]: + label += " + " + action + "\\l" if "timeout" in state: - label += r'\l- timeout(' + state['timeout'] + 's) -> (' + ', '.join(state['on_timeout']) + ')' - # end each label with a left-aligned newline - return label + r"\l" + label += '- timeout(' + state['timeout'] + 's) -> (' + ', '.join(state['on_timeout']) + ')' + return label def _get_state_names(self, state): if isinstance(state, (list, tuple, set)): diff --git a/transitions/extensions/diagrams_mermaid.py b/transitions/extensions/diagrams_mermaid.py index 9df99ad3..b06bcbab 100644 --- a/transitions/extensions/diagrams_mermaid.py +++ b/transitions/extensions/diagrams_mermaid.py @@ -8,6 +8,7 @@ import copy import logging from collections import defaultdict +import re from .diagrams_graphviz import filter_states from .diagrams_base import BaseGraph @@ -42,9 +43,9 @@ def reset_styling(self): def _add_nodes(self, states, container): for state in states: - container.append("state \"{}\" as {}".format(self._convert_state_attributes(state), state["name"])) - container.append("Class {} s_{}".format(state["name"], - self.custom_styles["node"][state["name"]] or "default")) + state_id = self._name_to_id(state["name"]) + self.custom_styles["node"][state_id] = self.custom_styles["node"][state_id] or "" + container.append("state \"{}\" as {}".format(self._convert_state_attributes(state), state_id)) def _add_edges(self, transitions, container): edge_labels = defaultdict(lambda: defaultdict(list)) @@ -99,10 +100,11 @@ def get_graph(self, title=None, roi_state=None): active_states = active_states.union({k for k, style in self.custom_styles["node"].items() if style}) states = filter_states(copy.deepcopy(states), active_states, self.machine.state_cls) self._add_nodes(states, fsm_graph) + self._add_node_styles(fsm_graph) fsm_graph.append("") self._add_edges(transitions, fsm_graph) if self.machine.initial and (roi_state is None or roi_state == self.machine.initial): - fsm_graph.append("[*] --> {}".format(self.machine.initial)) + fsm_graph.append("[*] --> {}".format(self._name_to_id(self.machine.initial))) indent = 0 for i in range(len(fsm_graph)): @@ -128,9 +130,21 @@ def _convert_state_attributes(self, state): label += r"\n- exit:\n + " + r"\n + ".join(state["on_exit"]) if "timeout" in state: label += r'\n- timeout(' + state['timeout'] + 's) -> (' + ', '.join(state['on_timeout']) + ')' - # end each label with a left-aligned newline return label + def _name_to_id(self, name): + """Convert a state name to a valid identifier.""" + # replace all non-alphanumeric characters with underscores + return re.sub(r'\W+', '___', name) + + def _add_node_styles(self, container): + """Add styles to the graph.""" + collection = defaultdict(set) + for state_id, style_name in self.custom_styles["node"].items(): + collection[style_name or "default"].add(state_id) + for style_name, state_ids in collection.items(): + container.append("class {} {}".format(", ".join(state_ids), "s_" + style_name)) + class NestedGraph(Graph): """Graph creation support for transitions.extensions.nested.HierarchicalGraphMachine.""" @@ -141,28 +155,26 @@ def __init__(self, *args, **kwargs): def set_node_style(self, state, style): for state_name in self._get_state_names(state): - super(NestedGraph, self).set_node_style(state_name, style) + super(NestedGraph, self).set_node_style(self._name_to_id(state_name), style) def set_previous_transition(self, src, dst): self.custom_styles["edge"][src][dst] = "previous" self.set_node_style(src, "previous") def _add_nodes(self, states, container): - self._add_nested_nodes(states, container, prefix="", default_style="default") + self._add_nested_nodes(states, container, prefix="", default_style="") def _add_nested_nodes(self, states, container, prefix, default_style): for state in states: - name = prefix + state["name"] + state_id = self._name_to_id(state["name"]) + name = prefix + state_id container.append("state \"{}\" as {}".format(self._convert_state_attributes(state), name)) if state.get("final", False): container.append("{} --> [*]".format(name)) - if not prefix: - container.append("Class {} s_{}".format(name.replace(" ", ""), - self.custom_styles["node"][name] or default_style)) + self.custom_styles["node"][name] = self.custom_styles["node"][name] or default_style if state.get("children", None) is not None: container.append("state {} {{".format(name)) self._cluster_states.append(name) - # with container.subgraph(name=cluster_name, graph_attr=attr) as sub: initial = state.get("initial", "") is_parallel = isinstance(initial, list) if is_parallel: @@ -171,7 +183,7 @@ def _add_nested_nodes(self, states, container, prefix, default_style): [child], container, default_style="parallel", - prefix=prefix + state["name"] + self.machine.state_cls.separator, + prefix=prefix + state_id + self.machine.state_cls.separator, ) container.append("--") if state["children"]: @@ -179,12 +191,12 @@ def _add_nested_nodes(self, states, container, prefix, default_style): else: if initial: container.append("[*] --> {}".format( - prefix + state["name"] + self.machine.state_cls.separator + initial)) + prefix + state_id + self.machine.state_cls.separator + initial)) self._add_nested_nodes( state["children"], container, - default_style="default", - prefix=prefix + state["name"] + self.machine.state_cls.separator, + default_style="", + prefix=prefix + state_id + self.machine.state_cls.separator, ) container.append("}") @@ -213,10 +225,13 @@ def _add_edges(self, transitions, container): ) for src, dests in edges_attr.items(): + source_id = self._name_to_id(src) for dst, attr in dests.items(): - if not attr["label"]: - continue - container.append("{source} --> {dest}: {label}".format(**attr)) + dest_id = self._name_to_id(dst) + t = "{} --> {}".format(source_id, dest_id) + if attr["label"]: + t += ": {}".format(attr["label"]) + container.append(t) def _create_edge_attr(self, src, dst, transition): return {"source": src, "dest": dst, "label": self._transition_label(transition)} @@ -252,7 +267,7 @@ def draw(self, filename, format=None, prog="dot", args=""): return None -invalid = {"style", "shape", "peripheries", "strict", "directed"} +invalid = {"style", "shape", "peripheries", "strict", "directed", "compound"} convertible = {"fillcolor": "fill", "rankdir": "direction"} diff --git a/transitions/extensions/nesting.py b/transitions/extensions/nesting.py index a392482a..5516a914 100644 --- a/transitions/extensions/nesting.py +++ b/transitions/extensions/nesting.py @@ -140,6 +140,9 @@ def trigger_nested(self, event_data): while elems: done.add(machine.state_cls.separator.join(elems)) elems.pop() + if event_data.result and self.machine._can_trigger(model, "", *event_data.args, **event_data.kwargs): + _LOGGER.debug("%sTriggering completion event", machine.name) + event_data.result = machine.events[""].trigger_nested(event_data) return event_data.result def _process(self, event_data): @@ -793,9 +796,9 @@ def _can_trigger_nested(self, model, trigger, path, *args, **kwargs): else: raise source_path.pop(-1) - if path: - with self(path.pop(0)): - return self._can_trigger_nested(model, trigger, path, *args, **kwargs) + if path and path[0] in self.states: + with self(path[0]): + return self._can_trigger_nested(model, trigger, path[1:], *args, **kwargs) return False def get_triggers(self, *args):