Skip to content

feat: Add A2A Multi-Agent Demo Notebook with Setup Modifications #246

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions demos/a2a_llama_stack/agents/a2a_composer/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import os
import logging

from llama_stack_client import LlamaStackClient, Agent

from common.server import A2AServer
from common.types import AgentCard, AgentCapabilities, AgentSkill

from .task_manager import AgentTaskManager, SUPPORTED_CONTENT_TYPES

logging.basicConfig(level=logging.INFO)


def build_server(host: str = "0.0.0.0", port: int = 10012):
# 1) instantiate your agent with the required parameters
agent = Agent(
client=LlamaStackClient(base_url=os.getenv("REMOTE_BASE_URL", "http://localhost:8321")),
model=os.getenv("INFERENCE_MODEL_ID", "llama3.2:3b-instruct-fp16"),
instructions=("You are skilled at writing human-friendly text based on the query and associated skills."),
max_infer_iters=3,
sampling_params = {
"strategy": {"type": "greedy"},
"max_tokens": 4096,
},
)

# 2) wrap it in the A2A TaskManager
task_manager = AgentTaskManager(agent=agent, internal_session_id=True)

# 3) advertise your tools as AgentSkills
card = AgentCard(
name="Writing Agent",
description="Generate human-friendly text based on the query and associated skills",
url=f"http://{host}:{port}/",
version="0.1.0",
defaultInputModes=["text/plain"],
defaultOutputModes=SUPPORTED_CONTENT_TYPES,
capabilities=AgentCapabilities(
streaming=False,
pushNotifications=False,
stateTransitionHistory=False,
),
skills = [
AgentSkill(
id="writing_agent",
name="Writing Agent",
description="Write human-friendly text based on the query and associated skills",
tags=["writing"],
examples=["Write human-friendly text based on the query and associated skills"],
inputModes=["text/plain"],
outputModes=["application/json"],
)
]
)

return A2AServer(
agent_card=card,
task_manager=task_manager,
host=host,
port=port,
)

if __name__ == "__main__":
import click

@click.command()
@click.option("--host", default="0.0.0.0")
@click.option("--port", default=10010, type=int)
def main(host, port):
build_server(host, port).start()

main()
Empty file.
134 changes: 134 additions & 0 deletions demos/a2a_llama_stack/agents/a2a_composer/task_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import logging
from typing import AsyncIterable, Union, AsyncIterator

from llama_stack_client import Agent, AgentEventLogger

import common.server.utils as utils
from common.server.task_manager import InMemoryTaskManager
from common.types import (
SendTaskRequest, SendTaskResponse,
SendTaskStreamingRequest, SendTaskStreamingResponse,
TaskStatus, Artifact,
Message, TaskState,
TaskStatusUpdateEvent, TaskArtifactUpdateEvent,
JSONRPCResponse,
)

logger = logging.getLogger(__name__)

SUPPORTED_CONTENT_TYPES = ["text", "text/plain", "application/json"]


class AgentTaskManager(InMemoryTaskManager):
def __init__(self, agent: Agent, internal_session_id=False):
super().__init__()
self.agent = agent
if internal_session_id:
self.session_id = self.agent.create_session("custom-agent-session")
else:
self.session_id = None

def _validate_request(
self, request: Union[SendTaskRequest, SendTaskStreamingRequest]
) -> JSONRPCResponse | None:
params = request.params
if not utils.are_modalities_compatible(
params.acceptedOutputModes,
SUPPORTED_CONTENT_TYPES
):
logger.warning("Unsupported output modes: %s", params.acceptedOutputModes)
return utils.new_incompatible_types_error(request.id)
return None

async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse:
err = self._validate_request(request)
if err:
return err

await self.upsert_task(request.params)
result = self._invoke(
request.params.message.parts[0].text,
request.params.sessionId
)
parts = [{"type": "text", "text": result}]
status = TaskStatus(state=TaskState.COMPLETED, message=Message(role="agent", parts=parts))
task = await self._update_store(request.params.id, status, [Artifact(parts=parts)])
return SendTaskResponse(id=request.id, result=task)

async def on_send_task_subscribe(
self, request: SendTaskStreamingRequest
) -> AsyncIterable[SendTaskStreamingResponse] | JSONRPCResponse:
err = self._validate_request(request)
if err:
return err

await self.upsert_task(request.params)
return self._stream_generator(request)

async def _stream_generator(
self, request: SendTaskStreamingRequest
) -> AsyncIterable[SendTaskStreamingResponse]:
params = request.params
query = params.message.parts[0].text

async for update in self._stream(query, params.sessionId):
done = update["is_task_complete"]
content = update["content"]
delta = update["updates"]

state = TaskState.COMPLETED if done else TaskState.WORKING
text = content if done else delta
parts = [{"type": "text", "text": text}]
artifacts = [Artifact(parts=parts)] if done else None

status = TaskStatus(state=state, message=Message(role="agent", parts=parts))
await self._update_store(request.params.id, status, artifacts or [])

yield SendTaskStreamingResponse(
id=request.id,
result=TaskStatusUpdateEvent(id=params.id, status=status, final=done)
)
if artifacts:
yield SendTaskStreamingResponse(
id=request.id,
result=TaskArtifactUpdateEvent(id=params.id, artifact=artifacts[0])
)

async def _update_store(self, task_id: str, status: TaskStatus, artifacts):
async with self.lock:
task = self.tasks[task_id]
task.status = status
if artifacts:
task.artifacts = (task.artifacts or []) + artifacts
return task

def _invoke(self, query: str, session_id: str) -> str:
"""
Route the user query through the Agent, executing tools as needed.
"""
# Determine which session to use
if self.session_id is not None:
sid = self.session_id
else:
sid = self.agent.create_session(session_id)

# Send the user query to the Agent
turn_resp = self.agent.create_turn(
messages=[{"role": "user", "content": query}],
session_id=sid,
)

# Extract tool and LLM outputs from events
logs = AgentEventLogger().log(turn_resp)
output = ""
for event in logs:
if hasattr(event, "content") and event.content:
output += event.content
return output

async def _stream(self, query: str, session_id: str) -> AsyncIterator[dict]:
"""
Simplest streaming stub: synchronously invoke and emit once.
"""
result = self._invoke(query, session_id)
yield {"updates": result, "is_task_complete": True, "content": result}
Empty file.
85 changes: 85 additions & 0 deletions demos/a2a_llama_stack/agents/a2a_custom_tools/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import os
import logging

from llama_stack_client import LlamaStackClient, Agent

from common.server import A2AServer
from common.types import AgentCard, AgentCapabilities, AgentSkill

from .agent import random_number_tool, date_tool
from .task_manager import AgentTaskManager, SUPPORTED_CONTENT_TYPES

logging.basicConfig(level=logging.INFO)


def build_server(host: str = "0.0.0.0", port: int = 10010):
# 1) instantiate your agent with the required parameters
agent = Agent(
client=LlamaStackClient(base_url=os.getenv("REMOTE_BASE_URL", "http://localhost:8321")),
model=os.getenv("INFERENCE_MODEL_ID", "llama3.1:8b-instruct-fp16"),
instructions=(
"You have access to two tools:\n"
"- random_number_tool: generates one random integer between 1 and 100\n"
"- date_tool: returns today's date in YYYY-MM-DD format\n"
"Always use the appropriate tool to answer user queries."
),
tools=[random_number_tool, date_tool],
max_infer_iters=3,
)

# 2) wrap it in the A2A TaskManager
task_manager = AgentTaskManager(agent=agent, internal_session_id=True)

# 3) advertise your tools as AgentSkills
card = AgentCard(
name="Custom Agent",
description="Generates random numbers or retrieve today's dates",
url=f"http://{host}:{port}/",
version="0.1.0",
defaultInputModes=["text/plain"],
defaultOutputModes=SUPPORTED_CONTENT_TYPES,
capabilities=AgentCapabilities(
streaming=False,
pushNotifications=False,
stateTransitionHistory=False,
),
skills=[
AgentSkill(
id="random_number_tool",
name="Random Number Generator",
description="Generates a random number between 1 and 100",
tags=["random"],
examples=["Give me a random number between 1 and 100"],
inputModes=["text/plain"],
outputModes=["text/plain"],
),

AgentSkill(
id="date_tool",
name="Date Provider",
description="Returns today's date in YYYY-MM-DD format",
tags=["date"],
examples=["What's the date today?"],
inputModes=["text/plain"],
outputModes=["text/plain"],
),
],
)

return A2AServer(
agent_card=card,
task_manager=task_manager,
host=host,
port=port,
)

if __name__ == "__main__":
import click

@click.command()
@click.option("--host", default="0.0.0.0")
@click.option("--port", default=10010, type=int)
def main(host, port):
build_server(host, port).start()

main()
Loading