From 1d90e658f93bd29ca44a6be7f80132a178285c57 Mon Sep 17 00:00:00 2001 From: Tom Bursch Date: Fri, 3 Jan 2025 11:45:46 +0100 Subject: [PATCH] fix: Minor --- .../app/controller/auth/auth_controller.py | 1 + backend/app/models/token.py | 23 ++++---- backend/tests/api/test_auth.py | 53 ++++++++++--------- 3 files changed, 40 insertions(+), 37 deletions(-) diff --git a/backend/app/controller/auth/auth_controller.py b/backend/app/controller/auth/auth_controller.py index 85074857..e78e58ac 100644 --- a/backend/app/controller/auth/auth_controller.py +++ b/backend/app/controller/auth/auth_controller.py @@ -36,6 +36,7 @@ def check_if_token_revoked(jwt_header, jwt_payload: dict) -> bool: token.save() return token is None + # Register a callback function that takes whatever object is passed in as the # identity when creating JWTs and converts it to a JSON serializable format. @jwt.user_identity_loader diff --git a/backend/app/models/token.py b/backend/app/models/token.py index 3b11fd41..c43cac3a 100644 --- a/backend/app/models/token.py +++ b/backend/app/models/token.py @@ -101,16 +101,14 @@ def has_created_refresh_token(self) -> bool: > 0 ) - def delete_created_access_tokens(self, exclude_token_id=None): + def delete_created_access_tokens(self, commit=True): if self.type != "refresh": return - query = db.session.query(Token).filter( + Token.query.filter( Token.refresh_token_id == self.id, Token.type == "access" - ) - if exclude_token_id is not None: - query = query.filter(Token.id != exclude_token_id) - query.delete() - db.session.commit() + ).delete() + if commit: + db.session.commit() @classmethod def create_access_token( @@ -141,10 +139,10 @@ def create_refresh_token( # Check if this refresh token has already been used to create another refresh token if oldRefreshToken and oldRefreshToken.has_created_refresh_token(): - for newer_token in db.session.query(Token).filter( + for newer_token in Token.query.filter( Token.refresh_token_id == oldRefreshToken.id, Token.type == "refresh" - ): + ).all(): newer_access_used = db.session.query(Token).filter( Token.refresh_token_id == newer_token.id, Token.type == "access", @@ -161,11 +159,11 @@ def create_refresh_token( ) else: # Only invalidate the unused parallel refresh token chain - for token in db.session.query(Token).filter( + Token.query.filter( Token.refresh_token_id == newer_token.id - ).all(): - db.session.delete(token) + ).delete() newer_token.type = "invalidated_refresh" + db.session.add(newer_token) refreshToken = create_refresh_token(identity=user) model = cls() @@ -174,6 +172,7 @@ def create_refresh_token( model.name = device or oldRefreshToken.name model.user = user if oldRefreshToken: + oldRefreshToken.delete_created_access_tokens(commit=False) model.refresh_token = oldRefreshToken model.save() return refreshToken, model diff --git a/backend/tests/api/test_auth.py b/backend/tests/api/test_auth.py index ffe8f1bb..a139488b 100644 --- a/backend/tests/api/test_auth.py +++ b/backend/tests/api/test_auth.py @@ -51,9 +51,9 @@ def test_shaky_network_token_refresh(user_client, username, password): assert response.status_code == 200 # Intentionally ignore new tokens - # Use old access token, should still work since we didn't use the new one + # Use old access token, should not work since refresh invalidates them response = user_client.get("/api/user", headers={"Authorization": f"Bearer {access_token}"}) - assert response.status_code == 200 + assert response.status_code == 401 # Original refresh token should still work since we didn't use the new one @@ -83,9 +83,9 @@ def test_token_hijack_attempt(user_client, username, password): leaked_refresh_token = data["refresh_token"] - # User continues normal use with original tokens + # User cannot continue normal use with original access token response = user_client.get("/api/user", headers={"Authorization": f"Bearer {access_token}"}) - assert response.status_code == 200 + assert response.status_code == 401 # Create another refresh token (normal use) response = user_client.get("/api/auth/refresh", headers={"Authorization": f"Bearer {refresh_token}"}) @@ -225,44 +225,47 @@ def test_complex_token_chain(user_client, username, password): at5 = data["access_token"] rt5 = data["refresh_token"] - # Use AT2 to make it the active chain + # AT2 should be rejected (refresh invalidates AT but not RT) response = user_client.get("/api/user", headers={"Authorization": f"Bearer {at2}"}) + assert response.status_code == 401 + + # RT5/AT5 chain should work (last created refresh token) + response = user_client.get("/api/user", headers={"Authorization": f"Bearer {at5}"}) + assert response.status_code == 200 + response = user_client.get("/api/auth/refresh", headers={"Authorization": f"Bearer {rt5}"}) assert response.status_code == 200 + data = response.get_json() + at6 = data["access_token"] + rt6 = data["refresh_token"] - # Verify unused tokens from parallel chains are rejected + # Verify unused tokens from parallel chains are rejected triggering breach detection + response = user_client.get("/api/auth/refresh", headers={"Authorization": f"Bearer {rt4}"}) + assert response.status_code == 401 - # RT3/AT3 chain should be rejected (unused parallel chain) + # Check that no token works anymore + response = user_client.get("/api/user", headers={"Authorization": f"Bearer {at1}"}) + assert response.status_code == 401 + response = user_client.get("/api/auth/refresh", headers={"Authorization": f"Bearer {rt1}"}) + assert response.status_code == 401 + response = user_client.get("/api/user", headers={"Authorization": f"Bearer {at2}"}) + assert response.status_code == 401 + response = user_client.get("/api/auth/refresh", headers={"Authorization": f"Bearer {rt2}"}) + assert response.status_code == 401 response = user_client.get("/api/user", headers={"Authorization": f"Bearer {at3}"}) assert response.status_code == 401 response = user_client.get("/api/auth/refresh", headers={"Authorization": f"Bearer {rt3}"}) assert response.status_code == 401 - - # RT4/AT4 chain should be rejected (unused parallel chain) response = user_client.get("/api/user", headers={"Authorization": f"Bearer {at4}"}) assert response.status_code == 401 response = user_client.get("/api/auth/refresh", headers={"Authorization": f"Bearer {rt4}"}) assert response.status_code == 401 - - # RT5/AT5 chain should be rejected (unused parallel chain) response = user_client.get("/api/user", headers={"Authorization": f"Bearer {at5}"}) assert response.status_code == 401 response = user_client.get("/api/auth/refresh", headers={"Authorization": f"Bearer {rt5}"}) assert response.status_code == 401 - - # Original RT1 should be rejected - response = user_client.get("/api/auth/refresh", headers={"Authorization": f"Bearer {rt1}"}) + response = user_client.get("/api/user", headers={"Authorization": f"Bearer {at6}"}) assert response.status_code == 401 - - # Try to use one of the parallel chain tokens (RT3), which should trigger breach detection - response = user_client.get("/api/auth/refresh", headers={"Authorization": f"Bearer {rt3}"}) - assert response.status_code == 401 - - # AT2 should now be rejected as the use of RT3 indicates a potential breach - response = user_client.get("/api/user", headers={"Authorization": f"Bearer {at2}"}) - assert response.status_code == 401 - - # RT2 should be rejected (part of the compromised chain) - response = user_client.get("/api/auth/refresh", headers={"Authorization": f"Bearer {rt2}"}) + response = user_client.get("/api/auth/refresh", headers={"Authorization": f"Bearer {rt6}"}) assert response.status_code == 401 def test_complex_token_chain2(user_client, username, password):