diff --git a/agent_memory_server/__init__.py b/agent_memory_server/__init__.py index c06ed54..074b7f6 100644 --- a/agent_memory_server/__init__.py +++ b/agent_memory_server/__init__.py @@ -1,3 +1,3 @@ """Redis Agent Memory Server - A memory system for conversational AI.""" -__version__ = "0.9.2" +__version__ = "0.9.3" diff --git a/agent_memory_server/auth.py b/agent_memory_server/auth.py index ec62230..4631faa 100644 --- a/agent_memory_server/auth.py +++ b/agent_memory_server/auth.py @@ -346,7 +346,7 @@ async def verify_token(token: str) -> UserInfo: ) from e -def get_current_user( +async def get_current_user( credentials: HTTPAuthorizationCredentials | None = Depends(oauth2_scheme), ) -> UserInfo: if settings.disable_auth or settings.auth_mode == "disabled": @@ -371,9 +371,7 @@ def get_current_user( # Determine authentication mode if settings.auth_mode == "token" or settings.token_auth_enabled: - import asyncio - - return asyncio.run(verify_token(credentials.credentials)) + return await verify_token(credentials.credentials) if settings.auth_mode == "oauth2": return verify_jwt(credentials.credentials) # Default to OAuth2 for backward compatibility @@ -381,7 +379,7 @@ def get_current_user( def require_scope(required_scope: str): - def scope_dependency(user: UserInfo = Depends(get_current_user)) -> UserInfo: + async def scope_dependency(user: UserInfo = Depends(get_current_user)) -> UserInfo: if settings.disable_auth: return user @@ -397,7 +395,7 @@ def scope_dependency(user: UserInfo = Depends(get_current_user)) -> UserInfo: def require_role(required_role: str): - def role_dependency(user: UserInfo = Depends(get_current_user)) -> UserInfo: + async def role_dependency(user: UserInfo = Depends(get_current_user)) -> UserInfo: if settings.disable_auth: return user diff --git a/tests/test_auth.py b/tests/test_auth.py index 0627cff..f2b100c 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -685,7 +685,7 @@ async def test_get_current_user_disabled_auth(self, mock_settings): """Test get_current_user when authentication is disabled""" mock_settings.disable_auth = True - result = get_current_user(None) + result = await get_current_user(None) assert isinstance(result, UserInfo) assert result.sub == "local-dev-user" @@ -700,7 +700,7 @@ async def test_get_current_user_missing_credentials(self, mock_settings): mock_settings.auth_mode = "oauth2" with pytest.raises(HTTPException) as exc_info: - get_current_user(None) + await get_current_user(None) assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED assert "Missing authorization header" in str(exc_info.value.detail) @@ -717,7 +717,7 @@ async def test_get_current_user_empty_credentials(self, mock_settings): empty_creds = HTTPAuthorizationCredentials(scheme="Bearer", credentials="") with pytest.raises(HTTPException) as exc_info: - get_current_user(empty_creds) + await get_current_user(empty_creds) assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED assert "Missing bearer token" in str(exc_info.value.detail) @@ -736,7 +736,7 @@ async def test_get_current_user_valid_token(self, mock_settings, valid_token): expected_user = UserInfo(sub="test-user", email="test@example.com") mock_verify.return_value = expected_user - result = get_current_user(creds) + result = await get_current_user(creds) assert result == expected_user mock_verify.assert_called_once_with(valid_token) @@ -753,7 +753,7 @@ async def test_require_scope_success(self, mock_settings): user = UserInfo(sub="test-user", scope="read write admin") scope_dependency = require_scope("read") - result = scope_dependency(user) + result = await scope_dependency(user) assert result == user @pytest.mark.asyncio @@ -765,7 +765,7 @@ async def test_require_scope_failure(self, mock_settings): scope_dependency = require_scope("admin") with pytest.raises(HTTPException) as exc_info: - scope_dependency(user) + await scope_dependency(user) assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN assert "Insufficient permissions" in str(exc_info.value.detail) @@ -780,7 +780,7 @@ async def test_require_scope_no_scope(self, mock_settings): scope_dependency = require_scope("read") with pytest.raises(HTTPException) as exc_info: - scope_dependency(user) + await scope_dependency(user) assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN @@ -792,7 +792,7 @@ async def test_require_scope_disabled_auth(self, mock_settings): user = UserInfo(sub="test-user", scope=None) scope_dependency = require_scope("admin") - result = scope_dependency(user) + result = await scope_dependency(user) assert result == user @pytest.mark.asyncio @@ -803,7 +803,7 @@ async def test_require_role_success(self, mock_settings): user = UserInfo(sub="test-user", roles=["user", "admin"]) role_dependency = require_role("admin") - result = role_dependency(user) + result = await role_dependency(user) assert result == user @pytest.mark.asyncio @@ -815,7 +815,7 @@ async def test_require_role_failure(self, mock_settings): role_dependency = require_role("admin") with pytest.raises(HTTPException) as exc_info: - role_dependency(user) + await role_dependency(user) assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN assert "Insufficient permissions" in str(exc_info.value.detail) @@ -830,7 +830,7 @@ async def test_require_role_no_roles(self, mock_settings): role_dependency = require_role("admin") with pytest.raises(HTTPException) as exc_info: - role_dependency(user) + await role_dependency(user) assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN @@ -842,7 +842,7 @@ async def test_require_role_disabled_auth(self, mock_settings): user = UserInfo(sub="test-user", roles=None) role_dependency = require_role("admin") - result = role_dependency(user) + result = await role_dependency(user) assert result == user diff --git a/tests/test_token_auth.py b/tests/test_token_auth.py index f2c4804..12875b3 100644 --- a/tests/test_token_auth.py +++ b/tests/test_token_auth.py @@ -189,28 +189,31 @@ async def test_verify_token_wrong_token(self, mock_redis, sample_token_info): class TestGetCurrentUser: """Test get_current_user with token authentication.""" - def test_get_current_user_disabled_auth(self, mock_settings): + @pytest.mark.asyncio + async def test_get_current_user_disabled_auth(self, mock_settings): """Test get_current_user with disabled authentication.""" mock_settings.disable_auth = True mock_settings.auth_mode = "disabled" - user_info = get_current_user(None) + user_info = await get_current_user(None) assert user_info.sub == "local-dev-user" assert user_info.aud == "local-dev" - def test_get_current_user_missing_credentials(self, mock_settings): + @pytest.mark.asyncio + async def test_get_current_user_missing_credentials(self, mock_settings): """Test get_current_user with missing credentials.""" mock_settings.disable_auth = False mock_settings.auth_mode = "token" with pytest.raises(HTTPException) as exc_info: - get_current_user(None) + await get_current_user(None) assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED assert "Missing authorization header" in exc_info.value.detail - def test_get_current_user_missing_token(self, mock_settings): + @pytest.mark.asyncio + async def test_get_current_user_missing_token(self, mock_settings): """Test get_current_user with missing token.""" mock_settings.disable_auth = False mock_settings.auth_mode = "token" @@ -218,13 +221,14 @@ def test_get_current_user_missing_token(self, mock_settings): credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials="") with pytest.raises(HTTPException) as exc_info: - get_current_user(credentials) + await get_current_user(credentials) assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED assert "Missing bearer token" in exc_info.value.detail - @patch("agent_memory_server.auth.verify_token") - def test_get_current_user_token_auth(self, mock_verify_token, mock_settings): + @patch("agent_memory_server.auth.verify_token", new_callable=AsyncMock) + @pytest.mark.asyncio + async def test_get_current_user_token_auth(self, mock_verify_token, mock_settings): """Test get_current_user with token authentication.""" mock_settings.disable_auth = False mock_settings.auth_mode = "token" @@ -232,16 +236,15 @@ def test_get_current_user_token_auth(self, mock_verify_token, mock_settings): # Mock verify_token to return a user mock_user = Mock() mock_user.sub = "token-user" + mock_verify_token.return_value = mock_user - # Mock asyncio.run to return the user directly - with patch("asyncio.run", return_value=mock_user): - credentials = HTTPAuthorizationCredentials( - scheme="Bearer", credentials="test_token" - ) + credentials = HTTPAuthorizationCredentials( + scheme="Bearer", credentials="test_token" + ) - user_info = get_current_user(credentials) + user_info = await get_current_user(credentials) - assert user_info.sub == "token-user" + assert user_info.sub == "token-user" class TestAuthConfig: