Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -225,4 +225,4 @@ local.settings.json
**/frontend/dist/

# Database files
*.db
*.db
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

import importlib.metadata

from agent_framework_durabletask import AgentCallbackContext, AgentResponseCallbackProtocol
from agent_framework_durabletask import AgentCallbackContext, AgentResponseCallbackProtocol, DurableAIAgent

from ._app import AgentFunctionApp
from ._orchestration import DurableAIAgent

try:
__version__ = importlib.metadata.version(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import json
import re
import uuid
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from datetime import datetime, timezone
Expand All @@ -28,14 +29,16 @@
WAIT_FOR_RESPONSE_FIELD,
WAIT_FOR_RESPONSE_HEADER,
AgentResponseCallbackProtocol,
AgentSessionId,
ApiResponseFields,
DurableAgentState,
DurableAIAgent,
RunRequest,
)

from ._entities import create_agent_entity
from ._errors import IncomingRequestError
from ._models import AgentSessionId
from ._orchestration import AgentOrchestrationContextType, DurableAIAgent
from ._orchestration import AgentOrchestrationContextType, AgentTask, AzureFunctionsAgentExecutor

logger = get_logger("agent_framework.azurefunctions")

Expand Down Expand Up @@ -296,7 +299,7 @@ def get_agent(
self,
context: AgentOrchestrationContextType,
agent_name: str,
) -> DurableAIAgent:
) -> DurableAIAgent[AgentTask]:
"""Return a DurableAIAgent proxy for a registered agent.

Args:
Expand All @@ -307,14 +310,15 @@ def get_agent(
ValueError: If the requested agent has not been registered.

Returns:
DurableAIAgent wrapper bound to the orchestration context.
DurableAIAgent[AgentTask] wrapper bound to the orchestration context.
"""
normalized_name = str(agent_name)

if normalized_name not in self._agent_metadata:
raise ValueError(f"Agent '{normalized_name}' is not registered with this app.")

return DurableAIAgent(context, normalized_name)
executor = AzureFunctionsAgentExecutor(context)
return DurableAIAgent(executor, normalized_name)

def _setup_agent_functions(
self,
Expand Down Expand Up @@ -407,11 +411,13 @@ async def http_start(req: func.HttpRequest, client: df.DurableOrchestrationClien
logger.debug(f"[HTTP Trigger] Generated correlation ID: {correlation_id}")
logger.debug("[HTTP Trigger] Calling entity to run agent...")

entity_instance_id = session_id.to_entity_id()
entity_instance_id = df.EntityId(
name=session_id.entity_name,
key=session_id.key,
)
run_request = self._build_request_data(
req_body,
message,
thread_id,
correlation_id,
request_response_format,
)
Expand Down Expand Up @@ -624,14 +630,16 @@ async def _handle_mcp_tool_invocation(
session_id = AgentSessionId.with_random_key(agent_name)

# Build entity instance ID
entity_instance_id = session_id.to_entity_id()
entity_instance_id = df.EntityId(
name=session_id.entity_name,
key=session_id.key,
)

# Create run request
correlation_id = self._generate_unique_id()
run_request = self._build_request_data(
req_body={"message": query, "role": "user"},
message=query,
thread_id=str(session_id),
correlation_id=correlation_id,
request_response_format=REQUEST_RESPONSE_FORMAT_TEXT,
)
Expand Down Expand Up @@ -782,8 +790,9 @@ async def _poll_entity_for_response(

agent_response = state.try_get_agent_response(correlation_id)
if agent_response:
response_message = "\n".join(message.text for message in agent_response.messages if message.text)
result = self._build_success_result(
response_data=agent_response,
response_message=response_message,
message=message,
thread_id=thread_id,
correlation_id=correlation_id,
Expand Down Expand Up @@ -829,23 +838,22 @@ async def _build_timeout_result(self, message: str, thread_id: str, correlation_
)

def _build_success_result(
self, response_data: dict[str, Any], message: str, thread_id: str, correlation_id: str, state: DurableAgentState
self, response_message: str, message: str, thread_id: str, correlation_id: str, state: DurableAgentState
) -> dict[str, Any]:
"""Build the success result returned to the HTTP caller."""
return self._build_response_payload(
response=response_data.get("content"),
response=response_message,
message=message,
thread_id=thread_id,
status="success",
correlation_id=correlation_id,
extra_fields={"message_count": response_data.get("message_count", state.message_count)},
extra_fields={ApiResponseFields.MESSAGE_COUNT: state.message_count},
)

def _build_request_data(
self,
req_body: dict[str, Any],
message: str,
thread_id: str,
correlation_id: str,
request_response_format: str,
) -> dict[str, Any]:
Expand Down Expand Up @@ -912,15 +920,13 @@ def _convert_payload_to_text(self, payload: dict[str, Any]) -> str:

def _generate_unique_id(self) -> str:
"""Generate a new unique identifier."""
import uuid

return uuid.uuid4().hex

def _create_session_id(self, func_name: str, thread_id: str | None) -> AgentSessionId:
def _create_session_id(self, agent_name: str, thread_id: str | None) -> AgentSessionId:
"""Create a session identifier using the provided thread id or a random value."""
if thread_id:
return AgentSessionId(name=func_name, key=thread_id)
return AgentSessionId.with_random_key(name=func_name)
return AgentSessionId(name=agent_name, key=thread_id)
return AgentSessionId.with_random_key(name=agent_name)

def _resolve_thread_id(self, req: func.HttpRequest, req_body: dict[str, Any]) -> str:
"""Retrieve the thread identifier from request body or query parameters."""
Expand Down

This file was deleted.

Loading
Loading