Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions tests/test_jinja_template_reference_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Author: Zhongkai Fu ([email protected])
# License: BSD 3-Clause License

from pathlib import Path
import sys

ROOT_DIR = Path(__file__).resolve().parents[1]
if str(ROOT_DIR) not in sys.path:
sys.path.insert(0, str(ROOT_DIR))

from velvetflow.action_registry import BUSINESS_ACTIONS
from velvetflow.verification.validation import validate_completed_workflow

ACTION_REGISTRY = BUSINESS_ACTIONS


def _workflow_with_expression_param(expression: str):
return {
"workflow_name": "jinja_expression_param",
"description": "",
"nodes": [
{
"id": "fetch_temperatures",
"type": "action",
"action_id": "hr.get_today_temperatures.v1",
"params": {"date": "2024-01-01"},
},
{
"id": "record_event",
"type": "action",
"action_id": "hr.record_health_event.v1",
"params": {
"event_type": "temperature",
"date": "{{ result_of.fetch_temperatures.date }}",
"abnormal_count": expression,
},
},
],
"edges": [
{"from": "fetch_temperatures", "to": "record_event"},
],
}


def test_template_expression_with_missing_field_is_reported():
workflow = _workflow_with_expression_param(
"{{ (result_of.fetch_temperatures.data | length) > 0 and result_of.fetch_temperatures.data[0].missing }}"
)
errors = validate_completed_workflow(workflow, action_registry=ACTION_REGISTRY)

assert any(
err.code == "SCHEMA_MISMATCH" and err.field == "abnormal_count"
for err in errors
)


def test_template_expression_with_valid_field_passes():
workflow = _workflow_with_expression_param(
"{{ (result_of.fetch_temperatures.data | length) > 0 and result_of.fetch_temperatures.data[0].temperature }}"
)
errors = validate_completed_workflow(workflow, action_registry=ACTION_REGISTRY)

assert errors == []
65 changes: 65 additions & 0 deletions velvetflow/jinja_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,71 @@ def _iter_filter_nodes(node: nodes.Node) -> list[nodes.Filter]:
raise ValueError(f"{path} 使用了未注册的测试: {test_arg.value}")


def extract_jinja_reference_paths(expr: str) -> list[str]:
"""Extract dotted reference paths from a Jinja expression.

Returns a list of reference strings such as ``result_of.node.field`` or
``loop.item.name`` derived from attribute/item lookup chains.
"""

if not isinstance(expr, str) or not expr.strip():
return []

env = get_jinja_env()
try:
parsed = env.parse(f"{{{{ {expr} }}}}")
except TemplateError:
return []

def _build_path(node: nodes.Node) -> str | None:
if isinstance(node, nodes.Name):
return node.name
if isinstance(node, nodes.Getattr):
base = _build_path(node.node)
if base:
return f"{base}.{node.attr}"
if isinstance(node, nodes.Getitem):
base = _build_path(node.node)
if not base:
return None
arg = node.arg
if isinstance(arg, nodes.Const):
if isinstance(arg.value, int):
return f"{base}[{arg.value}]"
if isinstance(arg.value, str):
return f"{base}.{arg.value}"
Comment on lines +176 to +180
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve bracket string keys when extracting paths

This conversion treats any string index (e.g. foo['a.b'] or foo['0']) as dotted access, which loses the original bracket semantics. When a schema has property names with dots or numeric strings, extract_jinja_reference_paths will emit foo.a.b or foo.0, and parse_field_path will split or coerce those into nested fields/array indices. That can raise SCHEMA_MISMATCH for templates that are actually valid in Jinja. This regression shows up only when bracket string keys are used in expressions, but in that case the new validation will flag false errors.

Useful? React with 👍 / 👎.

return None
return None

def _collect(node: nodes.Node) -> list[str]:
collected: list[str] = []
path = _build_path(node)
if path:
collected.append(path)
for child in node.iter_child_nodes():
collected.extend(_collect(child))
return collected

raw_paths = _collect(parsed)
unique_paths = list(dict.fromkeys(raw_paths))

def _is_prefix(candidate: str, full: str) -> bool:
if not full.startswith(candidate) or len(full) <= len(candidate):
return False
next_char = full[len(candidate)]
return next_char in {".", "["}

filtered: list[str] = []
for path in unique_paths:
if any(_is_prefix(path, other) for other in unique_paths if other != path):
continue
if "." not in path and "[" not in path:
continue
filtered.append(path)

return filtered


def eval_jinja_expr(expr: str, context: Mapping[str, Any]) -> Any:
"""Evaluate a Jinja expression string with the provided context."""

Expand Down
189 changes: 97 additions & 92 deletions velvetflow/verification/node_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import re
from typing import Any, Dict, List, Mapping, Optional

from velvetflow.jinja_utils import validate_jinja_expr
from velvetflow.jinja_utils import extract_jinja_reference_paths, validate_jinja_expr
from velvetflow.models import ValidationError
from velvetflow.reference_utils import normalize_reference_path, parse_field_path
from velvetflow.loop_dsl import loop_body_has_action
Expand Down Expand Up @@ -498,108 +498,113 @@ def _walk_params_for_templates(obj: Any, path_prefix: str = "") -> None:
for ref in _iter_template_references(obj):
_validate_jinja_expression(ref, node_id=nid, field=path_prefix or "params", errors=errors)

ref_head = _strip_jinja_filters(ref)
if not ref_head:
continue
ref_path = normalize_reference_path(ref_head)
try:
ref_parts = parse_field_path(ref_path)
except Exception:
continue
for ref_path in extract_jinja_reference_paths(ref):
try:
ref_parts = parse_field_path(ref_path)
except Exception:
continue

missing_target_node = False
if ref_path.startswith("result_of."):
target_node = None
if ref_parts and len(ref_parts) >= 2 and isinstance(ref_parts[1], str):
target_node = ref_parts[1]

if target_node and target_node not in nodes_by_id:
missing_target_node = True
errors.append(
ValidationError(
code="SCHEMA_MISMATCH",
node_id=nid,
field=path_prefix,
message=(
f"Template reference '{ref}' on action node '{nid}' is invalid: "
f"Referenced node '{target_node}' does not exist."
),
)
)

if missing_target_node:
continue

missing_target_node = False
if ref_path.startswith("result_of."):
target_node = None
if ref_parts and len(ref_parts) >= 2 and isinstance(ref_parts[1], str):
target_node = ref_parts[1]
if _is_self_reference_path(ref_path, nid):
field_label = path_prefix or "params"
if path_prefix and not path_prefix.startswith("params"):
field_label = f"params.{path_prefix}"
_flag_self_reference(field_label, ref_path)

if ref_parts and isinstance(ref_parts[0], str):
loop_ctx_root = ref_parts[0]
loop_node = nodes_by_id.get(loop_ctx_root)
if isinstance(loop_node, Mapping) and loop_node.get("type") == "loop":
loop_params = (
loop_node.get("params") if isinstance(loop_node.get("params"), Mapping) else {}
)
loop_item_alias = (
loop_params.get("item_alias")
if isinstance(loop_params.get("item_alias"), str)
else None
)
if len(ref_parts) >= 2 and isinstance(ref_parts[1], str):
root_field = ref_parts[1]
allowed_loop_fields = {"index", "size", "accumulator"}
if loop_item_alias:
allowed_loop_fields.add(loop_item_alias)
# Allow using item directly only when item_alias is missing or equals 'item'
if root_field == "item" and loop_item_alias and loop_item_alias != "item":
errors.append(
ValidationError(
code="SCHEMA_MISMATCH",
node_id=nid,
field=path_prefix,
message=(
f"Template reference '{ref}' on node '{nid}' is invalid: loop node '{loop_ctx_root}' currently exposes item_alias '{loop_item_alias}',"
"Use the alias instead of '.item' to access loop elements."
),
)
)
continue
if isinstance(root_field, str) and root_field not in allowed_loop_fields and not (
root_field == "item" and loop_item_alias in {None, "item"}
):
errors.append(
ValidationError(
code="SCHEMA_MISMATCH",
node_id=nid,
field=path_prefix,
message=(
f"Template reference '{ref}' on node '{nid}' is invalid: loop node '{loop_ctx_root}' context only exposes {', '.join(sorted(allowed_loop_fields | {'item'}))},"
f"Field '{root_field}' was not found."
),
)
)
continue

schema_err = None
if ref_parts:
alias = ref_parts[0]
if alias_schemas and alias in alias_schemas:
schema_err = _schema_path_error(alias_schemas[alias], ref_parts[1:])
else:
schema_err = _check_output_path_against_schema(
ref_path,
nodes_by_id,
actions_by_id,
loop_body_parents,
context_node_id=nid,
)

if target_node and target_node not in nodes_by_id:
missing_target_node = True
if schema_err:
errors.append(
ValidationError(
code="SCHEMA_MISMATCH",
node_id=nid,
field=path_prefix,
message=(
f"Template reference '{ref}' on action node '{nid}' is invalid: "
f"Referenced node '{target_node}' does not exist."
f"Template reference '{ref}' on action node '{nid}' is invalid: {schema_err}"
),
)
)

if missing_target_node:
continue

if _is_self_reference_path(ref_path, nid):
field_label = path_prefix or "params"
if path_prefix and not path_prefix.startswith("params"):
field_label = f"params.{path_prefix}"
_flag_self_reference(field_label, ref_path)

if ref_parts and isinstance(ref_parts[0], str):
loop_ctx_root = ref_parts[0]
loop_node = nodes_by_id.get(loop_ctx_root)
if isinstance(loop_node, Mapping) and loop_node.get("type") == "loop":
loop_params = loop_node.get("params") if isinstance(loop_node.get("params"), Mapping) else {}
loop_item_alias = loop_params.get("item_alias") if isinstance(loop_params.get("item_alias"), str) else None
if len(ref_parts) >= 2 and isinstance(ref_parts[1], str):
root_field = ref_parts[1]
allowed_loop_fields = {"index", "size", "accumulator"}
if loop_item_alias:
allowed_loop_fields.add(loop_item_alias)
# Allow using item directly only when item_alias is missing or equals 'item'
if root_field == "item" and loop_item_alias and loop_item_alias != "item":
errors.append(
ValidationError(
code="SCHEMA_MISMATCH",
node_id=nid,
field=path_prefix,
message=(
f"Template reference '{ref}' on node '{nid}' is invalid: loop node '{loop_ctx_root}' currently exposes item_alias '{loop_item_alias}',"
"Use the alias instead of '.item' to access loop elements."
),
)
)
continue
if isinstance(root_field, str) and root_field not in allowed_loop_fields and not (root_field == "item" and loop_item_alias in {None, "item"}):
errors.append(
ValidationError(
code="SCHEMA_MISMATCH",
node_id=nid,
field=path_prefix,
message=(
f"Template reference '{ref}' on node '{nid}' is invalid: loop node '{loop_ctx_root}' context only exposes {', '.join(sorted(allowed_loop_fields | {'item'}))},"
f"Field '{root_field}' was not found."
),
)
)
continue

schema_err = None
if ref_parts:
alias = ref_parts[0]
if alias_schemas and alias in alias_schemas:
schema_err = _schema_path_error(alias_schemas[alias], ref_parts[1:])
else:
schema_err = _check_output_path_against_schema(
ref_path,
nodes_by_id,
actions_by_id,
loop_body_parents,
context_node_id=nid,
)

if schema_err:
errors.append(
ValidationError(
code="SCHEMA_MISMATCH",
node_id=nid,
field=path_prefix,
message=(
f"Template reference '{ref}' on action node '{nid}' is invalid: {schema_err}"
),
)
)
elif isinstance(obj, Mapping):
for key, value in list(obj.items()):
new_prefix = f"{path_prefix}.{key}" if path_prefix else str(key)
Expand Down
Loading