diff --git a/tests/test_jinja_normalize_unwrapped_expression.py b/tests/test_jinja_normalize_unwrapped_expression.py new file mode 100644 index 0000000..1cce11e --- /dev/null +++ b/tests/test_jinja_normalize_unwrapped_expression.py @@ -0,0 +1,72 @@ +# Author: Zhongkai Fu (fuzhongkai@gmail.com) +# License: BSD 3-Clause License + +from pathlib import Path +import sys + +ROOT_DIR = Path(__file__).resolve().parents[1] +if str(ROOT_DIR) not in sys.path: + sys.path.insert(0, str(ROOT_DIR)) + +from velvetflow.verification.jinja_validation import ( + normalize_condition_params_to_jinja, + normalize_params_to_jinja, +) + + +BASE_WORKFLOW = { + "workflow_name": "normalize_expr", + "description": "", + "nodes": [ + { + "id": "fetch", + "type": "action", + "action_id": "hr.get_today_temperatures.v1", + "params": {"date": "2024-01-01"}, + }, + { + "id": "record", + "type": "action", + "action_id": "hr.record_health_event.v1", + "params": { + "event_type": "temperature", + "date": "result_of.fetch.date", + "abnormal_count": "result_of.fetch.data | length > 0", + }, + }, + { + "id": "check", + "type": "condition", + "params": { + "expression": "result_of.fetch.data | length > 0", + }, + "true_to_node": None, + "false_to_node": None, + }, + ], + "edges": [ + {"from": "fetch", "to": "record"}, + {"from": "record", "to": "check"}, + ], +} + + +def test_normalize_params_wraps_unwrapped_jinja_expression(): + normalized, summary, errors = normalize_params_to_jinja(BASE_WORKFLOW) + + assert errors == [] + assert summary.get("applied") is True + + record_params = normalized["nodes"][1]["params"] + assert record_params["date"] == "{{ result_of.fetch.date }}" + assert record_params["abnormal_count"] == "{{ result_of.fetch.data | length > 0 }}" + + +def test_normalize_condition_params_wraps_unwrapped_expression(): + normalized, summary, errors = normalize_condition_params_to_jinja(BASE_WORKFLOW) + + assert errors == [] + assert summary.get("applied") is True + + condition_params = normalized["nodes"][2]["params"] + assert condition_params["expression"] == "{{ result_of.fetch.data | length > 0 }}" diff --git a/tests/test_jinja_template_reference_validation.py b/tests/test_jinja_template_reference_validation.py new file mode 100644 index 0000000..6b04a76 --- /dev/null +++ b/tests/test_jinja_template_reference_validation.py @@ -0,0 +1,63 @@ +# Author: Zhongkai Fu (fuzhongkai@gmail.com) +# License: BSD 3-Clause License + +from pathlib import Path +import sys + +ROOT_DIR = Path(__file__).resolve().parents[1] +if str(ROOT_DIR) not in sys.path: + sys.path.insert(0, str(ROOT_DIR)) + +from velvetflow.action_registry import BUSINESS_ACTIONS +from velvetflow.verification.validation import validate_completed_workflow + +ACTION_REGISTRY = BUSINESS_ACTIONS + + +def _workflow_with_expression_param(expression: str): + return { + "workflow_name": "jinja_expression_param", + "description": "", + "nodes": [ + { + "id": "fetch_temperatures", + "type": "action", + "action_id": "hr.get_today_temperatures.v1", + "params": {"date": "2024-01-01"}, + }, + { + "id": "record_event", + "type": "action", + "action_id": "hr.record_health_event.v1", + "params": { + "event_type": "temperature", + "date": "{{ result_of.fetch_temperatures.date }}", + "abnormal_count": expression, + }, + }, + ], + "edges": [ + {"from": "fetch_temperatures", "to": "record_event"}, + ], + } + + +def test_template_expression_with_missing_field_is_reported(): + workflow = _workflow_with_expression_param( + "{{ (result_of.fetch_temperatures.data | length) > 0 and result_of.fetch_temperatures.data[0].missing }}" + ) + errors = validate_completed_workflow(workflow, action_registry=ACTION_REGISTRY) + + assert any( + err.code == "SCHEMA_MISMATCH" and err.field == "abnormal_count" + for err in errors + ) + + +def test_template_expression_with_valid_field_passes(): + workflow = _workflow_with_expression_param( + "{{ (result_of.fetch_temperatures.data | length) > 0 and result_of.fetch_temperatures.data[0].temperature }}" + ) + errors = validate_completed_workflow(workflow, action_registry=ACTION_REGISTRY) + + assert errors == [] diff --git a/tests/test_rule_based_repair_pipeline.py b/tests/test_rule_based_repair_pipeline.py index 26f2671..58adbc5 100644 --- a/tests/test_rule_based_repair_pipeline.py +++ b/tests/test_rule_based_repair_pipeline.py @@ -84,3 +84,32 @@ def test_rule_repairs_surface_missing_loop_body(): err.code == "INVALID_LOOP_BODY" and err.field == "body_subgraph" for err in remaining_errors ) + + +def test_rule_repairs_drop_invalid_switch_defaults(): + workflow = { + "nodes": [ + { + "id": "route_by_label", + "type": "switch", + "params": {"source": "{{ result_of.start.label }}", "field": "label"}, + "cases": [], + "default_to_node": "null", + }, + {"id": "start", "type": "action", "action_id": "demo.start", "params": {}}, + ], + "edges": [], + } + errors = [ + ValidationError( + code="UNDEFINED_REFERENCE", + node_id="route_by_label", + field="default_to_node", + message="switch default branch default_to_node points to nonexistent node 'null'", + ) + ] + + patched, _remaining_errors = apply_rule_based_repairs(workflow, [], errors) + + route = next(node for node in patched["nodes"] if node.get("id") == "route_by_label") + assert route.get("default_to_node") is None diff --git a/tests/test_workflow_builder_update.py b/tests/test_workflow_builder_update.py index f0dc325..ba48040 100644 --- a/tests/test_workflow_builder_update.py +++ b/tests/test_workflow_builder_update.py @@ -190,3 +190,41 @@ def test_builder_adds_switch_nodes_and_infers_edges(): assert any(cond == "default" for _, _, cond in edge_conditions) +def test_builder_replaces_loop_body_exports_references(): + builder = WorkflowBuilder() + builder.add_node( + node_id="loop_node", + node_type="loop", + params={ + "loop_kind": "for_each", + "source": "{{ result_of.start.items }}", + "item_alias": "item", + "exports": {"total": "{{ result_of.inner.total }}"}, + }, + ) + builder.add_node( + node_id="inner", + node_type="action", + action_id="demo.inner", + params={"value": "{{ loop.item.value }}"}, + parent_node_id="loop_node", + ) + builder.add_node( + node_id="use_export", + node_type="action", + action_id="demo.use_export", + params={ + "value": "{{ exports.total }}", + "raw_value": "exports.total", + }, + parent_node_id="loop_node", + ) + + workflow = builder.to_workflow() + loop_node = _find(workflow["nodes"], "loop_node") + body_nodes = (loop_node.get("params") or {}).get("body_subgraph", {}).get("nodes", []) + use_export = _find(body_nodes, "use_export") + + assert use_export["params"]["value"] == "{{ result_of.inner.total }}" + assert use_export["params"]["raw_value"] == "{{ result_of.inner.total }}" + diff --git a/velvetflow/jinja_utils.py b/velvetflow/jinja_utils.py index f2ae9fc..ff3c7b7 100644 --- a/velvetflow/jinja_utils.py +++ b/velvetflow/jinja_utils.py @@ -145,6 +145,71 @@ def _iter_filter_nodes(node: nodes.Node) -> list[nodes.Filter]: raise ValueError(f"{path} 使用了未注册的测试: {test_arg.value}") +def extract_jinja_reference_paths(expr: str) -> list[str]: + """Extract dotted reference paths from a Jinja expression. + + Returns a list of reference strings such as ``result_of.node.field`` or + ``loop.item.name`` derived from attribute/item lookup chains. + """ + + if not isinstance(expr, str) or not expr.strip(): + return [] + + env = get_jinja_env() + try: + parsed = env.parse(f"{{{{ {expr} }}}}") + except TemplateError: + return [] + + def _build_path(node: nodes.Node) -> str | None: + if isinstance(node, nodes.Name): + return node.name + if isinstance(node, nodes.Getattr): + base = _build_path(node.node) + if base: + return f"{base}.{node.attr}" + if isinstance(node, nodes.Getitem): + base = _build_path(node.node) + if not base: + return None + arg = node.arg + if isinstance(arg, nodes.Const): + if isinstance(arg.value, int): + return f"{base}[{arg.value}]" + if isinstance(arg.value, str): + return f"{base}.{arg.value}" + return None + return None + + def _collect(node: nodes.Node) -> list[str]: + collected: list[str] = [] + path = _build_path(node) + if path: + collected.append(path) + for child in node.iter_child_nodes(): + collected.extend(_collect(child)) + return collected + + raw_paths = _collect(parsed) + unique_paths = list(dict.fromkeys(raw_paths)) + + def _is_prefix(candidate: str, full: str) -> bool: + if not full.startswith(candidate) or len(full) <= len(candidate): + return False + next_char = full[len(candidate)] + return next_char in {".", "["} + + filtered: list[str] = [] + for path in unique_paths: + if any(_is_prefix(path, other) for other in unique_paths if other != path): + continue + if "." not in path and "[" not in path: + continue + filtered.append(path) + + return filtered + + def eval_jinja_expr(expr: str, context: Mapping[str, Any]) -> Any: """Evaluate a Jinja expression string with the provided context.""" diff --git a/velvetflow/planner/repair.py b/velvetflow/planner/repair.py index 6c001cb..e7424db 100644 --- a/velvetflow/planner/repair.py +++ b/velvetflow/planner/repair.py @@ -445,7 +445,11 @@ def apply_rule_based_repairs( ) if any(err.code == "UNDEFINED_REFERENCE" for err in validation_errors): patched_workflow, drop_summary = apply_repair_tool( - "drop_invalid_references", patched_workflow, remove_edges=True + tool_name="drop_invalid_references", + args={"remove_edges": True}, + workflow=patched_workflow, + validation_errors=validation_errors, + action_registry=action_registry, ) log_info( "[AutoRepair] UNDEFINED_REFERENCE was auto-cleaned; handing off to the LLM for analysis.", diff --git a/velvetflow/planner/workflow_builder.py b/velvetflow/planner/workflow_builder.py index b0bec3a..442cf1b 100644 --- a/velvetflow/planner/workflow_builder.py +++ b/velvetflow/planner/workflow_builder.py @@ -3,6 +3,7 @@ """Utilities for building workflow skeletons during planning.""" +import re import copy from typing import Any, Dict, Mapping, Optional @@ -10,6 +11,7 @@ from velvetflow.logging_utils import log_warn from velvetflow.loop_dsl import iter_workflow_and_loop_body_nodes from velvetflow.models import infer_depends_on_from_edges, infer_edges_from_bindings +from velvetflow.reference_utils import canonicalize_template_placeholders, normalize_reference_path def _normalize_condition_label(raw: Any) -> Optional[str]: @@ -83,6 +85,104 @@ def attach_condition_branches(workflow: Dict[str, Any]) -> Dict[str, Any]: _UNSET = object() +_TEMPLATE_REF_PATTERN = re.compile( + r"\{\{\s*([^{}]+?)\s*\}\}|\$\{\{\s*([^{}]+?)\s*\}\}|\$\{\s*([^{}]+?)\s*\}" +) + + +def _extract_export_expr(value: Any) -> str | None: + if not isinstance(value, str): + return None + stripped = value.strip() + if not stripped: + return None + normalized = normalize_reference_path(stripped) + return normalized.strip() + + +def _replace_exports_in_string( + value: str, *, export_exprs: Mapping[str, str], loop_id: str | None +) -> str: + if not export_exprs: + return value + + canonical = canonicalize_template_placeholders(value) + stripped = canonical.strip() + for key, expr in export_exprs.items(): + if stripped == f"exports.{key}": + return f"{{{{ {expr} }}}}" + if loop_id and stripped == f"result_of.{loop_id}.exports.{key}": + return f"{{{{ {expr} }}}}" + + def _match_key(expr: str) -> str | None: + expr = expr.strip() + if expr.startswith("exports."): + key = expr[len("exports.") :] + elif loop_id and expr.startswith(f"result_of.{loop_id}.exports."): + key = expr[len(f"result_of.{loop_id}.exports.") :] + else: + return None + return key if key in export_exprs else None + + def _replace(match: re.Match[str]) -> str: + expr = match.group(1) or match.group(2) or match.group(3) or "" + key = _match_key(expr) + if key: + return f"{{{{ {export_exprs[key]} }}}}" + return match.group(0) + + return _TEMPLATE_REF_PATTERN.sub(_replace, canonical) + + +def _replace_exports_in_params( + value: Any, *, export_exprs: Mapping[str, str], loop_id: str | None +) -> Any: + if isinstance(value, Mapping): + return { + key: _replace_exports_in_params(val, export_exprs=export_exprs, loop_id=loop_id) + for key, val in value.items() + } + if isinstance(value, list): + return [ + _replace_exports_in_params(item, export_exprs=export_exprs, loop_id=loop_id) + for item in value + ] + if isinstance(value, str): + return _replace_exports_in_string(value, export_exprs=export_exprs, loop_id=loop_id) + return value + + +def _replace_loop_body_exports(node: Mapping[str, Any]) -> None: + params = node.get("params") if isinstance(node.get("params"), Mapping) else None + if not isinstance(params, Mapping): + return + exports = params.get("exports") + if not isinstance(exports, Mapping): + return + loop_id = node.get("id") if isinstance(node.get("id"), str) else None + export_exprs = { + key: expr + for key, value in exports.items() + if isinstance(key, str) and (expr := _extract_export_expr(value)) + } + if not export_exprs: + return + + body = params.get("body_subgraph") if isinstance(params.get("body_subgraph"), Mapping) else None + body_nodes = body.get("nodes") if isinstance(body, Mapping) else None + if not isinstance(body_nodes, list): + return + + for body_node in body_nodes: + if not isinstance(body_node, Mapping): + continue + body_params = body_node.get("params") + if isinstance(body_params, Mapping): + body_node["params"] = _replace_exports_in_params( + body_params, export_exprs=export_exprs, loop_id=loop_id + ) + if body_node.get("type") == "loop": + _replace_loop_body_exports(body_node) class WorkflowBuilder: @@ -256,6 +356,9 @@ def to_workflow(self) -> Dict[str, Any]: "description": self.description, "nodes": root_nodes, } + for node in workflow.get("nodes", []) or []: + if isinstance(node, Mapping) and node.get("type") == "loop": + _replace_loop_body_exports(node) # Provide implicitly derived edges as read-only context for downstream # tools/LLM refinement while keeping the source of truth in param # bindings. 遍历主图与 loop 子图的所有节点,确保子图中的引用同样被纳入。 diff --git a/velvetflow/verification/jinja_validation.py b/velvetflow/verification/jinja_validation.py index fa592f6..dc5be00 100644 --- a/velvetflow/verification/jinja_validation.py +++ b/velvetflow/verification/jinja_validation.py @@ -7,7 +7,7 @@ from jinja2 import TemplateError -from velvetflow.jinja_utils import get_jinja_env +from velvetflow.jinja_utils import get_jinja_env, validate_jinja_expr from velvetflow.models import ValidationError from velvetflow.reference_utils import normalize_reference_path @@ -17,6 +17,16 @@ r"^(?P.+?)\s+if\s+(?P.+?)\s+else\s+(?P.+)$", re.IGNORECASE, ) +_JINJA_EXPR_HINTS_RE = re.compile( + r"(result_of\.|loop\.|\||==|!=|>=|<=|>|<|\band\b|\bor\b|\bnot\b|\bin\b)" +) + + +def _looks_like_jinja_expression(value: str) -> bool: + if not value: + return False + + return bool(_JINJA_EXPR_HINTS_RE.search(value)) def _normalize_jinja_expr(value: Any) -> Tuple[Any, bool]: @@ -45,6 +55,12 @@ def _normalize_jinja_expr(value: Any) -> Tuple[Any, bool]: if stripped and "{{" not in stripped and "{%" not in stripped: if _SIMPLE_PATH_RE.match(stripped): return f"{{{{ {stripped} }}}}", True + if _looks_like_jinja_expression(stripped): + try: + validate_jinja_expr(stripped) + except ValueError: + return f"{{{{ {repr(value)} }}}}", True + return f"{{{{ {stripped} }}}}", True # Fallback: wrap raw literals as Jinja templates so every param # remains Jinja-compatible (e.g., condition.field="employee_id"). return f"{{{{ {repr(value)} }}}}", True diff --git a/velvetflow/verification/node_rules.py b/velvetflow/verification/node_rules.py index 21e204b..a349977 100644 --- a/velvetflow/verification/node_rules.py +++ b/velvetflow/verification/node_rules.py @@ -6,7 +6,7 @@ import re from typing import Any, Dict, List, Mapping, Optional -from velvetflow.jinja_utils import validate_jinja_expr +from velvetflow.jinja_utils import extract_jinja_reference_paths, validate_jinja_expr from velvetflow.models import ValidationError from velvetflow.reference_utils import normalize_reference_path, parse_field_path from velvetflow.loop_dsl import loop_body_has_action @@ -498,108 +498,113 @@ def _walk_params_for_templates(obj: Any, path_prefix: str = "") -> None: for ref in _iter_template_references(obj): _validate_jinja_expression(ref, node_id=nid, field=path_prefix or "params", errors=errors) - ref_head = _strip_jinja_filters(ref) - if not ref_head: - continue - ref_path = normalize_reference_path(ref_head) - try: - ref_parts = parse_field_path(ref_path) - except Exception: - continue + for ref_path in extract_jinja_reference_paths(ref): + try: + ref_parts = parse_field_path(ref_path) + except Exception: + continue + + missing_target_node = False + if ref_path.startswith("result_of."): + target_node = None + if ref_parts and len(ref_parts) >= 2 and isinstance(ref_parts[1], str): + target_node = ref_parts[1] + + if target_node and target_node not in nodes_by_id: + missing_target_node = True + errors.append( + ValidationError( + code="SCHEMA_MISMATCH", + node_id=nid, + field=path_prefix, + message=( + f"Template reference '{ref}' on action node '{nid}' is invalid: " + f"Referenced node '{target_node}' does not exist." + ), + ) + ) + + if missing_target_node: + continue - missing_target_node = False - if ref_path.startswith("result_of."): - target_node = None - if ref_parts and len(ref_parts) >= 2 and isinstance(ref_parts[1], str): - target_node = ref_parts[1] + if _is_self_reference_path(ref_path, nid): + field_label = path_prefix or "params" + if path_prefix and not path_prefix.startswith("params"): + field_label = f"params.{path_prefix}" + _flag_self_reference(field_label, ref_path) + + if ref_parts and isinstance(ref_parts[0], str): + loop_ctx_root = ref_parts[0] + loop_node = nodes_by_id.get(loop_ctx_root) + if isinstance(loop_node, Mapping) and loop_node.get("type") == "loop": + loop_params = ( + loop_node.get("params") if isinstance(loop_node.get("params"), Mapping) else {} + ) + loop_item_alias = ( + loop_params.get("item_alias") + if isinstance(loop_params.get("item_alias"), str) + else None + ) + if len(ref_parts) >= 2 and isinstance(ref_parts[1], str): + root_field = ref_parts[1] + allowed_loop_fields = {"index", "size", "accumulator"} + if loop_item_alias: + allowed_loop_fields.add(loop_item_alias) + # Allow using item directly only when item_alias is missing or equals 'item' + if root_field == "item" and loop_item_alias and loop_item_alias != "item": + errors.append( + ValidationError( + code="SCHEMA_MISMATCH", + node_id=nid, + field=path_prefix, + message=( + f"Template reference '{ref}' on node '{nid}' is invalid: loop node '{loop_ctx_root}' currently exposes item_alias '{loop_item_alias}'," + "Use the alias instead of '.item' to access loop elements." + ), + ) + ) + continue + if isinstance(root_field, str) and root_field not in allowed_loop_fields and not ( + root_field == "item" and loop_item_alias in {None, "item"} + ): + errors.append( + ValidationError( + code="SCHEMA_MISMATCH", + node_id=nid, + field=path_prefix, + message=( + f"Template reference '{ref}' on node '{nid}' is invalid: loop node '{loop_ctx_root}' context only exposes {', '.join(sorted(allowed_loop_fields | {'item'}))}," + f"Field '{root_field}' was not found." + ), + ) + ) + continue + + schema_err = None + if ref_parts: + alias = ref_parts[0] + if alias_schemas and alias in alias_schemas: + schema_err = _schema_path_error(alias_schemas[alias], ref_parts[1:]) + else: + schema_err = _check_output_path_against_schema( + ref_path, + nodes_by_id, + actions_by_id, + loop_body_parents, + context_node_id=nid, + ) - if target_node and target_node not in nodes_by_id: - missing_target_node = True + if schema_err: errors.append( ValidationError( code="SCHEMA_MISMATCH", node_id=nid, field=path_prefix, message=( - f"Template reference '{ref}' on action node '{nid}' is invalid: " - f"Referenced node '{target_node}' does not exist." + f"Template reference '{ref}' on action node '{nid}' is invalid: {schema_err}" ), ) ) - - if missing_target_node: - continue - - if _is_self_reference_path(ref_path, nid): - field_label = path_prefix or "params" - if path_prefix and not path_prefix.startswith("params"): - field_label = f"params.{path_prefix}" - _flag_self_reference(field_label, ref_path) - - if ref_parts and isinstance(ref_parts[0], str): - loop_ctx_root = ref_parts[0] - loop_node = nodes_by_id.get(loop_ctx_root) - if isinstance(loop_node, Mapping) and loop_node.get("type") == "loop": - loop_params = loop_node.get("params") if isinstance(loop_node.get("params"), Mapping) else {} - loop_item_alias = loop_params.get("item_alias") if isinstance(loop_params.get("item_alias"), str) else None - if len(ref_parts) >= 2 and isinstance(ref_parts[1], str): - root_field = ref_parts[1] - allowed_loop_fields = {"index", "size", "accumulator"} - if loop_item_alias: - allowed_loop_fields.add(loop_item_alias) - # Allow using item directly only when item_alias is missing or equals 'item' - if root_field == "item" and loop_item_alias and loop_item_alias != "item": - errors.append( - ValidationError( - code="SCHEMA_MISMATCH", - node_id=nid, - field=path_prefix, - message=( - f"Template reference '{ref}' on node '{nid}' is invalid: loop node '{loop_ctx_root}' currently exposes item_alias '{loop_item_alias}'," - "Use the alias instead of '.item' to access loop elements." - ), - ) - ) - continue - if isinstance(root_field, str) and root_field not in allowed_loop_fields and not (root_field == "item" and loop_item_alias in {None, "item"}): - errors.append( - ValidationError( - code="SCHEMA_MISMATCH", - node_id=nid, - field=path_prefix, - message=( - f"Template reference '{ref}' on node '{nid}' is invalid: loop node '{loop_ctx_root}' context only exposes {', '.join(sorted(allowed_loop_fields | {'item'}))}," - f"Field '{root_field}' was not found." - ), - ) - ) - continue - - schema_err = None - if ref_parts: - alias = ref_parts[0] - if alias_schemas and alias in alias_schemas: - schema_err = _schema_path_error(alias_schemas[alias], ref_parts[1:]) - else: - schema_err = _check_output_path_against_schema( - ref_path, - nodes_by_id, - actions_by_id, - loop_body_parents, - context_node_id=nid, - ) - - if schema_err: - errors.append( - ValidationError( - code="SCHEMA_MISMATCH", - node_id=nid, - field=path_prefix, - message=( - f"Template reference '{ref}' on action node '{nid}' is invalid: {schema_err}" - ), - ) - ) elif isinstance(obj, Mapping): for key, value in list(obj.items()): new_prefix = f"{path_prefix}.{key}" if path_prefix else str(key)