From f7e99bf447135ddd2b1ede1ecee94b5ba8647546 Mon Sep 17 00:00:00 2001 From: Zhongkai Fu Date: Tue, 6 Jan 2026 20:50:27 -0800 Subject: [PATCH] Filter LLM context around loop subgraphs --- velvetflow/planner/structure.py | 179 +++++++++++++++++++++++++------- 1 file changed, 143 insertions(+), 36 deletions(-) diff --git a/velvetflow/planner/structure.py b/velvetflow/planner/structure.py index 608fe4e..8ee83b4 100644 --- a/velvetflow/planner/structure.py +++ b/velvetflow/planner/structure.py @@ -20,7 +20,7 @@ WorkflowBuilder, attach_condition_branches, ) -from velvetflow.loop_dsl import iter_workflow_and_loop_body_nodes +from velvetflow.loop_dsl import index_loop_body_nodes, iter_workflow_and_loop_body_nodes from velvetflow.planner.requirement_analysis import _normalize_requirements_payload from velvetflow.search import HybridActionSearchService from velvetflow.models import ( @@ -772,47 +772,78 @@ def _summarize_node_outputs( return _summarize_output_fields_from_schema(schema) -def _build_global_context( +def _summarize_node_outputs_from_map( + node: Mapping[str, Any], + action_schemas: Mapping[str, Dict[str, Any]], +) -> List[str]: + node_type = node.get("type") + if node_type == "loop": + params = node.get("params") if isinstance(node.get("params"), Mapping) else {} + exports = params.get("exports") if isinstance(params, Mapping) else {} + if isinstance(exports, Mapping): + return [k for k in exports.keys() if isinstance(k, str)] + return [] + + action_id = node.get("action_id") + schema = None + if isinstance(action_id, str): + schema = action_schemas.get(action_id, {}).get("output_schema") + schema = node.get("out_params_schema") or schema + return _summarize_output_fields_from_schema(schema if isinstance(schema, Mapping) else None) + + +def _build_global_context_from_nodes( *, - workflow: Workflow, + workflow_name: str, + description: str, + nodes: Sequence[Mapping[str, Any]], + edges: Sequence[Mapping[str, Any]], action_schemas: Mapping[str, Dict[str, Any]], filled_params: Mapping[str, Mapping[str, Any]], processed_node_ids: Sequence[str], binding_memory: Mapping[str, str], ) -> Dict[str, Any]: upstream_map: Dict[str, List[str]] = {} - for e in workflow.edges: - upstream_map.setdefault(e.to_node, []).append(e.from_node) + for e in edges: + if not isinstance(e, Mapping): + continue + to_node = e.get("to") or e.get("to_node") + from_node = e.get("from") or e.get("from_node") + if isinstance(to_node, str) and isinstance(from_node, str): + upstream_map.setdefault(to_node, []).append(from_node) node_summaries: List[Dict[str, Any]] = [] - for n in workflow.nodes: - schema = action_schemas.get(n.action_id, {}) if n.action_id else {} + for node in nodes: + if not isinstance(node, Mapping): + continue + node_id = node.get("id") if isinstance(node.get("id"), str) else None + action_id = node.get("action_id") if isinstance(node.get("action_id"), str) else None + schema = action_schemas.get(action_id, {}) if action_id else {} node_summaries.append( { - "id": n.id, - "type": n.type, - "action_id": n.action_id, - "display_name": n.display_name, + "id": node_id, + "type": node.get("type"), + "action_id": action_id, + "display_name": node.get("display_name"), "domain": schema.get("domain"), - "out_params_schema": getattr(n, "out_params_schema", None) - or schema.get("output_schema"), - "output_fields": _summarize_node_outputs(n, action_schemas), + "out_params_schema": node.get("out_params_schema") or schema.get("output_schema"), + "output_fields": _summarize_node_outputs_from_map(node, action_schemas), "arg_required_fields": ( schema.get("arg_schema", {}).get("required") if isinstance(schema.get("arg_schema"), Mapping) else None ), - "upstream": upstream_map.get(n.id, []), - "params_snapshot": filled_params.get(n.id) - if n.id in processed_node_ids + "upstream": upstream_map.get(node_id, []), + "params_snapshot": filled_params.get(node_id) + if node_id in processed_node_ids else None, } ) return { "workflow": { - "name": workflow.workflow_name, - "description": workflow.description, + "name": workflow_name, + "description": description, }, "node_summaries": node_summaries, "entity_binding_hints": [ @@ -1884,17 +1915,83 @@ def _build_workflow_for_params() -> Workflow: return Workflow.model_validate(workflow_dict) def _build_param_context(node_id: str) -> Dict[str, Any]: - workflow = _build_workflow_for_params() - nodes_by_id = {n.id: n for n in workflow.nodes} + workflow_dict = _attach_inferred_edges(builder.to_workflow()) + nodes = [ + n + for n in iter_workflow_and_loop_body_nodes(workflow_dict) + if isinstance(n, Mapping) + ] + nodes_by_id = { + n.get("id"): n for n in nodes if isinstance(n.get("id"), str) + } node = nodes_by_id.get(node_id) if not node: return {} - upstream_nodes = get_referenced_nodes(workflow, node_id) - allowed_node_ids = [n.id for n in upstream_nodes] + edges = workflow_dict.get("edges") + edges = edges if isinstance(edges, list) else [] + loop_body_parents = index_loop_body_nodes(workflow_dict) + loop_parent_id = loop_body_parents.get(node_id) + loop_body_ids = set(loop_body_parents.keys()) + + def _find_referenced_nodes(target_id: str) -> List[Mapping[str, Any]]: + referenced: List[Mapping[str, Any]] = [] + seen: set[str] = set() + for edge in edges: + if not isinstance(edge, Mapping): + continue + to_node = edge.get("to") or edge.get("to_node") + from_node = edge.get("from") or edge.get("from_node") + if to_node != target_id or not isinstance(from_node, str): + continue + if from_node in seen: + continue + upstream_node = nodes_by_id.get(from_node) + if upstream_node: + referenced.append(upstream_node) + seen.add(from_node) + return referenced + + loop_related_ids: set[str] | None = None + if loop_parent_id: + loop_related_ids = {loop_parent_id} + loop_related_ids.update( + { + nid + for nid, parent_id in loop_body_parents.items() + if parent_id == loop_parent_id + } + ) + loop_related_ids.update( + n.get("id") + for n in _find_referenced_nodes(loop_parent_id) + if isinstance(n.get("id"), str) + ) + + def _allow_node_id(nid: str) -> bool: + if loop_related_ids is not None: + return nid in loop_related_ids + return nid not in loop_body_ids + + upstream_nodes = [ + n + for n in _find_referenced_nodes(node_id) + if isinstance(n.get("id"), str) and _allow_node_id(n["id"]) + ] + allowed_node_ids = [ + n.get("id") for n in upstream_nodes if isinstance(n.get("id"), str) + ] binding_memory = _build_binding_memory(filled_params, validated_node_ids) - global_context = _build_global_context( - workflow=workflow, + context_nodes = [ + n + for n in nodes + if isinstance(n.get("id"), str) and _allow_node_id(n["id"]) + ] + global_context = _build_global_context_from_nodes( + workflow_name=workflow_dict.get("workflow_name", "unnamed_workflow"), + description=workflow_dict.get("description", ""), + nodes=context_nodes, + edges=edges, action_schemas=action_schemas, filled_params=filled_params, processed_node_ids=validated_node_ids, @@ -1902,24 +1999,34 @@ def _build_param_context(node_id: str) -> Dict[str, Any]: ) upstream_context = [] for n in upstream_nodes: - action_schema = action_schemas.get(n.action_id, {}) if n.action_id else {} + action_id = n.get("action_id") if isinstance(n.get("action_id"), str) else None + action_schema = action_schemas.get(action_id, {}) if action_id else {} upstream_context.append( { - "id": n.id, - "type": n.type, - "action_id": n.action_id, + "id": n.get("id"), + "type": n.get("type"), + "action_id": action_id, "output_schema": action_schema.get("output_schema"), - "params": filled_params.get(n.id, n.params), + "params": filled_params.get( + n.get("id"), + n.get("params") if isinstance(n.get("params"), Mapping) else {}, + ), } ) - target_action_schema = action_schemas.get(node.action_id, {}) if node.action_id else {} + target_action_id = node.get("action_id") if isinstance(node.get("action_id"), str) else None + target_action_schema = ( + action_schemas.get(target_action_id, {}) if target_action_id else {} + ) return { "target_node": { - "id": node.id, - "type": node.type, - "action_id": node.action_id, - "display_name": node.display_name, - "existing_params": filled_params.get(node.id, node.params), + "id": node_id, + "type": node.get("type"), + "action_id": target_action_id, + "display_name": node.get("display_name"), + "existing_params": filled_params.get( + node_id, + node.get("params") if isinstance(node.get("params"), Mapping) else {}, + ), }, "arg_schema": target_action_schema.get("arg_schema"), "allowed_node_ids": allowed_node_ids,