Skip to content

Fix authentication event loop corruption by converting get_current_user to async #40

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion agent_memory_server/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Redis Agent Memory Server - A memory system for conversational AI."""

__version__ = "0.9.2"
__version__ = "0.9.3"
10 changes: 4 additions & 6 deletions agent_memory_server/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ async def verify_token(token: str) -> UserInfo:
) from e


def get_current_user(
async def get_current_user(
credentials: HTTPAuthorizationCredentials | None = Depends(oauth2_scheme),
) -> UserInfo:
if settings.disable_auth or settings.auth_mode == "disabled":
Expand All @@ -371,17 +371,15 @@ def get_current_user(

# Determine authentication mode
if settings.auth_mode == "token" or settings.token_auth_enabled:
import asyncio

return asyncio.run(verify_token(credentials.credentials))
return await verify_token(credentials.credentials)
if settings.auth_mode == "oauth2":
return verify_jwt(credentials.credentials)
# Default to OAuth2 for backward compatibility
return verify_jwt(credentials.credentials)


def require_scope(required_scope: str):
def scope_dependency(user: UserInfo = Depends(get_current_user)) -> UserInfo:
async def scope_dependency(user: UserInfo = Depends(get_current_user)) -> UserInfo:
if settings.disable_auth:
return user

Expand All @@ -397,7 +395,7 @@ def scope_dependency(user: UserInfo = Depends(get_current_user)) -> UserInfo:


def require_role(required_role: str):
def role_dependency(user: UserInfo = Depends(get_current_user)) -> UserInfo:
async def role_dependency(user: UserInfo = Depends(get_current_user)) -> UserInfo:
if settings.disable_auth:
return user

Expand Down
24 changes: 12 additions & 12 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,7 @@ async def test_get_current_user_disabled_auth(self, mock_settings):
"""Test get_current_user when authentication is disabled"""
mock_settings.disable_auth = True

result = get_current_user(None)
result = await get_current_user(None)

assert isinstance(result, UserInfo)
assert result.sub == "local-dev-user"
Expand All @@ -700,7 +700,7 @@ async def test_get_current_user_missing_credentials(self, mock_settings):
mock_settings.auth_mode = "oauth2"

with pytest.raises(HTTPException) as exc_info:
get_current_user(None)
await get_current_user(None)

assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
assert "Missing authorization header" in str(exc_info.value.detail)
Expand All @@ -717,7 +717,7 @@ async def test_get_current_user_empty_credentials(self, mock_settings):
empty_creds = HTTPAuthorizationCredentials(scheme="Bearer", credentials="")

with pytest.raises(HTTPException) as exc_info:
get_current_user(empty_creds)
await get_current_user(empty_creds)

assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
assert "Missing bearer token" in str(exc_info.value.detail)
Expand All @@ -736,7 +736,7 @@ async def test_get_current_user_valid_token(self, mock_settings, valid_token):
expected_user = UserInfo(sub="test-user", email="[email protected]")
mock_verify.return_value = expected_user

result = get_current_user(creds)
result = await get_current_user(creds)

assert result == expected_user
mock_verify.assert_called_once_with(valid_token)
Expand All @@ -753,7 +753,7 @@ async def test_require_scope_success(self, mock_settings):
user = UserInfo(sub="test-user", scope="read write admin")
scope_dependency = require_scope("read")

result = scope_dependency(user)
result = await scope_dependency(user)
assert result == user

@pytest.mark.asyncio
Expand All @@ -765,7 +765,7 @@ async def test_require_scope_failure(self, mock_settings):
scope_dependency = require_scope("admin")

with pytest.raises(HTTPException) as exc_info:
scope_dependency(user)
await scope_dependency(user)

assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
assert "Insufficient permissions" in str(exc_info.value.detail)
Expand All @@ -780,7 +780,7 @@ async def test_require_scope_no_scope(self, mock_settings):
scope_dependency = require_scope("read")

with pytest.raises(HTTPException) as exc_info:
scope_dependency(user)
await scope_dependency(user)

assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN

Expand All @@ -792,7 +792,7 @@ async def test_require_scope_disabled_auth(self, mock_settings):
user = UserInfo(sub="test-user", scope=None)
scope_dependency = require_scope("admin")

result = scope_dependency(user)
result = await scope_dependency(user)
assert result == user

@pytest.mark.asyncio
Expand All @@ -803,7 +803,7 @@ async def test_require_role_success(self, mock_settings):
user = UserInfo(sub="test-user", roles=["user", "admin"])
role_dependency = require_role("admin")

result = role_dependency(user)
result = await role_dependency(user)
assert result == user

@pytest.mark.asyncio
Expand All @@ -815,7 +815,7 @@ async def test_require_role_failure(self, mock_settings):
role_dependency = require_role("admin")

with pytest.raises(HTTPException) as exc_info:
role_dependency(user)
await role_dependency(user)

assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
assert "Insufficient permissions" in str(exc_info.value.detail)
Expand All @@ -830,7 +830,7 @@ async def test_require_role_no_roles(self, mock_settings):
role_dependency = require_role("admin")

with pytest.raises(HTTPException) as exc_info:
role_dependency(user)
await role_dependency(user)

assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN

Expand All @@ -842,7 +842,7 @@ async def test_require_role_disabled_auth(self, mock_settings):
user = UserInfo(sub="test-user", roles=None)
role_dependency = require_role("admin")

result = role_dependency(user)
result = await role_dependency(user)
assert result == user


Expand Down
33 changes: 18 additions & 15 deletions tests/test_token_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,59 +189,62 @@ async def test_verify_token_wrong_token(self, mock_redis, sample_token_info):
class TestGetCurrentUser:
"""Test get_current_user with token authentication."""

def test_get_current_user_disabled_auth(self, mock_settings):
@pytest.mark.asyncio
async def test_get_current_user_disabled_auth(self, mock_settings):
"""Test get_current_user with disabled authentication."""
mock_settings.disable_auth = True
mock_settings.auth_mode = "disabled"

user_info = get_current_user(None)
user_info = await get_current_user(None)

assert user_info.sub == "local-dev-user"
assert user_info.aud == "local-dev"

def test_get_current_user_missing_credentials(self, mock_settings):
@pytest.mark.asyncio
async def test_get_current_user_missing_credentials(self, mock_settings):
"""Test get_current_user with missing credentials."""
mock_settings.disable_auth = False
mock_settings.auth_mode = "token"

with pytest.raises(HTTPException) as exc_info:
get_current_user(None)
await get_current_user(None)

assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
assert "Missing authorization header" in exc_info.value.detail

def test_get_current_user_missing_token(self, mock_settings):
@pytest.mark.asyncio
async def test_get_current_user_missing_token(self, mock_settings):
"""Test get_current_user with missing token."""
mock_settings.disable_auth = False
mock_settings.auth_mode = "token"

credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials="")

with pytest.raises(HTTPException) as exc_info:
get_current_user(credentials)
await get_current_user(credentials)

assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
assert "Missing bearer token" in exc_info.value.detail

@patch("agent_memory_server.auth.verify_token")
def test_get_current_user_token_auth(self, mock_verify_token, mock_settings):
@patch("agent_memory_server.auth.verify_token", new_callable=AsyncMock)
@pytest.mark.asyncio
async def test_get_current_user_token_auth(self, mock_verify_token, mock_settings):
"""Test get_current_user with token authentication."""
mock_settings.disable_auth = False
mock_settings.auth_mode = "token"

# Mock verify_token to return a user
mock_user = Mock()
mock_user.sub = "token-user"
mock_verify_token.return_value = mock_user

# Mock asyncio.run to return the user directly
with patch("asyncio.run", return_value=mock_user):
credentials = HTTPAuthorizationCredentials(
scheme="Bearer", credentials="test_token"
)
credentials = HTTPAuthorizationCredentials(
scheme="Bearer", credentials="test_token"
)

user_info = get_current_user(credentials)
user_info = await get_current_user(credentials)

assert user_info.sub == "token-user"
assert user_info.sub == "token-user"


class TestAuthConfig:
Expand Down