Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
__pycache__/
*.pyc
.bedrock_agentcore/
.bedrock_agentcore.yaml
.venv/
129 changes: 129 additions & 0 deletions 01-tutorials/03-AgentCore-identity/13-async-wat-refresh/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Async WAT Refresh for Long-Running Agents

## Overview

When an AgentCore agent runs a long-running background task in a thread, the Workload Access Token (WAT) — created from the inbound JWT — expires after the JWT's TTL. This causes the SDK to fall back to creating orphaned workload identities via IAM, breaking user binding and auditability.

This sample demonstrates a companion library (`agentcore_thread_utils`) that solves the WAT expiration problem by:
- Propagating the WAT to background threads via `contextvars.copy_context()`
- Pausing the thread when the WAT expires and waiting for the client to send a fresh JWT
- Retrying the credential provider call with the refreshed WAT

No orphan workload identities are created. The WAT stays bound to the original user.

## Architecture

| Component | Description |
|---|---|
| `@with_wat_refresh` | Drop-in replacement for `@requires_access_token`. Catches WAT expiration, pauses the thread, waits for client refresh, retries. |
| `ThreadTaskManager` | Manages thread lifecycle with WAT propagation. Handles start/status/refresh/result actions. |

## Prerequisites

- AWS CLI configured
- `agentcore` CLI installed
- `jq` installed
- Python 3.10+

## Setup

### 1. Create Cognito User Pool (5-min access token TTL)

```bash
source setup_cognito.sh
```

### 2. Create Credential Provider

```bash
bash setup_credential_provider.sh
```

### 3. Deploy the Agent

```bash
bash deploy.sh
```

Export the Agent ARN from the output:

```bash
export AGENT_ARN="arn:aws:bedrock-agentcore:us-east-1:123456789012:runtime/thread_async_utils-XXXXXXXXXX"
```

## Test

### Start a task

```bash
bash test_curl.sh
```

### Wait 6 minutes, then check status

```bash
bash test_refresh.sh <session-id> status
```

### Send WAT refresh

```bash
bash test_refresh.sh <session-id> refresh
```

### Check result

```bash
bash test_refresh.sh <session-id> status
```

## Expected Flow

1. Client invokes agent with JWT (5-min TTL) → task starts in background thread with propagated WAT
2. Thread sleeps 6 minutes (WAT expires at minute 5)
3. Thread calls credential provider → WAT expired → decorator pauses thread
4. Client sends `{"action": "refresh"}` with fresh JWT → Runtime creates new WAT
5. Thread resumes, retries credential provider → success
6. Task completes, agent returns to Healthy

## How It Works

```mermaid
sequenceDiagram
participant Client
participant Entrypoint as Entrypoint (main thread)
participant Runtime as AgentCore Runtime
participant Thread as Background Thread
participant Provider as Credential Provider

Client->>Entrypoint: {"action": "start"} + JWT
Runtime->>Runtime: Create WAT from JWT (same exp)
Entrypoint->>Thread: Start with copy_context() (WAT propagated)
Entrypoint-->>Client: {"task_id": 123, "status": "started"}

Note over Thread: Business logic runs...

Thread->>Provider: @with_wat_refresh → get token
Provider-->>Thread: Token has expired

Note over Thread: Thread paused, waiting for refresh

Client->>Entrypoint: {"action": "refresh"} + fresh JWT
Runtime->>Runtime: Create new WAT
Entrypoint->>Thread: Signal with new WAT

Note over Thread: Thread resumes
Thread->>Provider: Retry → get token ✓
```

## Cleanup

```bash
agentcore destroy
aws cognito-idp delete-user-pool --user-pool-id $POOL_ID --region us-east-1
```

## Related

- [AgentCore Identity - Getting Started](../01-getting_started.md)
- [AgentCore Identity - How It Works](../02-how_it_works.md)
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""AgentCore Thread Utils — companion library for threaded async agents with WAT refresh."""

from .decorator import with_wat_refresh, set_task_context
from .helper import ThreadTaskManager

__all__ = ["with_wat_refresh", "set_task_context", "ThreadTaskManager"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""with_wat_refresh — thread-safe drop-in replacement for @requires_access_token."""

import logging
import threading
from functools import wraps
from typing import Any, Callable, List, Literal, Optional

from bedrock_agentcore.runtime import BedrockAgentCoreContext
from bedrock_agentcore.identity.auth import requires_access_token

logger = logging.getLogger("agentcore_thread_utils.decorator")

# Thread-local storage for task context — safe for concurrent tasks
_local = threading.local()


def set_task_context(manager, task_id: int):
"""Set the current task context for WAT refresh coordination.

Must be called at the start of each task function.
Uses thread-local storage so concurrent tasks don't interfere.
"""
_local.task_manager = manager
_local.task_id = task_id


def _get_task_context():
"""Get the current task context from thread-local storage."""
manager = getattr(_local, "task_manager", None)
task_id = getattr(_local, "task_id", None)
return manager, task_id


def with_wat_refresh(
*,
provider_name: str,
scopes: List[str],
auth_flow: Literal["M2M", "USER_FEDERATION"] = "M2M",
into: str = "access_token",
on_auth_url: Optional[Callable] = None,
callback_url: Optional[str] = None,
force_authentication: bool = False,
max_retries: int = 2,
) -> Callable:
"""Decorator that wraps @requires_access_token with WAT refresh for threads.

Same interface as @requires_access_token but handles WAT expiration
by pausing the thread and waiting for a client refresh.

Args:
max_retries: Maximum number of refresh attempts before giving up (default: 2).
"""

def decorator(func: Callable) -> Callable:
@requires_access_token(
provider_name=provider_name,
scopes=scopes,
auth_flow=auth_flow,
into=into,
on_auth_url=on_auth_url,
callback_url=callback_url,
force_authentication=force_authentication,
)
def _inner_call(*, access_token: str):
return func(access_token=access_token)

@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
manager, task_id = _get_task_context()
last_error = None

for attempt in range(max_retries + 1):
try:
return _inner_call(**kwargs)
except Exception as e:
last_error = e
if "expired" in str(e).lower() and manager and task_id:
if attempt < max_retries:
logger.info(
f"WAT expired (attempt {attempt + 1}/{max_retries}). "
f"Requesting refresh for task {task_id}..."
)
new_wat = manager.wait_for_wat_refresh(task_id)
BedrockAgentCoreContext.set_workload_access_token(new_wat)
logger.info("WAT refreshed. Retrying...")
else:
logger.error(f"WAT expired after {max_retries} refresh attempts.")
raise
else:
raise

raise last_error

return wrapper

return decorator
Loading
Loading