Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions run_unified_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Author: Zhongkai Fu ([email protected])
# License: BSD 3-Clause License

"""Command-line entrypoint for the unified workflow planning Agent.

This tool drives the unified Agent/Runner pipeline that can build the
workflow structure, fill parameters, validate, repair, and update the
workflow as needed. The resulting workflow DSL will be persisted to disk.
"""

from __future__ import annotations

import argparse
import json
import os
import sys
from typing import Optional

from velvetflow.action_registry import BUSINESS_ACTIONS
from velvetflow.config import OPENAI_MODEL
from velvetflow.planner.unified_agent import run_workflow_planning_agent
from velvetflow.search import build_default_search_service
from velvetflow.visualization import render_workflow_dag

DEFAULT_OUTPUT_JSON = "workflow_unified_output.json"
DEFAULT_OUTPUT_DAG = "workflow_unified_dag.jpg"


def _prompt_requirement(default_text: str) -> str:
user_nl = input("请输入你的流程需求(直接回车使用默认示例):\n> ").strip()
if not user_nl:
user_nl = default_text
print("\n使用默认示例:", user_nl)
return user_nl


def main(argv: Optional[list[str]] = None) -> int:
parser = argparse.ArgumentParser(description="Run the unified workflow planning Agent.")
parser.add_argument(
"--requirement",
type=str,
help="自然语言需求描述,不提供时会提示输入(回车使用示例)",
)
parser.add_argument(
"--output",
type=str,
default=DEFAULT_OUTPUT_JSON,
help=f"输出 workflow JSON 文件路径(默认: {DEFAULT_OUTPUT_JSON})",
)
parser.add_argument(
"--dag",
type=str,
default=DEFAULT_OUTPUT_DAG,
help=f"输出 workflow DAG 图片路径(默认: {DEFAULT_OUTPUT_DAG})",
)
args = parser.parse_args(argv)

if not os.environ.get("OPENAI_API_KEY"):
print("请先设置环境变量 OPENAI_API_KEY 再运行。")
return 1

default_requirement = (
"每天早上 5 点,从某信息源中获取当日的若干条记录,"
"如果存在满足特定关键字条件的记录,请对这些记录进行总结,并发送通知给我。"
)
requirement = args.requirement or _prompt_requirement(default_requirement)

search_service = build_default_search_service()

try:
workflow = run_workflow_planning_agent(
nl_requirement=requirement,
action_registry=BUSINESS_ACTIONS,
search_service=search_service,
model=OPENAI_MODEL,
)
except Exception as exc: # pragma: no cover - CLI surface
print("\n[unified_agent] 工作流规划失败:", repr(exc))
return 1

try:
with open(args.output, "w", encoding="utf-8") as f:
json.dump(workflow.model_dump(by_alias=True), f, indent=2, ensure_ascii=False)
print(f"\n已将工作流以 JSON 格式保存至:{args.output}")

dag_path = render_workflow_dag(workflow, output_path=args.dag)
print(f"已将最终工作流 DAG 保存为 JPEG:{dag_path}")
except Exception as exc: # pragma: no cover - CLI surface
print("\n[warning] 工作流持久化失败:", repr(exc))
return 1

print("\n现在可以使用 execute_workflow.py 从保存的 JSON 执行该流程。")
return 0


if __name__ == "__main__":
sys.exit(main())
2 changes: 2 additions & 0 deletions velvetflow/planner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from velvetflow.planner.structure import plan_workflow_structure_with_llm
from velvetflow.planner.params import fill_params_with_llm
from velvetflow.planner.repair import repair_workflow_with_llm
from velvetflow.planner.unified_agent import run_workflow_planning_agent
from velvetflow.planner.update import update_workflow_with_llm
from velvetflow.planner.relations import (
build_node_relations,
Expand All @@ -25,6 +26,7 @@
"plan_workflow_structure_with_llm",
"fill_params_with_llm",
"repair_workflow_with_llm",
"run_workflow_planning_agent",
"update_workflow_with_llm",
"build_node_relations",
"get_referenced_nodes",
Expand Down
229 changes: 229 additions & 0 deletions velvetflow/planner/unified_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
# Author: Zhongkai Fu ([email protected])
# License: BSD 3-Clause License

"""Single Agent entrypoint that exposes all planner tools at once.

This module wraps structure planning, parameter completion, validation,
repair and workflow updates into one Agent SDK workflow. The Agent is
executed by :class:`Runner` so the LLM can freely choose the right tool
sequence to satisfy a user's requirement and return a validated workflow
JSON.
"""

from __future__ import annotations

import asyncio
import json
from typing import Any, Dict, List, Mapping, Optional, Sequence

from velvetflow.config import OPENAI_MODEL
from velvetflow.logging_utils import log_section, log_warn
from velvetflow.models import ValidationError, Workflow
from velvetflow.planner.agent_runtime import Agent, Runner, function_tool
from velvetflow.planner.params import fill_params_with_llm
from velvetflow.planner.repair import repair_workflow_with_llm
from velvetflow.planner.structure import plan_workflow_structure_with_llm
from velvetflow.planner.update import update_workflow_with_llm
from velvetflow.search import HybridActionSearchService, build_default_search_service
from velvetflow.verification import validate_completed_workflow


def _workflow_to_dict(workflow: Workflow | Mapping[str, Any]) -> Dict[str, Any]:
if isinstance(workflow, Workflow):
return workflow.model_dump(by_alias=True)
if isinstance(workflow, Mapping):
return Workflow.model_validate(workflow).model_dump(by_alias=True)
raise ValueError("workflow 必须是 Mapping 或 Workflow 实例。")


def _serialize_validation_errors(errors: Sequence[ValidationError]) -> List[Dict[str, Any]]:
return [
{
"code": err.code,
"node_id": err.node_id,
"field": err.field,
"message": err.message,
}
for err in errors
]


def run_workflow_planning_agent(
nl_requirement: str,
action_registry: List[Dict[str, Any]],
search_service: Optional[HybridActionSearchService] = None,
base_workflow: Optional[Mapping[str, Any]] = None,
*,
model: str = OPENAI_MODEL,
max_turns: int = 32,
) -> Workflow:
"""Expose planner lifecycle tools through one Agent SDK runner.

The Agent exposes dedicated tools for structure building, parameter fill,
validation, repair, update and final submission. The LLM can choose any
order and combination via the Agent SDK Runner.
"""

service = search_service or build_default_search_service()
working_workflow: Optional[Dict[str, Any]] = _workflow_to_dict(base_workflow) if base_workflow else None
finalized_workflow: Optional[Dict[str, Any]] = None
latest_validation_errors: List[ValidationError] = []

system_prompt = (
"你是一个统一的 Workflow Agent,负责使用提供的工具完成构建、补参、校验、修复与更新。\n"
"可以按需调用 build_workflow、fill_workflow_params、validate_workflow、repair_workflow、update_workflow、submit_final_workflow 组合出满足用户需求的流程。\n"
"所有 params 必须使用 Jinja 表达式或字面量,禁止 __from__/__agg__,引用 loop 结果时只能使用 exports.items/aggregates。"
)

@function_tool(strict_mode=False)
def build_workflow(requirement: Optional[str] = None) -> Mapping[str, Any]:
"""Generate a workflow skeleton from a natural language requirement."""

nonlocal working_workflow, latest_validation_errors
req = requirement or nl_requirement
skeleton = plan_workflow_structure_with_llm(
req,
search_service=service,
action_registry=action_registry,
)
working_workflow = Workflow.model_validate(skeleton).model_dump(by_alias=True)
latest_validation_errors = []
return {"status": "ok", "workflow": working_workflow}

@function_tool(strict_mode=False)
def fill_workflow_params(workflow: Optional[Mapping[str, Any]] = None) -> Mapping[str, Any]:
"""Complete node parameters for the current or provided workflow."""

nonlocal working_workflow, latest_validation_errors
target = _workflow_to_dict(workflow or working_workflow or {})
filled = fill_params_with_llm(target, action_registry=action_registry, model=model)
working_workflow = Workflow.model_validate(filled).model_dump(by_alias=True)
latest_validation_errors = []
return {"status": "ok", "workflow": working_workflow}

@function_tool(strict_mode=False)
def validate_workflow(workflow: Optional[Mapping[str, Any]] = None) -> Mapping[str, Any]:
"""Run static validation and return structured errors if any."""

nonlocal working_workflow, latest_validation_errors
target = _workflow_to_dict(workflow or working_workflow or {})
errors = validate_completed_workflow(target, action_registry=action_registry)
working_workflow = target
latest_validation_errors = errors
return {
"status": "ok" if not errors else "failed",
"workflow": target,
"errors": _serialize_validation_errors(errors),
}

@function_tool(strict_mode=False)
def repair_workflow(
workflow: Optional[Mapping[str, Any]] = None,
error_summary: Optional[str] = None,
) -> Mapping[str, Any]:
"""Repair the workflow using the latest validation errors."""

nonlocal working_workflow, latest_validation_errors
target = _workflow_to_dict(workflow or working_workflow or {})
fixed = repair_workflow_with_llm(
broken_workflow=target,
validation_errors=latest_validation_errors,
action_registry=action_registry,
error_summary=error_summary,
previous_failed_attempts=None,
model=model,
)
working_workflow = Workflow.model_validate(fixed).model_dump(by_alias=True)
latest_validation_errors = []
return {"status": "ok", "workflow": working_workflow}

@function_tool(strict_mode=False)
def update_workflow(
requirement: Optional[str] = None,
workflow: Optional[Mapping[str, Any]] = None,
validation_errors: Optional[Sequence[Mapping[str, Any]]] = None,
) -> Mapping[str, Any]:
"""Update the workflow according to a requirement or validation errors."""

nonlocal working_workflow, latest_validation_errors
req = requirement or nl_requirement
target = _workflow_to_dict(workflow or working_workflow or {})
parsed_errors: Optional[Sequence[ValidationError]] = latest_validation_errors
if validation_errors is not None:
parsed_errors = [
ValidationError(
code=e.get("code", "INVALID_SCHEMA"),
node_id=e.get("node_id"),
field=e.get("field"),
message=str(e.get("message", "")),
)
for e in validation_errors
if isinstance(e, Mapping)
]
updated = update_workflow_with_llm(
workflow_raw=target,
requirement=req,
action_registry=action_registry,
model=model,
validation_errors=parsed_errors or None,
)
working_workflow = Workflow.model_validate(updated).model_dump(by_alias=True)
latest_validation_errors = list(parsed_errors or [])
return {"status": "ok", "workflow": working_workflow}

@function_tool(strict_mode=False)
def submit_final_workflow(workflow: Optional[Mapping[str, Any]] = None) -> Mapping[str, Any]:
"""Submit the final workflow once all checks pass."""

nonlocal working_workflow, finalized_workflow
target = _workflow_to_dict(workflow or working_workflow or {})
working_workflow = target
finalized_workflow = target
return {"status": "ok", "workflow": target}

agent = Agent(
name="WorkflowOrchestrator",
instructions=system_prompt,
tools=[
build_workflow,
fill_workflow_params,
validate_workflow,
repair_workflow,
update_workflow,
submit_final_workflow,
],
model=model,
)

log_section("统一 Agent 工作流规划")
run_input: Any = json.dumps(
{
"requirement": nl_requirement,
"base_workflow": working_workflow,
"action_registry_size": len(action_registry),
},
ensure_ascii=False,
)

try:
Runner.run_sync(agent, run_input, max_turns=max_turns)
except TypeError:
coro = Runner.run(agent, run_input) # type: ignore[call-arg]
result = coro if not asyncio.iscoroutine(coro) else asyncio.run(coro)
_ = result

if finalized_workflow is None:
if working_workflow is None:
raise RuntimeError("Agent 未提交最终 workflow,且没有可用的工作副本。")
if latest_validation_errors:
raise RuntimeError(
"Agent 结束但 workflow 仍未通过校验:"
+ "; ".join(err.message for err in latest_validation_errors)
)
log_warn("[run_workflow_planning_agent] 未收到 submit_final_workflow,返回当前工作副本。")
finalized_workflow = working_workflow

return Workflow.model_validate(finalized_workflow)


__all__ = ["run_workflow_planning_agent"]
Loading