Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -208,13 +208,14 @@ WARP.md
**/projectBrief.md

# Azurite storage emulator files
*/__azurite_db_blob__.json
*/__azurite_db_blob_extent__.json
*/__azurite_db_queue__.json
*/__azurite_db_queue_extent__.json
*/__azurite_db_table__.json
*/__azurite_db_blob__.json*
*/__azurite_db_blob_extent__.json*
*/__azurite_db_queue__.json*
*/__azurite_db_queue_extent__.json*
*/__azurite_db_table__.json*
*/__blobstorage__/
*/__queuestorage__/
*/AzuriteConfig

# Azure Functions local settings
local.settings.json
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

import importlib.metadata

from agent_framework_durabletask import AgentCallbackContext, AgentResponseCallbackProtocol
from agent_framework_durabletask import AgentCallbackContext, AgentResponseCallbackProtocol, DurableAIAgent

from ._app import AgentFunctionApp
from ._orchestration import DurableAIAgent

try:
__version__ = importlib.metadata.version(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import json
import re
import uuid
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from datetime import datetime, timezone
Expand All @@ -28,14 +29,16 @@
WAIT_FOR_RESPONSE_FIELD,
WAIT_FOR_RESPONSE_HEADER,
AgentResponseCallbackProtocol,
AgentSessionId,
ApiResponseFields,
DurableAgentState,
DurableAIAgent,
RunRequest,
)

from ._entities import create_agent_entity
from ._errors import IncomingRequestError
from ._models import AgentSessionId
from ._orchestration import AgentOrchestrationContextType, DurableAIAgent
from ._orchestration import AgentOrchestrationContextType, AgentTask, AzureFunctionsAgentExecutor

logger = get_logger("agent_framework.azurefunctions")

Expand Down Expand Up @@ -296,7 +299,7 @@ def get_agent(
self,
context: AgentOrchestrationContextType,
agent_name: str,
) -> DurableAIAgent:
) -> DurableAIAgent[AgentTask]:
"""Return a DurableAIAgent proxy for a registered agent.

Args:
Expand All @@ -307,14 +310,15 @@ def get_agent(
ValueError: If the requested agent has not been registered.

Returns:
DurableAIAgent wrapper bound to the orchestration context.
DurableAIAgent[AgentTask] wrapper bound to the orchestration context.
"""
normalized_name = str(agent_name)

if normalized_name not in self._agent_metadata:
raise ValueError(f"Agent '{normalized_name}' is not registered with this app.")

return DurableAIAgent(context, normalized_name)
executor = AzureFunctionsAgentExecutor(context)
return DurableAIAgent(executor, normalized_name)

def _setup_agent_functions(
self,
Expand Down Expand Up @@ -377,8 +381,6 @@ async def http_start(req: func.HttpRequest, client: df.DurableOrchestrationClien
"enable_tool_calls": true|false (optional, default: true)
}
"""
logger.debug(f"[HTTP Trigger] Received request on route: /api/agents/{agent_name}/run")

request_response_format: str = REQUEST_RESPONSE_FORMAT_JSON
thread_id: str | None = None

Expand All @@ -387,9 +389,9 @@ async def http_start(req: func.HttpRequest, client: df.DurableOrchestrationClien
thread_id = self._resolve_thread_id(req=req, req_body=req_body)
wait_for_response = self._should_wait_for_response(req=req, req_body=req_body)

logger.debug(f"[HTTP Trigger] Message: {message}")
logger.debug(f"[HTTP Trigger] Thread ID: {thread_id}")
logger.debug(f"[HTTP Trigger] wait_for_response: {wait_for_response}")
logger.debug(
f"[HTTP Trigger] Message: {message}, Thread ID: {thread_id}, wait_for_response: {wait_for_response}"
)

if not message:
logger.warning("[HTTP Trigger] Request rejected: Missing message")
Expand All @@ -403,15 +405,18 @@ async def http_start(req: func.HttpRequest, client: df.DurableOrchestrationClien
session_id = self._create_session_id(agent_name, thread_id)
correlation_id = self._generate_unique_id()

logger.debug(f"[HTTP Trigger] Using session ID: {session_id}")
logger.debug(f"[HTTP Trigger] Generated correlation ID: {correlation_id}")
logger.debug("[HTTP Trigger] Calling entity to run agent...")
logger.debug(
f"[HTTP Trigger] Calling entity to run agent using session ID: {session_id} "
f"and correlation ID: {correlation_id}"
)

entity_instance_id = session_id.to_entity_id()
entity_instance_id = df.EntityId(
name=session_id.entity_name,
key=session_id.key,
)
run_request = self._build_request_data(
req_body,
message,
thread_id,
correlation_id,
request_response_format,
)
Expand Down Expand Up @@ -624,14 +629,16 @@ async def _handle_mcp_tool_invocation(
session_id = AgentSessionId.with_random_key(agent_name)

# Build entity instance ID
entity_instance_id = session_id.to_entity_id()
entity_instance_id = df.EntityId(
name=session_id.entity_name,
key=session_id.key,
)

# Create run request
correlation_id = self._generate_unique_id()
run_request = self._build_request_data(
req_body={"message": query, "role": "user"},
message=query,
thread_id=str(session_id),
correlation_id=correlation_id,
request_response_format=REQUEST_RESPONSE_FORMAT_TEXT,
)
Expand Down Expand Up @@ -783,7 +790,7 @@ async def _poll_entity_for_response(
agent_response = state.try_get_agent_response(correlation_id)
if agent_response:
result = self._build_success_result(
response_data=agent_response,
response_message=agent_response.text,
message=message,
thread_id=thread_id,
correlation_id=correlation_id,
Expand Down Expand Up @@ -829,23 +836,22 @@ async def _build_timeout_result(self, message: str, thread_id: str, correlation_
)

def _build_success_result(
self, response_data: dict[str, Any], message: str, thread_id: str, correlation_id: str, state: DurableAgentState
self, response_message: str, message: str, thread_id: str, correlation_id: str, state: DurableAgentState
) -> dict[str, Any]:
"""Build the success result returned to the HTTP caller."""
return self._build_response_payload(
response=response_data.get("content"),
response=response_message,
message=message,
thread_id=thread_id,
status="success",
correlation_id=correlation_id,
extra_fields={"message_count": response_data.get("message_count", state.message_count)},
extra_fields={ApiResponseFields.MESSAGE_COUNT: state.message_count},
)

def _build_request_data(
self,
req_body: dict[str, Any],
message: str,
thread_id: str,
correlation_id: str,
request_response_format: str,
) -> dict[str, Any]:
Expand Down Expand Up @@ -912,15 +918,13 @@ def _convert_payload_to_text(self, payload: dict[str, Any]) -> str:

def _generate_unique_id(self) -> str:
"""Generate a new unique identifier."""
import uuid

return uuid.uuid4().hex

def _create_session_id(self, func_name: str, thread_id: str | None) -> AgentSessionId:
def _create_session_id(self, agent_name: str, thread_id: str | None) -> AgentSessionId:
"""Create a session identifier using the provided thread id or a random value."""
if thread_id:
return AgentSessionId(name=func_name, key=thread_id)
return AgentSessionId.with_random_key(name=func_name)
return AgentSessionId(name=agent_name, key=thread_id)
return AgentSessionId.with_random_key(name=agent_name)

def _resolve_thread_id(self, req: func.HttpRequest, req_body: dict[str, Any]) -> str:
"""Retrieve the thread identifier from request body or query parameters."""
Expand Down
Loading
Loading