Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- Feature #710: `experimental.utils.generate_base_model` can now be called with an instance of MarkupMachine directly (thanks @patrickwolf)
- Bug #715: `HierarchicalMachine._final_check` wrongly determined a parallel state final when the last child was final (thanks @DenizKucukozturk)
- Bug #716: `HierarchicalMachine` caused an `AssertionError` when `model_override` was `True` and `NestedSeperator` differed from `_` (thanks @pritam-dey3)
- Feat #706: Instroduce completion transitions which will be executed after a transition has been conducted (thanks @oEscal)

## 0.9.3 (July 2024)

Expand Down
10 changes: 10 additions & 0 deletions tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,16 @@ async def run():
await machine.switch_model_context(self)
self.assertEqual(len(w), 1)

def test_completion_transition(self):
states = ['A', 'B', 'C']
m = self.machine_cls(states=states, initial='A', auto_transitions=False)
m.add_transition('walk', 'A', 'B')
m.add_transition('', 'B', 'C')

async def run():
assert await m.walk()
assert m.is_C()

asyncio.run(run())


Expand Down
13 changes: 13 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1418,3 +1418,16 @@ class MyMachine(self.machine_cls): # type: ignore
assert trans[0].my_int == 23
assert trans[0].my_dict == {"baz": "bar"}
assert trans[0].my_none is None

def test_completion_transition(self):
states = ['A', 'B', 'C']
transitions = [
['walk', 'A', 'B'],
['', 'B', 'C'],
['complete', 'C', 'A']
]

m = self.machine_cls(states=states, transitions=transitions, initial='A', auto_transitions=False)
self.assertTrue(m.is_A())
m.walk()
self.assertTrue(m.is_C())
11 changes: 5 additions & 6 deletions tests/test_graphviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def test_roi(self):
g1 = m.get_graph(show_roi=True)
dot, nodes, edges = self.parse_dot(g1)
self.assertEqual(0, len(edges))
self.assertIn(r'label="A\l"', dot)
self.assertIn('label="A\\n"', dot)
# make sure that generating a graph without ROI has not influence on the later generated graph
# this has to be checked since graph.custom_style is a class property and is persistent for multiple
# calls of graph.generate()
Expand All @@ -213,7 +213,7 @@ def test_roi(self):
_ = m.get_graph()
g2 = m.get_graph(show_roi=True)
dot, _, _ = self.parse_dot(g2)
self.assertNotIn(r'label="A\l"', dot)
self.assertNotIn('label="A\\n"', dot)
m.to_B()
g3 = m.get_graph(show_roi=True)
_, nodes, edges = self.parse_dot(g3)
Expand Down Expand Up @@ -247,10 +247,10 @@ class CustomMachine(self.machine_cls): # type: ignore
transitions=[{'trigger': 'event', 'source': 'A', 'dest': 'B', 'label': 'LabelEvent'}],
initial='A', graph_engine=self.graph_engine)
dot, _, _ = self.parse_dot(m.get_graph())
self.assertIn(r'label="LabelA\l"', dot)
self.assertIn(r'label="NotLabelA\l"', dot)
self.assertIn('label="LabelA\\n"', dot)
self.assertIn('label="NotLabelA\\n"', dot)
self.assertIn("label=LabelEvent", dot)
self.assertNotIn(r'label="A\l"', dot)
self.assertNotIn('label="A\\n"', dot)
self.assertNotIn("label=event", dot)

def test_binary_stream(self):
Expand Down Expand Up @@ -389,7 +389,6 @@ def is_fast(*args, **kwargs):
g1 = model.get_graph(show_roi=True)
_, nodes, edges = self.parse_dot(g1)
self.assertEqual(len(edges), 2) # reset and walk
print(nodes)
self.assertEqual(len(nodes), 4)
model.walk()
model.run()
Expand Down
17 changes: 1 addition & 16 deletions tests/test_mermaid.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,11 @@
from .test_graphviz import TestDiagrams, TestDiagramsNested
from .utils import Stuff, DummyModel
from .test_core import TestTransitions, TYPE_CHECKING

from transitions.extensions import (
LockedGraphMachine, GraphMachine, HierarchicalGraphMachine, LockedHierarchicalGraphMachine
)
from transitions.extensions.states import add_state_features, Timeout, Tags
from transitions.extensions import HierarchicalGraphMachine, LockedHierarchicalGraphMachine
from unittest import skipIf
import tempfile
import os
import re
import sys
from unittest import TestCase

try:
# Just to skip tests if graphviz not installed
import graphviz as pgv # @UnresolvedImport
except ImportError: # pragma: no cover
pgv = None

if TYPE_CHECKING:
from typing import Type, List, Collection, Union, Literal


class TestMermaidDiagrams(TestDiagrams):
Expand Down
12 changes: 9 additions & 3 deletions transitions/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,12 @@ def trigger(self, model, *args, **kwargs):
# noinspection PyProtectedMember
# Machine._process should not be called somewhere else. That's why it should not be exposed
# to Machine users.
return self.machine._process(func)
res = self.machine._process(func)
if res and self.machine._can_trigger(model, "", *args, **kwargs):
_LOGGER.debug("%sTriggering completion event", self.machine.name)
# Trigger the completion event if the machine allows it
res = self.machine.events[""].trigger(model, *args, **kwargs)
return res

def _trigger(self, event_data):
"""Internal trigger function called by the ``Machine`` instance. This should not
Expand Down Expand Up @@ -925,8 +930,9 @@ def _add_may_transition_func_for_trigger(self, trigger, model):
self._checked_assignment(model, "may_%s" % trigger, partial(self._can_trigger, model, trigger))

def _add_trigger_to_model(self, trigger, model):
self._checked_assignment(model, trigger, partial(self.events[trigger].trigger, model))
self._add_may_transition_func_for_trigger(trigger, model)
if trigger:
self._checked_assignment(model, trigger, partial(self.events[trigger].trigger, model))
self._add_may_transition_func_for_trigger(trigger, model)

def _get_trigger(self, model, trigger_name, *args, **kwargs):
"""Convenience function added to the model to trigger events by name.
Expand Down
9 changes: 8 additions & 1 deletion transitions/extensions/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,11 @@ async def trigger(self, model, *args, **kwargs):
successfully executed (True if successful, False if not).
"""
func = partial(self._trigger, EventData(None, self, self.machine, model, args=args, kwargs=kwargs))
return await self.machine.process_context(func, model)
res = await self.machine.process_context(func, model)
if res and await self.machine._can_trigger(model, "", *args, **kwargs):
_LOGGER.debug("%sTriggering completion event", self.machine.name)
res = await self.machine.events[""].trigger(model, *args, **kwargs)
return res

async def _trigger(self, event_data):
event_data.state = self.machine.get_state(getattr(event_data.model, self.machine.model_attribute))
Expand Down Expand Up @@ -256,6 +260,9 @@ async def trigger_nested(self, event_data):
while elems:
done.add(machine.state_cls.separator.join(elems))
elems.pop()
if event_data.result and await self.machine._can_trigger(model, "", *event_data.args, **event_data.kwargs):
_LOGGER.debug("%sTriggering completion event", self.machine.name)
event_data.result = await self.machine.events[""].trigger_nested(event_data)
return event_data.result

async def _process(self, event_data):
Expand Down
1 change: 1 addition & 0 deletions transitions/extensions/diagrams.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class GraphMachine(MarkupMachine):
"directed": "true",
"strict": "false",
"rankdir": "LR",
"compound": "true",
}

style_attributes = {
Expand Down
17 changes: 10 additions & 7 deletions transitions/extensions/diagrams_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,21 @@ def get_graph(self, title=None, roi_state=None):
"""

def _convert_state_attributes(self, state):
label = state.get("label", state["name"])
label = state.get("label", state["name"]) + "\\n"
if self.machine.show_state_attributes:
if "tags" in state:
label += " [" + ", ".join(state["tags"]) + "]"
label += "[" + ", ".join(state["tags"]) + "]\\n"
if "on_enter" in state:
label += r"\l- enter:\l + " + r"\l + ".join(state["on_enter"])
label += "- enter:\\l"
for action in state["on_enter"]:
label += " + " + action + "\\l"
if "on_exit" in state:
label += r"\l- exit:\l + " + r"\l + ".join(state["on_exit"])
label += "- exit:\\l"
for action in state["on_exit"]:
label += " + " + action + "\\l"
if "timeout" in state:
label += r'\l- timeout(' + state['timeout'] + 's) -> (' + ', '.join(state['on_timeout']) + ')'
# end each label with a left-aligned newline
return label + r"\l"
label += '- timeout(' + state['timeout'] + 's) -> (' + ', '.join(state['on_timeout']) + ')'
return label

def _get_state_names(self, state):
if isinstance(state, (list, tuple, set)):
Expand Down
55 changes: 35 additions & 20 deletions transitions/extensions/diagrams_mermaid.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import copy
import logging
from collections import defaultdict
import re

from .diagrams_graphviz import filter_states
from .diagrams_base import BaseGraph
Expand Down Expand Up @@ -42,9 +43,9 @@ def reset_styling(self):

def _add_nodes(self, states, container):
for state in states:
container.append("state \"{}\" as {}".format(self._convert_state_attributes(state), state["name"]))
container.append("Class {} s_{}".format(state["name"],
self.custom_styles["node"][state["name"]] or "default"))
state_id = self._name_to_id(state["name"])
self.custom_styles["node"][state_id] = self.custom_styles["node"][state_id] or ""
container.append("state \"{}\" as {}".format(self._convert_state_attributes(state), state_id))

def _add_edges(self, transitions, container):
edge_labels = defaultdict(lambda: defaultdict(list))
Expand Down Expand Up @@ -99,10 +100,11 @@ def get_graph(self, title=None, roi_state=None):
active_states = active_states.union({k for k, style in self.custom_styles["node"].items() if style})
states = filter_states(copy.deepcopy(states), active_states, self.machine.state_cls)
self._add_nodes(states, fsm_graph)
self._add_node_styles(fsm_graph)
fsm_graph.append("")
self._add_edges(transitions, fsm_graph)
if self.machine.initial and (roi_state is None or roi_state == self.machine.initial):
fsm_graph.append("[*] --> {}".format(self.machine.initial))
fsm_graph.append("[*] --> {}".format(self._name_to_id(self.machine.initial)))

indent = 0
for i in range(len(fsm_graph)):
Expand All @@ -128,9 +130,21 @@ def _convert_state_attributes(self, state):
label += r"\n- exit:\n + " + r"\n + ".join(state["on_exit"])
if "timeout" in state:
label += r'\n- timeout(' + state['timeout'] + 's) -> (' + ', '.join(state['on_timeout']) + ')'
# end each label with a left-aligned newline
return label

def _name_to_id(self, name):
"""Convert a state name to a valid identifier."""
# replace all non-alphanumeric characters with underscores
return re.sub(r'\W+', '___', name)

def _add_node_styles(self, container):
"""Add styles to the graph."""
collection = defaultdict(set)
for state_id, style_name in self.custom_styles["node"].items():
collection[style_name or "default"].add(state_id)
for style_name, state_ids in collection.items():
container.append("class {} {}".format(", ".join(state_ids), "s_" + style_name))


class NestedGraph(Graph):
"""Graph creation support for transitions.extensions.nested.HierarchicalGraphMachine."""
Expand All @@ -141,28 +155,26 @@ def __init__(self, *args, **kwargs):

def set_node_style(self, state, style):
for state_name in self._get_state_names(state):
super(NestedGraph, self).set_node_style(state_name, style)
super(NestedGraph, self).set_node_style(self._name_to_id(state_name), style)

def set_previous_transition(self, src, dst):
self.custom_styles["edge"][src][dst] = "previous"
self.set_node_style(src, "previous")

def _add_nodes(self, states, container):
self._add_nested_nodes(states, container, prefix="", default_style="default")
self._add_nested_nodes(states, container, prefix="", default_style="")

def _add_nested_nodes(self, states, container, prefix, default_style):
for state in states:
name = prefix + state["name"]
state_id = self._name_to_id(state["name"])
name = prefix + state_id
container.append("state \"{}\" as {}".format(self._convert_state_attributes(state), name))
if state.get("final", False):
container.append("{} --> [*]".format(name))
if not prefix:
container.append("Class {} s_{}".format(name.replace(" ", ""),
self.custom_styles["node"][name] or default_style))
self.custom_styles["node"][name] = self.custom_styles["node"][name] or default_style
if state.get("children", None) is not None:
container.append("state {} {{".format(name))
self._cluster_states.append(name)
# with container.subgraph(name=cluster_name, graph_attr=attr) as sub:
initial = state.get("initial", "")
is_parallel = isinstance(initial, list)
if is_parallel:
Expand All @@ -171,20 +183,20 @@ def _add_nested_nodes(self, states, container, prefix, default_style):
[child],
container,
default_style="parallel",
prefix=prefix + state["name"] + self.machine.state_cls.separator,
prefix=prefix + state_id + self.machine.state_cls.separator,
)
container.append("--")
if state["children"]:
container.pop()
else:
if initial:
container.append("[*] --> {}".format(
prefix + state["name"] + self.machine.state_cls.separator + initial))
prefix + state_id + self.machine.state_cls.separator + initial))
self._add_nested_nodes(
state["children"],
container,
default_style="default",
prefix=prefix + state["name"] + self.machine.state_cls.separator,
default_style="",
prefix=prefix + state_id + self.machine.state_cls.separator,
)
container.append("}")

Expand Down Expand Up @@ -213,10 +225,13 @@ def _add_edges(self, transitions, container):
)

for src, dests in edges_attr.items():
source_id = self._name_to_id(src)
for dst, attr in dests.items():
if not attr["label"]:
continue
container.append("{source} --> {dest}: {label}".format(**attr))
dest_id = self._name_to_id(dst)
t = "{} --> {}".format(source_id, dest_id)
if attr["label"]:
t += ": {}".format(attr["label"])
container.append(t)

def _create_edge_attr(self, src, dst, transition):
return {"source": src, "dest": dst, "label": self._transition_label(transition)}
Expand Down Expand Up @@ -252,7 +267,7 @@ def draw(self, filename, format=None, prog="dot", args=""):
return None


invalid = {"style", "shape", "peripheries", "strict", "directed"}
invalid = {"style", "shape", "peripheries", "strict", "directed", "compound"}
convertible = {"fillcolor": "fill", "rankdir": "direction"}


Expand Down
9 changes: 6 additions & 3 deletions transitions/extensions/nesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ def trigger_nested(self, event_data):
while elems:
done.add(machine.state_cls.separator.join(elems))
elems.pop()
if event_data.result and self.machine._can_trigger(model, "", *event_data.args, **event_data.kwargs):
_LOGGER.debug("%sTriggering completion event", machine.name)
event_data.result = machine.events[""].trigger_nested(event_data)
return event_data.result

def _process(self, event_data):
Expand Down Expand Up @@ -793,9 +796,9 @@ def _can_trigger_nested(self, model, trigger, path, *args, **kwargs):
else:
raise
source_path.pop(-1)
if path:
with self(path.pop(0)):
return self._can_trigger_nested(model, trigger, path, *args, **kwargs)
if path and path[0] in self.states:
with self(path[0]):
return self._can_trigger_nested(model, trigger, path[1:], *args, **kwargs)
return False

def get_triggers(self, *args):
Expand Down
Loading