From eed1f4ff055e2c2d46f236ed68e65ae61c1fcc7c Mon Sep 17 00:00:00 2001 From: tnware Date: Sat, 8 Mar 2025 14:26:38 -0800 Subject: [PATCH 1/9] Token lifecycle management --- django_auth_adfs/__init__.py | 2 +- django_auth_adfs/config.py | 6 + django_auth_adfs/middleware.py | 324 +++++++++++++- django_auth_adfs/utils.py | 151 +++++++ docs/index.rst | 1 + docs/settings_ref.rst | 107 +++++ docs/token_lifecycle.rst | 459 ++++++++++++++++++++ pyproject.toml | 2 +- tests/test_middleware.py | 750 +++++++++++++++++++++++++++++++++ 9 files changed, 1799 insertions(+), 3 deletions(-) create mode 100644 django_auth_adfs/utils.py create mode 100644 docs/token_lifecycle.rst create mode 100644 tests/test_middleware.py diff --git a/django_auth_adfs/__init__.py b/django_auth_adfs/__init__.py index 05785d56..67f1dd52 100644 --- a/django_auth_adfs/__init__.py +++ b/django_auth_adfs/__init__.py @@ -4,4 +4,4 @@ Adding imports here will break setup.py """ -__version__ = '1.15.0' +__version__ = "1.16.0" diff --git a/django_auth_adfs/config.py b/django_auth_adfs/config.py index 317781f4..e4a4a57d 100644 --- a/django_auth_adfs/config.py +++ b/django_auth_adfs/config.py @@ -77,6 +77,12 @@ def __init__(self): ) self.PROXIES = None + # Token Lifecycle Middleware settings + self.TOKEN_REFRESH_THRESHOLD = 300 # 5 minutes + self.STORE_OBO_TOKEN = True + self.TOKEN_ENCRYPTION_SALT = b"django_auth_adfs_token_encryption" + self.LOGOUT_ON_TOKEN_REFRESH_FAILURE = False + self.VERSION = 'v1.0' self.SCOPES = [] diff --git a/django_auth_adfs/middleware.py b/django_auth_adfs/middleware.py index 649a2390..5bc9c006 100644 --- a/django_auth_adfs/middleware.py +++ b/django_auth_adfs/middleware.py @@ -1,6 +1,9 @@ """ Based on https://djangosnippets.org/snippets/1179/ """ + +import datetime +import logging from re import compile from django.conf import settings as django_settings @@ -8,7 +11,9 @@ from django.urls import reverse from django_auth_adfs.exceptions import MFARequired -from django_auth_adfs.config import settings +from django_auth_adfs.config import settings, provider_config +from django_auth_adfs.signals import post_authenticate +from django_auth_adfs.utils import _encrypt_token LOGIN_EXEMPT_URLS = [ compile(django_settings.LOGIN_URL.lstrip('/')), @@ -19,6 +24,8 @@ if hasattr(settings, 'LOGIN_EXEMPT_URLS'): LOGIN_EXEMPT_URLS += [compile(expr) for expr in settings.LOGIN_EXEMPT_URLS] +logger = logging.getLogger("django_auth_adfs") + class LoginRequiredMiddleware: """ @@ -49,3 +56,318 @@ def __call__(self, request): return redirect_to_login('django_auth_adfs:login-force-mfa') return self.get_response(request) + + +class TokenLifecycleMiddleware: + """ + Middleware that handles the complete lifecycle of ADFS access and refresh tokens. + + This middleware will: + 1. Store tokens in the session after successful authentication + 2. Check if the access token is about to expire + 3. Use the refresh token to get a new access token if needed + 4. Update the tokens in the session + 5. Handle OBO (On-Behalf-Of) tokens for Microsoft Graph API + + To enable this middleware, add it to your MIDDLEWARE setting: + 'django_auth_adfs.middleware.TokenLifecycleMiddleware' + + You can configure the token refresh behavior with these settings: + + TOKEN_REFRESH_THRESHOLD: Number of seconds before expiration to refresh (default: 300) + STORE_OBO_TOKEN: Boolean to enable/disable OBO token storage (default: True) + """ + + def __init__(self, get_response): + self.get_response = get_response + # Default settings + self.threshold = getattr(settings, "TOKEN_REFRESH_THRESHOLD", 300) + self.using_signed_cookies = ( + django_settings.SESSION_ENGINE + == "django.contrib.sessions.backends.signed_cookies" + ) + self.disable_for_signed_cookies = True + self.store_obo_token = getattr(settings, "STORE_OBO_TOKEN", True) + if self.using_signed_cookies: + logger.warning( + "TokenLifecycleMiddleware is enabled but you are using the signed_cookies session backend. " + "Storing tokens in signed cookies is not recommended for security reasons and cookie size limitations. " + "The middleware will not store tokens in the session. " + "Consider using database or cache-based sessions instead." + ) + + # Connect the signal receiver + post_authenticate.connect(self._capture_tokens_from_auth) + + def __call__(self, request): + if hasattr(request, "user"): + # Store tokens if they're available on the user object but not in the session + self._store_tokens_from_user(request) + if request.user.is_authenticated: + self._handle_token_refresh(request) + response = self.get_response(request) + + # This handles the case where authentication happens during the request + if hasattr(request, "user"): + self._store_tokens_from_user(request) + + return response + + def _store_tokens_from_user(self, request): + """ + Store tokens from the user object in the session if they exist + """ + if self.using_signed_cookies: + return + + if not hasattr(request, "user") or not request.user.is_authenticated: + return + + user = request.user + session_modified = False + + # Check if user has tokens that aren't in the session + if hasattr(user, "access_token") and user.access_token: + encrypted_token = _encrypt_token(user.access_token) + if encrypted_token and ( + not request.session.get("ADFS_ACCESS_TOKEN") + or request.session.get("ADFS_ACCESS_TOKEN") != encrypted_token + ): + request.session["ADFS_ACCESS_TOKEN"] = encrypted_token + session_modified = True + + if hasattr(user, "refresh_token") and user.refresh_token: + encrypted_token = _encrypt_token(user.refresh_token) + if encrypted_token and ( + not request.session.get("ADFS_REFRESH_TOKEN") + or request.session.get("ADFS_REFRESH_TOKEN") != encrypted_token + ): + request.session["ADFS_REFRESH_TOKEN"] = encrypted_token + session_modified = True + + if hasattr(user, "token_expires_at") and user.token_expires_at: + expires_at_str = user.token_expires_at.isoformat() + if ( + not request.session.get("ADFS_TOKEN_EXPIRES_AT") + or request.session.get("ADFS_TOKEN_EXPIRES_AT") != expires_at_str + ): + request.session["ADFS_TOKEN_EXPIRES_AT"] = expires_at_str + session_modified = True + + # Store OBO token if available and enabled + if ( + self.store_obo_token + and hasattr(user, "obo_access_token") + and user.obo_access_token + ): + encrypted_token = _encrypt_token(user.obo_access_token) + if encrypted_token and ( + not request.session.get("ADFS_OBO_ACCESS_TOKEN") + or request.session.get("ADFS_OBO_ACCESS_TOKEN") != encrypted_token + ): + request.session["ADFS_OBO_ACCESS_TOKEN"] = encrypted_token + session_modified = True + + # Store OBO token expiration if available + if ( + self.store_obo_token + and hasattr(user, "obo_token_expires_at") + and user.obo_token_expires_at + ): + obo_expires_at_str = user.obo_token_expires_at.isoformat() + if ( + not request.session.get("ADFS_OBO_TOKEN_EXPIRES_AT") + or request.session.get("ADFS_OBO_TOKEN_EXPIRES_AT") + != obo_expires_at_str + ): + request.session["ADFS_OBO_TOKEN_EXPIRES_AT"] = obo_expires_at_str + session_modified = True + + if session_modified: + request.session.modified = True + logger.debug("Stored tokens from user object in session") + + def _handle_token_refresh(self, request): + """ + Check if the access token needs to be refreshed and refresh it if needed + """ + if self.using_signed_cookies: + return + + if ( + "ADFS_ACCESS_TOKEN" not in request.session + or "ADFS_REFRESH_TOKEN" not in request.session + or "ADFS_TOKEN_EXPIRES_AT" not in request.session + ): + return + + try: + expires_at = datetime.datetime.fromisoformat( + request.session["ADFS_TOKEN_EXPIRES_AT"] + ) + now = datetime.datetime.now() + + if (expires_at - now).total_seconds() <= self.threshold: + logger.debug("Access token is about to expire, refreshing...") + self._refresh_tokens(request) + + if ( + self.store_obo_token + and "ADFS_OBO_ACCESS_TOKEN" in request.session + and "ADFS_OBO_TOKEN_EXPIRES_AT" in request.session + ): + obo_expires_at = datetime.datetime.fromisoformat( + request.session["ADFS_OBO_TOKEN_EXPIRES_AT"] + ) + if (obo_expires_at - now).total_seconds() <= self.threshold: + logger.debug("OBO token is about to expire, refreshing...") + self._refresh_obo_token(request) + + except Exception as e: + logger.warning(f"Error checking token expiration: {e}") + + def _refresh_tokens(self, request): + """ + Refresh the access token using the refresh token + """ + if self.using_signed_cookies: + return + + if "ADFS_REFRESH_TOKEN" not in request.session: + return + + try: + from django_auth_adfs.utils import _decrypt_token, _encrypt_token + + refresh_token = _decrypt_token(request.session["ADFS_REFRESH_TOKEN"]) + if not refresh_token: + logger.warning("Failed to decrypt refresh token") + return + + provider_config.load_config() + + data = { + "grant_type": "refresh_token", + "client_id": settings.CLIENT_ID, + "refresh_token": refresh_token, + } + + if settings.CLIENT_SECRET: + data["client_secret"] = settings.CLIENT_SECRET + + response = provider_config.session.post( + provider_config.token_endpoint, data=data, timeout=settings.TIMEOUT + ) + if response.status_code == 200: + token_data = response.json() + request.session["ADFS_ACCESS_TOKEN"] = _encrypt_token( + token_data["access_token"] + ) + if "refresh_token" in token_data: + request.session["ADFS_REFRESH_TOKEN"] = _encrypt_token( + token_data["refresh_token"] + ) + expires_in = int( + token_data.get("expires_in", 3600) + ) + expires_at = datetime.datetime.now() + datetime.timedelta( + seconds=expires_in + ) + request.session["ADFS_TOKEN_EXPIRES_AT"] = expires_at.isoformat() + + request.session.modified = True + logger.debug("Successfully refreshed tokens") + + if self.store_obo_token: + self._refresh_obo_token(request) + else: + logger.warning( + f"Failed to refresh token: {response.status_code} {response.text}" + ) + if settings.LOGOUT_ON_TOKEN_REFRESH_FAILURE: + from django.contrib.auth import logout + + logger.info("Logging out user due to token refresh failure") + logout(request) + + except Exception as e: + logger.exception(f"Error refreshing tokens: {e}") + if settings.LOGOUT_ON_TOKEN_REFRESH_FAILURE: + from django.contrib.auth import logout + + logger.info("Logging out user due to token refresh error") + logout(request) + + def _refresh_obo_token(self, request): + """ + Refresh the OBO token for Microsoft Graph API + """ + if not self.store_obo_token: + return + + if self.using_signed_cookies: + return + + if "ADFS_ACCESS_TOKEN" not in request.session: + return + + try: + from django_auth_adfs.utils import _decrypt_token, _encrypt_token + + access_token = _decrypt_token(request.session["ADFS_ACCESS_TOKEN"]) + if not access_token: + logger.warning("Failed to decrypt access token") + return + + from django_auth_adfs.backend import AdfsBaseBackend + + backend = AdfsBaseBackend() + obo_token = backend.get_obo_access_token(access_token) + + if obo_token: + request.session["ADFS_OBO_ACCESS_TOKEN"] = _encrypt_token(obo_token) + + expires_at = datetime.datetime.now() + datetime.timedelta(hours=1) + request.session["ADFS_OBO_TOKEN_EXPIRES_AT"] = expires_at.isoformat() + + request.session.modified = True + logger.debug("Successfully refreshed OBO token") + else: + logger.warning("Failed to get OBO token") + + except Exception as e: + logger.exception(f"Error refreshing OBO token: {e}") + + def _capture_tokens_from_auth( + self, sender, user, claims, adfs_response=None, **kwargs + ): + """ + Signal handler to capture tokens during authentication and store them on the user object. + This ensures the tokens are available for the middleware to store in the session. + """ + if not user: + return + + if hasattr(sender, "access_token"): + user.access_token = sender.access_token + elif adfs_response and "access_token" in adfs_response: + user.access_token = adfs_response["access_token"] + + if adfs_response and "refresh_token" in adfs_response: + user.refresh_token = adfs_response["refresh_token"] + + if adfs_response and "expires_in" in adfs_response: + user.token_expires_at = datetime.datetime.now() + datetime.timedelta( + seconds=int(adfs_response["expires_in"]) + ) + + if self.store_obo_token and hasattr(user, "access_token") and user.access_token: + try: + obo_token = sender.get_obo_access_token(user.access_token) + if obo_token: + user.obo_access_token = obo_token + user.obo_token_expires_at = ( + datetime.datetime.now() + datetime.timedelta(hours=1) + ) + except Exception as e: + logger.warning(f"Error getting OBO token during authentication: {e}") diff --git a/django_auth_adfs/utils.py b/django_auth_adfs/utils.py new file mode 100644 index 00000000..a34de790 --- /dev/null +++ b/django_auth_adfs/utils.py @@ -0,0 +1,151 @@ +""" +Utility functions for django-auth-adfs. + +Only relevant if you are using the Token Lifecycle Middleware. +""" + +import logging +import base64 + +from cryptography.fernet import Fernet +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC + +from django.conf import settings as django_settings +from django_auth_adfs.config import settings + +logger = logging.getLogger("django_auth_adfs") + + +def _get_encryption_key(): + """ + Derive a Fernet encryption key from Django's SECRET_KEY. + + The salt can be customized through the TOKEN_ENCRYPTION_SALT setting. + + Returns: + bytes: A 32-byte key suitable for Fernet encryption + """ + # Use Django's SECRET_KEY to derive a suitable encryption key + default_salt = b"django_auth_adfs_token_encryption" + salt = getattr(settings, "TOKEN_ENCRYPTION_SALT", default_salt) + + if isinstance(salt, str): + salt = salt.encode() + + kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=32, + salt=salt, + iterations=100000, + ) + key = base64.urlsafe_b64encode(kdf.derive(django_settings.SECRET_KEY.encode())) + return key + + +def _encrypt_token(token): + """ + Encrypt a token using Django's SECRET_KEY. + + Args: + token (str): The token to encrypt + + Returns: + str: The encrypted token as a string + """ + if not token: + return None + + try: + key = _get_encryption_key() + f = Fernet(key) + encrypted_token = f.encrypt(token.encode()) + return encrypted_token.decode() + except Exception as e: + logger.error(f"Error encrypting token: {e}") + return None + + +def _decrypt_token(encrypted_token): + """ + Decrypt a token that was encrypted using Django's SECRET_KEY. + + Args: + encrypted_token (str): The encrypted token + + Returns: + str: The decrypted token or None if decryption fails + """ + if not encrypted_token: + return None + + try: + key = _get_encryption_key() + f = Fernet(key) + decrypted_token = f.decrypt(encrypted_token.encode()) + return decrypted_token.decode() + except Exception as e: + logger.error(f"Error decrypting token: {e}") + return None + + +def _is_signed_cookies_disabled(): + """ + Check if token storage is disabled for signed_cookies session backend + """ + using_signed_cookies = ( + django_settings.SESSION_ENGINE + == "django.contrib.sessions.backends.signed_cookies" + ) + return using_signed_cookies + + +def get_access_token(request): + """ + Get the current access token from the session. + + The token is automatically decrypted before being returned. + + Args: + request: The current request object + + Returns: + str: The access token or None if not available + """ + if not hasattr(request, "session"): + return None + + if _is_signed_cookies_disabled(): + logger.debug("Token retrieval from signed_cookies session is disabled") + return None + + encrypted_token = request.session.get("ADFS_ACCESS_TOKEN") + return _decrypt_token(encrypted_token) + + +def get_obo_access_token(request): + """ + Get the current OBO (On-Behalf-Of) access token for Microsoft Graph API from the session. + + The token is automatically decrypted before being returned. + + Args: + request: The current request object + + Returns: + str: The OBO access token or None if not available + """ + if not hasattr(request, "session"): + return None + + if _is_signed_cookies_disabled(): + logger.debug("Token retrieval from signed_cookies session is disabled") + return None + + store_obo_token = getattr(settings, "STORE_OBO_TOKEN", True) + if not store_obo_token: + logger.debug("OBO token storage is disabled") + return None + + encrypted_token = request.session.get("ADFS_OBO_ACCESS_TOKEN") + return _decrypt_token(encrypted_token) diff --git a/docs/index.rst b/docs/index.rst index ff325ec4..006b008e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -38,6 +38,7 @@ Contents settings_ref config_guides middleware + token_lifecycle signals rest_framework demo diff --git a/docs/settings_ref.rst b/docs/settings_ref.rst index 9d7ad940..7648e193 100644 --- a/docs/settings_ref.rst +++ b/docs/settings_ref.rst @@ -497,3 +497,110 @@ PROXIES An optional proxy for all communication with the server. Example: ``{'http': '10.0.0.1', 'https': '10.0.0.2'}`` See the `requests documentation `__ for more information. + +TOKEN_REFRESH_THRESHOLD +--------------------------- +* **Default**: ``300`` (5 minutes) +* **Type**: ``integer`` +* **Unit**: seconds + +Used by the ``TokenLifecycleMiddleware`` to determine how long before token expiration to attempt a refresh. +This setting controls how proactively the middleware will refresh tokens before they expire. + +For example, with the default value of 300 seconds (5 minutes), if a token is set to expire in 4 minutes, +the middleware will attempt to refresh it during the next request. This helps ensure that users don't +experience disruptions due to token expiration during active sessions. + +.. code-block:: python + + # In your Django settings.py + # Refresh tokens 10 minutes before they expire + AUTH_ADFS = { + # other settings + "TOKEN_REFRESH_THRESHOLD": 600 + } + +STORE_OBO_TOKEN +------------------ +* **Default**: ``True`` +* **Type**: ``boolean`` + +Used by the ``TokenLifecycleMiddleware`` to enable or disable the storage of On-Behalf-Of (OBO) tokens +for Microsoft Graph API. Set to ``False`` if you don't need to access Microsoft Graph API. + +.. note:: + When using the ``TokenLifecycleMiddleware`` with Django's ``signed_cookies`` session backend, token storage + is always disabled for security reasons. This behavior cannot be overridden. If you need token storage, + you must use a different session backend like database or cache-based sessions. + +.. code-block:: python + + # In your Django settings.py + AUTH_ADFS = { + # other settings + "STORE_OBO_TOKEN": False + } + +TOKEN_ENCRYPTION_SALT +-------------------------- +* **Default**: ``b"django_auth_adfs_token_encryption"`` +* **Type**: ``string`` + +Used by the ``TokenLifecycleMiddleware`` to derive an encryption key for token encryption. +The salt is combined with Django's ``SECRET_KEY`` to create a unique encryption key. + +You can customize this value to use a different salt for token encryption: + +.. code-block:: python + + # In your Django settings.py + AUTH_ADFS = { + # other settings + "TOKEN_ENCRYPTION_SALT": "your-custom-salt-string" + } + +While the default value is defined as a bytes literal (with the ``b`` prefix) in the code, +you should simply provide a regular string in your settings. The middleware automatically +handles the conversion to bytes as needed. + +.. warning:: + If you change this setting after tokens have been stored in sessions, those tokens will no longer be decryptable. + This effectively invalidates all existing tokens, requiring users to re-authenticate. + Consider this when deploying changes to the salt in production environments. + +LOGOUT_ON_TOKEN_REFRESH_FAILURE +------------------------------- +* **Default**: ``False`` +* **Type**: ``boolean`` + +Used by the ``TokenLifecycleMiddleware`` to control whether users should be automatically logged out when token refresh fails. + +When set to ``True``, if a token refresh attempt fails (either due to an error response from the server or an exception), +the middleware will automatically log the user out of the Django application. + +When set to ``False`` (the default), the middleware will log the error but allow the user to continue using the application +until their session expires naturally, even though their tokens are no longer valid. + +This setting is particularly important for security considerations, as it determines how your application responds when a user's account +has been disabled in Azure AD/ADFS. When enabled, it can help ensure that users who have had their accounts disabled in the +identity provider are promptly logged out of your Django application, closing a potential security gap. + +This feature is disabled by default to prioritize user experience, but can be enabled for applications where security requirements +outweigh the potential disruption of unexpected logouts. + +.. code-block:: python + + # In your Django settings.py + AUTH_ADFS = { + # other settings + "LOGOUT_ON_TOKEN_REFRESH_FAILURE": True + } + +.. note:: + This setting only affects what happens when token refresh fails. It does not affect the initial authentication process + or what happens when tokens expire without a refresh attempt. + +.. important:: + Even for applications that don't make additional API calls after authentication, enabling this setting provides + an optional security mechanism that can help ensure that access revocation in Azure AD/ADFS is reflected in your + Django application. diff --git a/docs/token_lifecycle.rst b/docs/token_lifecycle.rst new file mode 100644 index 00000000..c442f9ce --- /dev/null +++ b/docs/token_lifecycle.rst @@ -0,0 +1,459 @@ +Token Lifecycle Middleware +========================== + +Traditionally, django-auth-adfs is used **exclusively** as an authentication solution - it handles user authentication +via ADFS/Azure AD and maps claims to Django users. It doesn't really care about the access tokens from Azure/ADFS after you've been authenticated. + +The Token Lifecycle Middleware extends django-auth-adfs beyond pure authentication to also handle the complete lifecycle of access tokens +after the authentication process. This creates a more integrated approach where: + +* The same application registration handles both authentication and resource access +* Tokens obtained during authentication are managed and refreshed automatically +* The application can make delegated API calls on behalf of the user +* The middleware can optionally log out users when token refresh fails + +How it works +------------ + +The ``TokenLifecycleMiddleware`` handles the entire token lifecycle: + +1. **Initial Token Capture**: Uses the ``post_authenticate`` signal to capture tokens during authentication +2. **Token Storage**: Automatically stores tokens in the users session after successful authentication +3. **Token Refresh**: Checks if the access token is about to expire and refreshes it if needed +4. **Optional Security Enforcement**: Can be configured to log out users when token refresh fails +5. **Session Management**: Keeps the session updated with the latest tokens +6. **OBO Token Management**: Handles On-Behalf-Of tokens for Microsoft Graph API access + +Read more: https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-on-behalf-of-flow#protocol-diagram + + +.. warning:: + The Token Lifecycle Middleware is a new feature in django-auth-adfs and is considered experimental. + Please be aware: + + **Currently no community support is guaranteed to be available for this feature** + + We recommend thoroughly testing this feature in your specific environment before deploying to production. + + Consider enabling the ``LOGOUT_ON_TOKEN_REFRESH_FAILURE`` setting, + which allows you to log out users when token refresh fails. + + +Configuration +------------- + +To enable the token lifecycle middleware, add it to your ``MIDDLEWARE`` setting in your Django settings file: + +.. code-block:: python + + MIDDLEWARE = [ + # ... other middleware + 'django.contrib.sessions.middleware.SessionMiddleware', + 'django.contrib.auth.middleware.AuthenticationMiddleware', + 'django_auth_adfs.middleware.TokenLifecycleMiddleware', # Add this line + # ... other middleware + ] + +.. important:: + The middleware must be placed after the ``SessionMiddleware`` and ``AuthenticationMiddleware``. + + +You can configure the token lifecycle behavior with these settings in your Django settings file: + +.. code-block:: python + + AUTH_ADFS = { + # other settings + + # Number of seconds before expiration to refresh (default: 300, i.e., 5 minutes) + "TOKEN_REFRESH_THRESHOLD": 300, + + # Enable or disable OBO token storage for Microsoft Graph API (default: True) + "STORE_OBO_TOKEN": True, + + # Custom salt for token encryption (optional) + # If not specified, a default salt is used + "TOKEN_ENCRYPTION_SALT": "your-custom-salt-string", + + # Automatically log out users when token refresh fails (default: False) + "LOGOUT_ON_TOKEN_REFRESH_FAILURE": False, + } + +.. warning:: + If you change the ``TOKEN_ENCRYPTION_SALT`` after tokens have been stored in sessions, those tokens will no longer be decryptable. + This effectively invalidates all existing tokens, requiring users to re-authenticate. + + Consider this when deploying changes to the salt in production environments. + +.. note:: + By default (``STORE_OBO_TOKEN = True``), the middleware will automatically request and store OBO tokens + for Microsoft Graph API access. If your application doesn't need to access Microsoft Graph API, + you can set ``STORE_OBO_TOKEN = False`` to disable this functionality completely. + See `the OBO token configuration section <#disabling-obo-token-functionality>`_ for more details. + +Considerations +-------------- + +- The middleware will automatically capture and store tokens during authentication using signals. +- You don't need to modify your views or authentication backends to store tokens. +- Token refresh only works for authenticated users. +- If the refresh token is invalid or expired, the middleware will not be able to refresh the access token. +- By default, the middleware will not log the user out if token refresh fails, but this behavior can be changed with the ``LOGOUT_ON_TOKEN_REFRESH_FAILURE`` setting. +- The middleware will not store tokens in the session when using the ``signed_cookies`` session backend by default. +- OBO token storage is enabled by default but can be disabled with the ``STORE_OBO_TOKEN`` setting. +- Using the OBO token versus the regular access token is dependent on the resources you are accessing and the permissions granted to your ADFS/Azure AD application. See `the token types section <#understanding-access-tokens-vs-obo-tokens>`_ for more details. + +**Token Refresh Failures** + +By default, when token refresh fails, the middleware logs the error but allows the user to continue using the application until their session expires naturally. This behavior can be changed with the ``LOGOUT_ON_TOKEN_REFRESH_FAILURE`` setting: + +- When set to ``False`` (default), users remain logged in even if their tokens can't be refreshed +- When set to ``True``, users are automatically logged out when token refresh fails + +When a user's account is disabled in Azure AD/ADFS, their existing Django sessions will remain active by default until they expire naturally. This can create a security gap where revoked users maintain access to your application. + +The ``LOGOUT_ON_TOKEN_REFRESH_FAILURE`` setting provides an option to address this concern by allowing you to configure the middleware to automatically log out users when their token refresh fails, which happens when their account has been disabled in the identity provider. + +**Existing Sessions** + +When deploying the Token Lifecycle Middleware to an existing application with active user sessions, be aware of the following: + +The middleware only captures tokens during the authentication process. Existing authenticated sessions won't have tokens stored in them, which means: + +- Users with existing sessions won't have access to token-dependent features until they re-authenticate +- Utility functions like ``get_access_token()`` and ``get_obo_access_token()`` will return ``None`` for these sessions +- API calls that depend on these tokens will fail for existing sessions + +The best approach is to ensure that all users re-authenticate after the middleware is deployed. + +Azure AD Application Configuration +---------------------------------- + +When using the Token Lifecycle Middleware, your Azure AD application registration needs additional permissions +beyond those required for simple authentication. This extends the standard authentication-only setup described in the :doc:`azure_ad_config_guide` with additional +API permissions needed for delegated access. + +.. important:: + Your Django application's session cookie age must be set to a value that is less than that of your ADFS/Azure AD application's refresh token lifetime. + + If a users refresh token has expired, the user will be required to re-authenticate to continue making delegated requests. + +Security Overview +----------------------- + +**Token Encryption** + +Tokens are automatically encrypted before being stored in the session and decrypted when they are retrieved. +The encryption is handled transparently by the middleware and utility functions. This provides an additional layer of security: + +- **Always Enabled**: Token encryption is always enabled and cannot be disabled +- **Encryption Method**: Tokens are encrypted using the Fernet symmetric encryption algorithm +- **Encryption Key**: The key is derived from Django's ``SECRET_KEY`` using PBKDF2 +- **Customizable Salt**: You can customize the encryption salt using the ``TOKEN_ENCRYPTION_SALT`` setting +- **Transparent Operation**: Encryption and decryption happen automatically when tokens are stored or retrieved + + +**Signed Cookies Session Backend Restriction** + +The middleware will not store tokens in the session when using Django's ``signed_cookies`` session backend: + +.. code-block:: python + + # This will not work with the token lifecycle middleware + SESSION_ENGINE = 'django.contrib.sessions.backends.signed_cookies' + +This is for a few reasons: + +1. **Size Limitations**: Cookies have size limitations (typically 4KB), which may be exceeded by tokens +2. **Security Risks**: Storing sensitive tokens in cookies increases the risk of token theft +3. **Performance**: Large cookies are sent with every request, increasing bandwidth usage + +If you're using the ``signed_cookies`` session backend and need token storage, you won't be able to use the token lifecycle middleware. + +.. note:: + This restriction only applies to the ``signed_cookies`` session backend. For other session backends (database, cache, file), + tokens are stored securely on the server and only a session ID is stored in the cookie. + +**Automatic OBO Token Acquisition** + +By default, the middleware automatically requests OBO tokens during authentication. If your application doesn't need OBO tokens, you can disable this behavior to reduce unnecessary token requests (see `the OBO token configuration section <#disabling-obo-token-functionality>`_ for more details). + +Disabling OBO Token Functionality +--------------------------------- + +By default, the Token Lifecycle Middleware automatically requests and stores OBO tokens for Microsoft Graph API access. If you don't need this functionality (for example, if your application doesn't interact with Microsoft Graph API), you can disable it completely: + +.. code-block:: python + + # In your Django settings.py + AUTH_ADFS = { + "STORE_OBO_TOKEN": False, + } + +When this setting is ``False``: + +1. The middleware will not request OBO tokens during authentication +2. The middleware will not store OBO tokens in the session +3. The middleware will not refresh OBO tokens +4. The ``get_obo_access_token`` utility function will always return ``None`` + +Note that disabling OBO tokens doesn't affect the regular access token functionality. Your application will still be able to use the access token obtained during authentication for its own resources and APIs that directly trust your application. + +See `the token types section <#understanding-access-tokens-vs-obo-tokens>`_ for more details. + +Accessing Tokens in Your Views +------------------------------ + +When building views that need to make requests using the Azure AD/ADFS tokens, you'll need to access the tokens stored in the session. + +Since tokens are encrypted in the session, Token Lifecycle Middleware provides utility functions in the ``django_auth_adfs.utils`` module to help you access tokens safely: + +.. code-block:: python + + # For your own APIs or APIs that trust your application directly + from django_auth_adfs.utils import get_access_token + + # For Microsoft Graph API or other APIs requiring delegated access + from django_auth_adfs.utils import get_obo_access_token + +These utility functions automatically handle decryption of the tokens, so you don't need to worry about the encryption details. + +.. warning:: + You should always use these utility functions to access tokens rather than accessing them directly from the session. + Direct access to ``request.session["ADFS_ACCESS_TOKEN"]`` will give you the encrypted token, not the actual token value. + +Examples +---------------------- + +Here are practical examples of using these utility functions in your views: + +Using with Microsoft Graph API +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +In this flow, we will exchange our access token from the authentication process for an OBO token to access Microsoft Graph API. + +This is the recommended flow for delegated access to Microsoft Graph API. + +.. code-block:: python + + from django.contrib.auth.decorators import login_required + from django.http import JsonResponse + from django_auth_adfs.utils import get_obo_access_token + import requests + + @login_required + def me_view(request): + """Get the user's profile from Microsoft Graph API""" + obo_token = get_obo_access_token(request) + + if not obo_token: + return JsonResponse({"error": "No OBO token available"}, status=401) + + headers = { + "Authorization": f"Bearer {obo_token}", + "Content-Type": "application/json", + } + + try: + response = requests.get("https://graph.microsoft.com/v1.0/me", headers=headers) + response.raise_for_status() + return JsonResponse(response.json()) + except requests.exceptions.RequestException as e: + return JsonResponse( + {"error": "Failed to fetch user profile", "details": str(e)}, + status=500 + ) + +Using with other resources +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The key difference here is to use the ``get_access_token`` function to get the token for the resource you are accessing. + +This is different than the ``get_obo_access_token`` function, which is used for Microsoft Graph API delegated access in the previous example. + +.. code-block:: python + + from rest_framework.views import APIView + from rest_framework.response import Response + from django_auth_adfs.utils import get_access_token + import requests + + class ExternalApiView(APIView): + def get(self, request): + """Call an API that accepts your application's token""" + token = get_access_token(request) + + if not token: + return Response({"error": "No access token available"}, status=401) + + headers = {"Authorization": f"Bearer {token}"} + response = requests.get("https://api.example.com/data", headers=headers) + + return Response(response.json()) + +Debug view +---------- + +The following example code demonstrates a debug view to check the values of the tokens stored in the session: + +.. code-block:: python + + from django.contrib.auth.decorators import login_required + from django.http import JsonResponse + from django_auth_adfs.utils import get_access_token, get_obo_access_token + from datetime import datetime + + @login_required + def debug_view(request): + """ + Debug view that provides detailed information about the authentication state, + tokens, and session data. + """ + if not request.user.is_authenticated: + return JsonResponse({"authenticated": False}) + + # Basic session token info + session_info = { + "has_access_token": "ADFS_ACCESS_TOKEN" in request.session, + "has_refresh_token": "ADFS_REFRESH_TOKEN" in request.session, + "has_expires_at": "ADFS_TOKEN_EXPIRES_AT" in request.session, + } + + # Add token expiration details if available + if "ADFS_TOKEN_EXPIRES_AT" in request.session: + from datetime import datetime + + try: + expires_at = datetime.fromisoformat( + request.session["ADFS_TOKEN_EXPIRES_AT"] + ) + now = datetime.now() + session_info["token_expires_at"] = expires_at.isoformat() + session_info["expires_in_seconds"] = max( + 0, int((expires_at - now).total_seconds()) + ) + session_info["is_expired"] = expires_at <= now + except (ValueError, TypeError) as e: + session_info["expiration_parse_error"] = str(e) + + # Show raw encrypted tokens for debugging + if "ADFS_ACCESS_TOKEN" in request.session: + raw_token = request.session["ADFS_ACCESS_TOKEN"] + session_info["raw_token_preview"] = f"{raw_token[:10]}...{raw_token[-10:]}" + session_info["raw_token_length"] = len(raw_token) + + # Try to decode as JWT without decryption (should fail if properly encrypted) + try: + import jwt + + jwt.decode(raw_token, options={"verify_signature": False}) + session_info["is_encrypted"] = False + except: + session_info["is_encrypted"] = True + + # Get properly decrypted access token + try: + from django_auth_adfs.utils import get_access_token + + access_token = get_access_token(request) + session_info["decrypted_access_token_available"] = access_token is not None + + if access_token: + if len(access_token) > 20: + session_info["decrypted_access_token_preview"] = ( + f"{access_token[:10]}...{access_token[-10:]}" + ) + session_info["decrypted_access_token_length"] = len(access_token) + + # Try to decode as JWT (should succeed if properly decrypted) + try: + import jwt + + decoded = jwt.decode(access_token, options={"verify_signature": False}) + session_info["jwt_decode_success"] = True + # Add some basic JWT info without exposing sensitive data + if "exp" in decoded: + from datetime import datetime + + exp_time = datetime.fromtimestamp(decoded["exp"]) + session_info["jwt_expiry"] = exp_time.isoformat() + except Exception as e: + session_info["jwt_decode_error"] = str(e) + except Exception as e: + session_info["access_token_error"] = f"Error getting access token: {str(e)}" + + # Check if OBO token is available + try: + from django_auth_adfs.utils import get_obo_access_token + + obo_token = get_obo_access_token(request) + obo_info = { + "has_obo_token": obo_token is not None, + } + + # Show raw encrypted OBO token if available + if "ADFS_OBO_ACCESS_TOKEN" in request.session: + raw_obo = request.session["ADFS_OBO_ACCESS_TOKEN"] + obo_info["raw_obo_preview"] = f"{raw_obo[:10]}...{raw_obo[-10:]}" + obo_info["raw_obo_length"] = len(raw_obo) + + if obo_token: + if len(obo_token) > 20: + obo_info["obo_token_preview"] = f"{obo_token[:10]}...{obo_token[-10:]}" + obo_info["obo_token_length"] = len(obo_token) + + # Try to decode as JWT (should succeed if properly decrypted) + try: + import jwt + + decoded = jwt.decode(obo_token, options={"verify_signature": False}) + obo_info["jwt_decode_success"] = True + # Add some basic JWT info without exposing sensitive data + if "exp" in decoded: + from datetime import datetime + + exp_time = datetime.fromtimestamp(decoded["exp"]) + obo_info["jwt_expiry"] = exp_time.isoformat() + except Exception as e: + obo_info["jwt_decode_error"] = str(e) + except Exception as e: + obo_info = {"error": f"Error getting OBO token: {str(e)}"} + + # Return all the collected information + return JsonResponse( + { + "authenticated": True, + "user": { + "id": request.user.id, + "username": request.user.username, + "email": request.user.email, + "is_staff": request.user.is_staff, + "is_superuser": request.user.is_superuser, + }, + "session_tokens": session_info, + "obo_token": obo_info, + }, + json_dumps_params={"indent": 2}, + ) + +Understanding Access Tokens vs. OBO Tokens +------------------------------------------ + +It's important to understand the difference between regular access tokens and OBO (On-Behalf-Of) tokens, especially in the context of delegated access versus application access: + +**Delegated Access vs. Application Access**: + There are two primary ways an application can access resources in Azure AD/ADFS: + + * **Application Access**: The application accesses resources directly with its own identity, not on behalf of a user. + + * **Delegated Access**: The application accesses resources on behalf of a signed-in user. + +**Regular Access Token**: + The token obtained during authentication with ADFS. + +**OBO (On-Behalf-Of) Token**: + The OBO flow is specifically designed for delegated access scenarios where your application needs to access resources (like Microsoft Graph) on behalf of the authenticated user. + + The middleware handles this exchange automatically when OBO token storage is enabled. + +For more information on the different types of permissions, see `the Microsoft documentation `_. diff --git a/pyproject.toml b/pyproject.toml index f0a11f5b..a9db6886 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = 'django-auth-adfs' -version = "1.15.0" # Remember to also change __init__.py version +version = "1.16.0" # Remember to also change __init__.py version description = 'A Django authentication backend for Microsoft ADFS and AzureAD' authors = ['Joris Beckers '] maintainers = ['Jonas Krüger Svensson ', 'Sondre Lillebø Gundersen '] diff --git a/tests/test_middleware.py b/tests/test_middleware.py new file mode 100644 index 00000000..a4df6ade --- /dev/null +++ b/tests/test_middleware.py @@ -0,0 +1,750 @@ +import datetime +from unittest.mock import Mock, patch +import time + +from django.contrib.auth import get_user_model +from django.contrib.auth.models import AnonymousUser +from django.test import TestCase, RequestFactory, override_settings +from django.contrib.sessions.backends.db import SessionStore + +from django_auth_adfs.middleware import TokenLifecycleMiddleware +from django_auth_adfs.config import settings +from django_auth_adfs.utils import ( + get_access_token, + get_obo_access_token, + _encrypt_token, + _decrypt_token, +) +from tests.settings import MIDDLEWARE + +User = get_user_model() + +# Add TokenLifecycleMiddleware to the existing middleware +MIDDLEWARE_WITH_TOKEN_LIFECYCLE = MIDDLEWARE + ( + "django_auth_adfs.middleware.TokenLifecycleMiddleware", +) + + +@override_settings(MIDDLEWARE=MIDDLEWARE_WITH_TOKEN_LIFECYCLE) +class TokenLifecycleMiddlewareTests(TestCase): + """ + Tests for the TokenLifecycleMiddleware. + + The middleware handles the lifecycle of ADFS tokens: + 1. Storing tokens from user object to session + 2. Detecting when tokens need to be refreshed + 3. Refreshing tokens when needed + 4. Handling OBO (On-Behalf-Of) tokens + """ + + def setUp(self): + """Set up test environment before each test""" + self.factory = RequestFactory() + self.middleware = TokenLifecycleMiddleware(lambda r: r) + self.user = User.objects.create_user(username="testuser") + self.request = self.factory.get("/") + self.request.user = self.user + self.request.session = SessionStore() + + # Group 1: Initialization Tests + + def test_init_with_default_settings(self): + """Test middleware initialization with default settings""" + middleware = TokenLifecycleMiddleware(lambda r: r) + self.assertEqual(middleware.threshold, 300) + self.assertTrue(middleware.store_obo_token) + self.assertFalse(middleware.using_signed_cookies) + + def test_init_with_custom_settings(self): + """Test middleware initialization with custom settings""" + with patch("django_auth_adfs.middleware.getattr") as mock_getattr: + # Mock getattr to return custom values + mock_getattr.side_effect = lambda obj, name, default: { + "TOKEN_REFRESH_THRESHOLD": 600, + "STORE_OBO_TOKEN": False, + }.get(name, default) + + middleware = TokenLifecycleMiddleware(lambda r: r) + + # Verify custom settings are applied + self.assertEqual(middleware.threshold, 600) + self.assertFalse(middleware.store_obo_token) + + # Group 2: Token Storage Tests + + def test_store_tokens_from_user(self): + """Test storing tokens from user object to session""" + # Set tokens on user object + setattr(self.user, "access_token", "test_access_token") + setattr(self.user, "refresh_token", "test_refresh_token") + setattr( + self.user, + "token_expires_at", + datetime.datetime.now() + datetime.timedelta(hours=1), + ) + setattr(self.user, "obo_access_token", "test_obo_token") + setattr( + self.user, + "obo_token_expires_at", + datetime.datetime.now() + datetime.timedelta(hours=1), + ) + + # Call middleware + self.middleware._store_tokens_from_user(self.request) + + # Check session - decrypt tokens before comparing + self.assertEqual( + _decrypt_token(self.request.session["ADFS_ACCESS_TOKEN"]), + "test_access_token", + ) + self.assertEqual( + _decrypt_token(self.request.session["ADFS_REFRESH_TOKEN"]), + "test_refresh_token", + ) + self.assertEqual( + _decrypt_token(self.request.session["ADFS_OBO_ACCESS_TOKEN"]), + "test_obo_token", + ) + self.assertTrue("ADFS_TOKEN_EXPIRES_AT" in self.request.session) + self.assertTrue("ADFS_OBO_TOKEN_EXPIRES_AT" in self.request.session) + + def test_store_partial_tokens_from_user(self): + """Test storing partial token data (only access token without refresh token)""" + # Set only access token on user object + setattr(self.user, "access_token", "test_access_token") + setattr( + self.user, + "token_expires_at", + datetime.datetime.now() + datetime.timedelta(hours=1), + ) + # No refresh token or OBO token + + # Call middleware + self.middleware._store_tokens_from_user(self.request) + + # Check session - should have access token but not refresh token + self.assertEqual( + _decrypt_token(self.request.session["ADFS_ACCESS_TOKEN"]), + "test_access_token", + ) + self.assertTrue("ADFS_TOKEN_EXPIRES_AT" in self.request.session) + self.assertFalse("ADFS_REFRESH_TOKEN" in self.request.session) + self.assertFalse("ADFS_OBO_ACCESS_TOKEN" in self.request.session) + + def test_store_tokens_from_user_with_signed_cookies(self): + """Test that tokens are not stored when using signed cookies""" + self.middleware.using_signed_cookies = True + setattr(self.user, "access_token", "test_access_token") + + self.middleware._store_tokens_from_user(self.request) + self.assertFalse("ADFS_ACCESS_TOKEN" in self.request.session) + + def test_session_modified_flag(self): + """Test session.modified is set correctly during token storage operations""" + # Test 1: When tokens are added, session.modified should be True + self.request.session.modified = False + setattr(self.user, "access_token", "new_token") + self.middleware._store_tokens_from_user(self.request) + # With encryption, session will always be modified when tokens are stored + # so we can't test for False here anymore + self.assertTrue(self.request.session.modified) + + # Test 2: Reset and test when no changes are made + self.request.session.modified = False + # Remove the token attribute so no changes will be made + delattr(self.user, "access_token") + self.middleware._store_tokens_from_user(self.request) + self.assertFalse(self.request.session.modified) + + # Group 3: Token Refresh Detection Tests + + def test_handle_token_refresh_not_needed(self): + """Test token refresh when it's not needed""" + self.request.session["ADFS_ACCESS_TOKEN"] = "test_access_token" + self.request.session["ADFS_REFRESH_TOKEN"] = "test_refresh_token" + expires_at = datetime.datetime.now() + datetime.timedelta( + hours=1 + ) # 1 hour to expiry + self.request.session["ADFS_TOKEN_EXPIRES_AT"] = expires_at.isoformat() + + with patch.object(self.middleware, "_refresh_tokens") as mock_refresh: + self.middleware._handle_token_refresh(self.request) + mock_refresh.assert_not_called() + + def test_handle_token_refresh_needed(self): + """Test token refresh when it's needed""" + self.request.session["ADFS_ACCESS_TOKEN"] = "test_access_token" + self.request.session["ADFS_REFRESH_TOKEN"] = "test_refresh_token" + expires_at = datetime.datetime.now() + datetime.timedelta( + seconds=60 + ) # 1 minute to expiry + self.request.session["ADFS_TOKEN_EXPIRES_AT"] = expires_at.isoformat() + + with patch.object(self.middleware, "_refresh_tokens") as mock_refresh: + self.middleware._handle_token_refresh(self.request) + mock_refresh.assert_called_once_with(self.request) + + def test_handle_expired_token(self): + """Test token refresh when token is already expired""" + self.request.session["ADFS_ACCESS_TOKEN"] = "test_access_token" + self.request.session["ADFS_REFRESH_TOKEN"] = "test_refresh_token" + expires_at = datetime.datetime.now() - datetime.timedelta( + hours=1 + ) # Expired 1 hour ago + self.request.session["ADFS_TOKEN_EXPIRES_AT"] = expires_at.isoformat() + + with patch.object(self.middleware, "_refresh_tokens") as mock_refresh: + self.middleware._handle_token_refresh(self.request) + mock_refresh.assert_called_once_with(self.request) + + def test_obo_token_expires_before_access_token(self): + """Test when OBO token expires before access token""" + # Set up access token with long expiry + self.request.session["ADFS_ACCESS_TOKEN"] = "access_token" + self.request.session["ADFS_REFRESH_TOKEN"] = "refresh_token" + access_token_expires_at = datetime.datetime.now() + datetime.timedelta(hours=1) + self.request.session["ADFS_TOKEN_EXPIRES_AT"] = ( + access_token_expires_at.isoformat() + ) + + # Set up OBO token with short expiry + self.request.session["ADFS_OBO_ACCESS_TOKEN"] = "obo_token" + obo_expires_at = datetime.datetime.now() + datetime.timedelta(seconds=30) + self.request.session["ADFS_OBO_TOKEN_EXPIRES_AT"] = obo_expires_at.isoformat() + + # Should refresh only OBO token + with patch.object( + self.middleware, "_refresh_tokens" + ) as mock_refresh_token, patch.object( + self.middleware, "_refresh_obo_token" + ) as mock_refresh_obo: + self.middleware._handle_token_refresh(self.request) + + # Verify only OBO token is refreshed, not the access token + mock_refresh_token.assert_not_called() + mock_refresh_obo.assert_called_once_with(self.request) + + # Verify session state remains unchanged for access token + self.assertEqual(self.request.session["ADFS_ACCESS_TOKEN"], "access_token") + self.assertEqual( + self.request.session["ADFS_REFRESH_TOKEN"], "refresh_token" + ) + self.assertEqual( + self.request.session["ADFS_TOKEN_EXPIRES_AT"], + access_token_expires_at.isoformat(), + ) + + # Group 4: Token Refresh Implementation Tests + + @patch("django_auth_adfs.middleware.provider_config") + def test_refresh_token_success(self, mock_provider_config): + """Test successful token refresh""" + # Set up mock response + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "access_token": "new_access_token", + "refresh_token": "new_refresh_token", + "expires_in": 3600, + } + + # Configure the mock + mock_provider_config.session.post.return_value = mock_response + mock_provider_config.token_endpoint = ( + "https://adfs.example.com/adfs/oauth2/token" + ) + + # Set up session with expired token + self.request.session["ADFS_ACCESS_TOKEN"] = _encrypt_token("old_access_token") + self.request.session["ADFS_REFRESH_TOKEN"] = _encrypt_token("old_refresh_token") + self.request.session["ADFS_TOKEN_EXPIRES_AT"] = ( + datetime.datetime.now() - datetime.timedelta(minutes=5) + ).isoformat() + + # Mock the OBO token refresh to prevent real HTTP requests + with patch.object(self.middleware, "_refresh_obo_token") as mock_refresh_obo: + # Call refresh method + self.middleware._refresh_tokens(self.request) + + # Check that tokens were updated + self.assertEqual( + _decrypt_token(self.request.session["ADFS_ACCESS_TOKEN"]), + "new_access_token", + ) + self.assertEqual( + _decrypt_token(self.request.session["ADFS_REFRESH_TOKEN"]), + "new_refresh_token", + ) + + @patch("django_auth_adfs.middleware.provider_config") + def test_refresh_token_without_new_refresh_token(self, mock_provider_config): + """Test token refresh when response doesn't include a new refresh token""" + # Set up mock response without refresh_token + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "access_token": "new_access_token", + "expires_in": 3600, + } + + # Configure the mock + mock_provider_config.session.post.return_value = mock_response + mock_provider_config.token_endpoint = ( + "https://adfs.example.com/adfs/oauth2/token" + ) + + # Set up session with expired token + self.request.session["ADFS_ACCESS_TOKEN"] = _encrypt_token("old_access_token") + self.request.session["ADFS_REFRESH_TOKEN"] = _encrypt_token("old_refresh_token") + self.request.session["ADFS_TOKEN_EXPIRES_AT"] = ( + datetime.datetime.now() - datetime.timedelta(minutes=5) + ).isoformat() + + # Mock the OBO token refresh to prevent real HTTP requests + with patch.object(self.middleware, "_refresh_obo_token") as mock_refresh_obo: + # Call refresh method + self.middleware._refresh_tokens(self.request) + + # Check that access token was updated but refresh token remains the same + self.assertEqual( + _decrypt_token(self.request.session["ADFS_ACCESS_TOKEN"]), + "new_access_token", + ) + self.assertEqual( + _decrypt_token(self.request.session["ADFS_REFRESH_TOKEN"]), + "old_refresh_token", + ) + + @patch("django_auth_adfs.backend.AdfsBaseBackend") + def test_refresh_obo_token_success(self, mock_backend_class): + """Test successful OBO token refresh""" + # Set up mock backend + mock_backend = Mock() + mock_backend.get_obo_access_token.return_value = "new_obo_token" + mock_backend_class.return_value = mock_backend + + # Ensure OBO token storage is enabled + self.middleware.store_obo_token = True + + # Set up session with expired OBO token but valid access token + self.request.session["ADFS_ACCESS_TOKEN"] = _encrypt_token("valid_access_token") + self.request.session["ADFS_REFRESH_TOKEN"] = _encrypt_token( + "valid_refresh_token" + ) + self.request.session["ADFS_OBO_ACCESS_TOKEN"] = _encrypt_token( + "expired_obo_token" + ) + self.request.session["ADFS_TOKEN_EXPIRES_AT"] = ( + datetime.datetime.now() + datetime.timedelta(hours=1) + ).isoformat() + self.request.session["ADFS_OBO_TOKEN_EXPIRES_AT"] = ( + datetime.datetime.now() - datetime.timedelta(minutes=5) + ).isoformat() + + # Call handle token refresh directly + self.middleware._handle_token_refresh(self.request) + + # Verify the backend was called with the correct token + mock_backend.get_obo_access_token.assert_called_once_with("valid_access_token") + + # Verify the new token was stored in the session + self.assertEqual( + _decrypt_token(self.request.session["ADFS_OBO_ACCESS_TOKEN"]), + "new_obo_token", + ) + self.assertTrue("ADFS_OBO_TOKEN_EXPIRES_AT" in self.request.session) + + def test_refresh_obo_token_failure(self): + """Test failed OBO token refresh""" + self.request.session["ADFS_ACCESS_TOKEN"] = "test_access_token" + + # Store original session state to verify it's not modified + original_session_data = dict(self.request.session) + self.request.session.modified = False + + with patch("django_auth_adfs.backend.AdfsBaseBackend") as mock_backend: + mock_backend.return_value.get_obo_access_token.return_value = None + + self.middleware._refresh_obo_token(self.request) + + # Verify session not modified + self.assertFalse("ADFS_OBO_ACCESS_TOKEN" in self.request.session) + self.assertEqual(dict(self.request.session), original_session_data) + self.assertFalse(self.request.session.modified) + + def test_obo_token_without_access_token(self): + """Test OBO token handling when access token is missing""" + # Only OBO token exists + self.request.session["ADFS_OBO_ACCESS_TOKEN"] = "obo_token" + self.request.session["ADFS_OBO_TOKEN_EXPIRES_AT"] = ( + datetime.datetime.now().isoformat() + ) + # No ADFS_ACCESS_TOKEN + + # Store original session state to verify it's not modified + original_session_data = dict(self.request.session) + self.request.session.modified = False + + self.middleware._refresh_obo_token(self.request) + + # Verify session not modified + self.assertEqual(dict(self.request.session), original_session_data) + self.assertFalse(self.request.session.modified) + + # Group 5: Authentication Signal Tests + + def test_capture_tokens_from_auth(self): + """Test capturing tokens during authentication""" + sender = Mock() + sender.access_token = "sender_access_token" + sender.get_obo_access_token.return_value = "obo_token" + + adfs_response = { + "access_token": "response_access_token", + "refresh_token": "response_refresh_token", + "expires_in": 3600, + } + + self.middleware._capture_tokens_from_auth( + sender=sender, user=self.user, claims={}, adfs_response=adfs_response + ) + + # Check user object has temporary token attributes + self.assertEqual(getattr(self.user, "access_token"), "sender_access_token") + self.assertEqual(getattr(self.user, "refresh_token"), "response_refresh_token") + self.assertTrue(hasattr(self.user, "token_expires_at")) + self.assertEqual(getattr(self.user, "obo_access_token"), "obo_token") + self.assertTrue(hasattr(self.user, "obo_token_expires_at")) + + def test_capture_tokens_from_adfs_response_only(self): + """Test capturing tokens when they're only in the ADFS response, not on sender""" + sender = Mock(spec=[]) # Create a mock without access_token attribute + # Ensure get_obo_access_token is available but returns None + sender.get_obo_access_token = Mock(return_value=None) + + adfs_response = { + "access_token": "response_access_token", + "refresh_token": "response_refresh_token", + "expires_in": 3600, + } + + self.middleware._capture_tokens_from_auth( + sender=sender, user=self.user, claims={}, adfs_response=adfs_response + ) + + # Check user object has temporary token attributes from adfs_response + self.assertEqual(getattr(self.user, "access_token"), "response_access_token") + self.assertEqual(getattr(self.user, "refresh_token"), "response_refresh_token") + self.assertTrue(hasattr(self.user, "token_expires_at")) + # No OBO token should be set + self.assertFalse(hasattr(self.user, "obo_access_token")) + + # Group 6: Middleware Call Tests + + def test_middleware_call_with_authenticated_user(self): + """Test the complete middleware request/response cycle with authenticated user""" + # Set up user with tokens + setattr(self.user, "access_token", "test_access_token") + setattr(self.user, "refresh_token", "test_refresh_token") + setattr( + self.user, + "token_expires_at", + datetime.datetime.now() + datetime.timedelta(hours=1), + ) + + # Create request with authenticated user + request = self.factory.get("/") + request.user = self.user + request.session = SessionStore() + + # Call middleware + response = self.middleware(request) + + # Check that tokens were stored in session + self.assertEqual( + _decrypt_token(request.session["ADFS_ACCESS_TOKEN"]), "test_access_token" + ) + self.assertEqual( + _decrypt_token(request.session["ADFS_REFRESH_TOKEN"]), "test_refresh_token" + ) + self.assertTrue("ADFS_TOKEN_EXPIRES_AT" in request.session) + + def test_middleware_post_response_token_storage(self): + """Test tokens added during view processing are stored after response""" + + def get_response_with_token_addition(request): + # Simulate a view that adds tokens to the user + setattr(request.user, "access_token", "view_added_token") + setattr( + request.user, + "token_expires_at", + datetime.datetime.now() + datetime.timedelta(hours=1), + ) + return Mock() + + # Create middleware with our custom get_response + middleware = TokenLifecycleMiddleware(get_response_with_token_addition) + + # Create request with authenticated user + request = self.factory.get("/") + request.user = self.user + request.session = SessionStore() + + # Call middleware + response = middleware(request) + + # Check that tokens were stored in session + self.assertEqual( + _decrypt_token(request.session["ADFS_ACCESS_TOKEN"]), "view_added_token" + ) + self.assertTrue("ADFS_TOKEN_EXPIRES_AT" in request.session) + + def test_middleware_without_user(self): + """Test middleware behavior when request has no user""" + request = self.factory.get("/") + request.session = SessionStore() + + response = self.middleware(request) + # Should not raise any errors + self.assertEqual(response, request) + + def test_middleware_with_unauthenticated_user(self): + """Test middleware behavior with unauthenticated user""" + request = self.factory.get("/") + request.user = Mock(is_authenticated=False) + request.session = SessionStore() + + with patch.object(self.middleware, "_handle_token_refresh") as mock_refresh: + response = self.middleware(request) + mock_refresh.assert_not_called() + + # Group 7: Error Handling Tests + + def test_handle_malformed_expiry_time(self): + """Test handling of malformed expiry time in session""" + self.request.session["ADFS_ACCESS_TOKEN"] = "test_access_token" + self.request.session["ADFS_REFRESH_TOKEN"] = "test_refresh_token" + self.request.session["ADFS_TOKEN_EXPIRES_AT"] = "invalid_datetime" + + # Store original session state to verify it's not modified inappropriately + original_session_data = dict(self.request.session) + self.request.session.modified = False + + # Should handle gracefully without error + self.middleware._handle_token_refresh(self.request) + + # Verify session wasn't modified inappropriately + self.assertEqual(dict(self.request.session), original_session_data) + self.assertFalse(self.request.session.modified) + + def test_handle_incomplete_token_state(self): + """Test handling when only some token data exists in session""" + # Only access token, no refresh token + self.request.session["ADFS_ACCESS_TOKEN"] = "test_access_token" + self.request.session["ADFS_TOKEN_EXPIRES_AT"] = ( + datetime.datetime.now().isoformat() + ) + # Missing ADFS_REFRESH_TOKEN + + # Store original session state to verify it's not modified inappropriately + original_session_data = dict(self.request.session) + self.request.session.modified = False + + self.middleware._handle_token_refresh(self.request) + + # Verify session wasn't modified inappropriately + self.assertEqual(dict(self.request.session), original_session_data) + self.assertFalse(self.request.session.modified) + + def test_handle_malformed_tokens(self): + """Test handling of malformed/corrupt token data in session""" + # Invalid token format + self.request.session["ADFS_ACCESS_TOKEN"] = {"malformed": "data"} + self.request.session["ADFS_REFRESH_TOKEN"] = None + self.request.session["ADFS_TOKEN_EXPIRES_AT"] = "not-a-date" + + # Store original session state to verify it's not modified inappropriately + original_session_data = dict(self.request.session) + self.request.session.modified = False + + self.middleware._handle_token_refresh(self.request) + + # Verify session wasn't modified inappropriately + self.assertEqual(dict(self.request.session), original_session_data) + self.assertFalse(self.request.session.modified) + + def test_disabled_obo_token_functionality(self): + """Test that OBO token functionality is disabled when STORE_OBO_TOKEN is False""" + # Set up a user with an access token and OBO token + self.user.access_token = "test_access_token" + self.user.obo_access_token = "test_obo_token" + + # Patch the middleware to disable OBO token storage + with patch.object(self.middleware, "store_obo_token", False): + # Store tokens from user + self.middleware._store_tokens_from_user(self.request) + + # Verify access token is stored but OBO token is not + self.assertTrue("ADFS_ACCESS_TOKEN" in self.request.session) + self.assertFalse("ADFS_OBO_ACCESS_TOKEN" in self.request.session) + + # Verify get_obo_access_token returns None when disabled + with patch("django_auth_adfs.utils.settings") as mock_settings: + mock_settings.STORE_OBO_TOKEN = False + self.assertIsNone(get_obo_access_token(self.request)) + + def test_token_encryption(self): + """Test that tokens are properly encrypted and decrypted""" + # Test encryption and decryption directly + original_token = "test_access_token" + encrypted_token = _encrypt_token(original_token) + + # Verify the token is encrypted (should be different from original) + self.assertNotEqual(original_token, encrypted_token) + + # Verify the token can be decrypted back to the original + decrypted_token = _decrypt_token(encrypted_token) + self.assertEqual(original_token, decrypted_token) + + # Test the middleware stores encrypted tokens + self.user.access_token = original_token + self.middleware._store_tokens_from_user(self.request) + + # Verify the token in the session is encrypted + session_token = self.request.session.get("ADFS_ACCESS_TOKEN") + self.assertNotEqual(original_token, session_token) + + # Test the utility function decrypts the token + retrieved_token = get_access_token(self.request) + self.assertEqual(original_token, retrieved_token) + + # Test with OBO token + original_obo_token = "test_obo_token" + self.user.obo_access_token = original_obo_token + self.middleware._store_tokens_from_user(self.request) + + # Verify the OBO token in the session is encrypted + session_obo_token = self.request.session.get("ADFS_OBO_ACCESS_TOKEN") + self.assertNotEqual(original_obo_token, session_obo_token) + + # Test the utility function decrypts the OBO token + retrieved_obo_token = get_obo_access_token(self.request) + self.assertEqual(original_obo_token, retrieved_obo_token) + + @override_settings(TOKEN_ENCRYPTION_SALT="custom-salt-for-testing") + def test_custom_encryption_salt(self): + """Test that custom encryption salt changes the encrypted token value""" + # First, encrypt a token with the default salt + original_token = "test_access_token" + default_encrypted_token = _encrypt_token(original_token) + + # Now, encrypt the same token with a custom salt (set via override_settings) + with patch("django_auth_adfs.utils.settings") as mock_settings: + mock_settings.TOKEN_ENCRYPTION_SALT = "custom-salt-for-testing" + custom_encrypted_token = _encrypt_token(original_token) + + # The encrypted tokens should be different due to different salts + self.assertNotEqual(default_encrypted_token, custom_encrypted_token) + + # But both should decrypt to the original token when using the correct salt + with patch("django_auth_adfs.utils.settings") as mock_settings: + mock_settings.TOKEN_ENCRYPTION_SALT = "custom-salt-for-testing" + decrypted_token = _decrypt_token(custom_encrypted_token) + + self.assertEqual(original_token, decrypted_token) + + # A token encrypted with one salt should not be decryptable with another + with patch("django_auth_adfs.utils.settings") as mock_settings: + mock_settings.TOKEN_ENCRYPTION_SALT = "different-salt" + # The function catches exceptions and returns None, so check for None + self.assertIsNone(_decrypt_token(custom_encrypted_token)) + + @patch("django_auth_adfs.middleware.provider_config") + def test_refresh_token_failure_with_logout(self, mock_provider_config): + """Test token refresh failure with LOGOUT_ON_TOKEN_REFRESH_FAILURE enabled""" + # Setup + self.request.session["ADFS_ACCESS_TOKEN"] = _encrypt_token("test_access_token") + self.request.session["ADFS_REFRESH_TOKEN"] = _encrypt_token( + "test_refresh_token" + ) + expires_at = datetime.datetime.now() - datetime.timedelta(minutes=5) + self.request.session["ADFS_TOKEN_EXPIRES_AT"] = expires_at.isoformat() + + # Mock the response from the token endpoint + mock_response = Mock() + mock_response.status_code = 400 + mock_response.text = "Invalid refresh token" + mock_provider_config.session.post.return_value = mock_response + + # Enable the setting + with patch("django_auth_adfs.middleware.settings") as mock_settings: + mock_settings.LOGOUT_ON_TOKEN_REFRESH_FAILURE = True + mock_settings.CLIENT_ID = "test_client_id" + mock_settings.CLIENT_SECRET = "test_client_secret" + mock_settings.TIMEOUT = 5 + + # Mock the logout function + with patch("django.contrib.auth.logout") as mock_logout: + self.middleware._refresh_tokens(self.request) + + # Verify logout was called + mock_logout.assert_called_once_with(self.request) + + @patch("django_auth_adfs.middleware.provider_config") + def test_refresh_token_failure_without_logout(self, mock_provider_config): + """Test token refresh failure with LOGOUT_ON_TOKEN_REFRESH_FAILURE disabled""" + # Setup + self.request.session["ADFS_ACCESS_TOKEN"] = _encrypt_token("test_access_token") + self.request.session["ADFS_REFRESH_TOKEN"] = _encrypt_token( + "test_refresh_token" + ) + expires_at = datetime.datetime.now() - datetime.timedelta(minutes=5) + self.request.session["ADFS_TOKEN_EXPIRES_AT"] = expires_at.isoformat() + + # Mock the response from the token endpoint + mock_response = Mock() + mock_response.status_code = 400 + mock_response.text = "Invalid refresh token" + mock_provider_config.session.post.return_value = mock_response + + # Disable the setting (default) + with patch("django_auth_adfs.middleware.settings") as mock_settings: + mock_settings.LOGOUT_ON_TOKEN_REFRESH_FAILURE = False + mock_settings.CLIENT_ID = "test_client_id" + mock_settings.CLIENT_SECRET = "test_client_secret" + mock_settings.TIMEOUT = 5 + + # Mock the logout function + with patch("django.contrib.auth.logout") as mock_logout: + self.middleware._refresh_tokens(self.request) + + # Verify logout was not called + mock_logout.assert_not_called() + + @patch("django_auth_adfs.middleware.provider_config") + def test_refresh_token_exception_with_logout(self, mock_provider_config): + """Test token refresh exception with LOGOUT_ON_TOKEN_REFRESH_FAILURE enabled""" + # Setup + self.request.session["ADFS_ACCESS_TOKEN"] = _encrypt_token("test_access_token") + self.request.session["ADFS_REFRESH_TOKEN"] = _encrypt_token( + "test_refresh_token" + ) + expires_at = datetime.datetime.now() - datetime.timedelta(minutes=5) + self.request.session["ADFS_TOKEN_EXPIRES_AT"] = expires_at.isoformat() + + # Make the request raise an exception + mock_provider_config.session.post.side_effect = Exception("Connection error") + + # Enable the setting + with patch("django_auth_adfs.middleware.settings") as mock_settings: + mock_settings.LOGOUT_ON_TOKEN_REFRESH_FAILURE = True + mock_settings.CLIENT_ID = "test_client_id" + mock_settings.CLIENT_SECRET = "test_client_secret" + mock_settings.TIMEOUT = 5 + + # Mock the logout function + with patch("django.contrib.auth.logout") as mock_logout: + self.middleware._refresh_tokens(self.request) + + # Verify logout was called + mock_logout.assert_called_once_with(self.request) From d9b301856c12b548a7c72a389787738bb81c9c8f Mon Sep 17 00:00:00 2001 From: tnware Date: Sat, 8 Mar 2025 18:11:34 -0800 Subject: [PATCH 2/9] Optimize token middleware --- django_auth_adfs/middleware.py | 208 ++++++-------- docs/token_lifecycle.rst | 10 +- tests/test_middleware.py | 500 +++++++++++++++++++++++++-------- 3 files changed, 471 insertions(+), 247 deletions(-) diff --git a/django_auth_adfs/middleware.py b/django_auth_adfs/middleware.py index 5bc9c006..90bb5344 100644 --- a/django_auth_adfs/middleware.py +++ b/django_auth_adfs/middleware.py @@ -63,12 +63,19 @@ class TokenLifecycleMiddleware: Middleware that handles the complete lifecycle of ADFS access and refresh tokens. This middleware will: - 1. Store tokens in the session after successful authentication + 1. Store tokens in the session after successful authentication via signal handler 2. Check if the access token is about to expire 3. Use the refresh token to get a new access token if needed 4. Update the tokens in the session 5. Handle OBO (On-Behalf-Of) tokens for Microsoft Graph API + Token Flow: + - During authentication, tokens are received from ADFS + - The middleware stores these tokens directly in the session via signal handler + - Tokens are managed entirely in the session + - Token refresh operations work directly with the session + - The utility functions get_access_token() and get_obo_access_token() retrieve tokens from the session + To enable this middleware, add it to your MIDDLEWARE setting: 'django_auth_adfs.middleware.TokenLifecycleMiddleware' @@ -100,93 +107,13 @@ def __init__(self, get_response): post_authenticate.connect(self._capture_tokens_from_auth) def __call__(self, request): - if hasattr(request, "user"): - # Store tokens if they're available on the user object but not in the session - self._store_tokens_from_user(request) - if request.user.is_authenticated: - self._handle_token_refresh(request) + if hasattr(request, "user") and request.user.is_authenticated: + # Only handle token refresh + self._handle_token_refresh(request) + response = self.get_response(request) - - # This handles the case where authentication happens during the request - if hasattr(request, "user"): - self._store_tokens_from_user(request) - return response - def _store_tokens_from_user(self, request): - """ - Store tokens from the user object in the session if they exist - """ - if self.using_signed_cookies: - return - - if not hasattr(request, "user") or not request.user.is_authenticated: - return - - user = request.user - session_modified = False - - # Check if user has tokens that aren't in the session - if hasattr(user, "access_token") and user.access_token: - encrypted_token = _encrypt_token(user.access_token) - if encrypted_token and ( - not request.session.get("ADFS_ACCESS_TOKEN") - or request.session.get("ADFS_ACCESS_TOKEN") != encrypted_token - ): - request.session["ADFS_ACCESS_TOKEN"] = encrypted_token - session_modified = True - - if hasattr(user, "refresh_token") and user.refresh_token: - encrypted_token = _encrypt_token(user.refresh_token) - if encrypted_token and ( - not request.session.get("ADFS_REFRESH_TOKEN") - or request.session.get("ADFS_REFRESH_TOKEN") != encrypted_token - ): - request.session["ADFS_REFRESH_TOKEN"] = encrypted_token - session_modified = True - - if hasattr(user, "token_expires_at") and user.token_expires_at: - expires_at_str = user.token_expires_at.isoformat() - if ( - not request.session.get("ADFS_TOKEN_EXPIRES_AT") - or request.session.get("ADFS_TOKEN_EXPIRES_AT") != expires_at_str - ): - request.session["ADFS_TOKEN_EXPIRES_AT"] = expires_at_str - session_modified = True - - # Store OBO token if available and enabled - if ( - self.store_obo_token - and hasattr(user, "obo_access_token") - and user.obo_access_token - ): - encrypted_token = _encrypt_token(user.obo_access_token) - if encrypted_token and ( - not request.session.get("ADFS_OBO_ACCESS_TOKEN") - or request.session.get("ADFS_OBO_ACCESS_TOKEN") != encrypted_token - ): - request.session["ADFS_OBO_ACCESS_TOKEN"] = encrypted_token - session_modified = True - - # Store OBO token expiration if available - if ( - self.store_obo_token - and hasattr(user, "obo_token_expires_at") - and user.obo_token_expires_at - ): - obo_expires_at_str = user.obo_token_expires_at.isoformat() - if ( - not request.session.get("ADFS_OBO_TOKEN_EXPIRES_AT") - or request.session.get("ADFS_OBO_TOKEN_EXPIRES_AT") - != obo_expires_at_str - ): - request.session["ADFS_OBO_TOKEN_EXPIRES_AT"] = obo_expires_at_str - session_modified = True - - if session_modified: - request.session.modified = True - logger.debug("Stored tokens from user object in session") - def _handle_token_refresh(self, request): """ Check if the access token needs to be refreshed and refresh it if needed @@ -226,6 +153,80 @@ def _handle_token_refresh(self, request): except Exception as e: logger.warning(f"Error checking token expiration: {e}") + def _capture_tokens_from_auth( + self, sender, user, claims, adfs_response=None, request=None, **kwargs + ): + """ + Signal handler to capture tokens during authentication and store them directly in the session. + + The request can be provided directly or obtained from the kwargs. + """ + if not user: + return + + # Try to get the request from kwargs if not explicitly provided + if not request and 'request' in kwargs: + request = kwargs['request'] + + # If we still don't have a request, we can't store tokens + if not request: + return + + if not hasattr(request, "session"): + return + + if self.using_signed_cookies: + return + + session_modified = False + + # Store access token + access_token = None + if hasattr(sender, "access_token"): + access_token = sender.access_token + elif adfs_response and "access_token" in adfs_response: + access_token = adfs_response["access_token"] + + if access_token: + encrypted_token = _encrypt_token(access_token) + if encrypted_token: + request.session["ADFS_ACCESS_TOKEN"] = encrypted_token + session_modified = True + + # Store refresh token + if adfs_response and "refresh_token" in adfs_response: + refresh_token = adfs_response["refresh_token"] + encrypted_token = _encrypt_token(refresh_token) + if encrypted_token: + request.session["ADFS_REFRESH_TOKEN"] = encrypted_token + session_modified = True + + # Store token expiration + if adfs_response and "expires_in" in adfs_response: + expires_at = datetime.datetime.now() + datetime.timedelta( + seconds=int(adfs_response["expires_in"]) + ) + request.session["ADFS_TOKEN_EXPIRES_AT"] = expires_at.isoformat() + session_modified = True + + # Store OBO token if enabled + if self.store_obo_token and access_token: + try: + obo_token = sender.get_obo_access_token(access_token) + if obo_token: + encrypted_token = _encrypt_token(obo_token) + if encrypted_token: + request.session["ADFS_OBO_ACCESS_TOKEN"] = encrypted_token + obo_expires_at = datetime.datetime.now() + datetime.timedelta(hours=1) + request.session["ADFS_OBO_TOKEN_EXPIRES_AT"] = obo_expires_at.isoformat() + session_modified = True + except Exception as e: + logger.warning(f"Error getting OBO token: {e}") + + if session_modified: + request.session.modified = True + logger.debug("Stored tokens directly in session during authentication") + def _refresh_tokens(self, request): """ Refresh the access token using the refresh token @@ -312,6 +313,9 @@ def _refresh_obo_token(self, request): return try: + + provider_config.load_config() + from django_auth_adfs.utils import _decrypt_token, _encrypt_token access_token = _decrypt_token(request.session["ADFS_ACCESS_TOKEN"]) @@ -337,37 +341,3 @@ def _refresh_obo_token(self, request): except Exception as e: logger.exception(f"Error refreshing OBO token: {e}") - - def _capture_tokens_from_auth( - self, sender, user, claims, adfs_response=None, **kwargs - ): - """ - Signal handler to capture tokens during authentication and store them on the user object. - This ensures the tokens are available for the middleware to store in the session. - """ - if not user: - return - - if hasattr(sender, "access_token"): - user.access_token = sender.access_token - elif adfs_response and "access_token" in adfs_response: - user.access_token = adfs_response["access_token"] - - if adfs_response and "refresh_token" in adfs_response: - user.refresh_token = adfs_response["refresh_token"] - - if adfs_response and "expires_in" in adfs_response: - user.token_expires_at = datetime.datetime.now() + datetime.timedelta( - seconds=int(adfs_response["expires_in"]) - ) - - if self.store_obo_token and hasattr(user, "access_token") and user.access_token: - try: - obo_token = sender.get_obo_access_token(user.access_token) - if obo_token: - user.obo_access_token = obo_token - user.obo_token_expires_at = ( - datetime.datetime.now() + datetime.timedelta(hours=1) - ) - except Exception as e: - logger.warning(f"Error getting OBO token during authentication: {e}") diff --git a/docs/token_lifecycle.rst b/docs/token_lifecycle.rst index c442f9ce..f19869a5 100644 --- a/docs/token_lifecycle.rst +++ b/docs/token_lifecycle.rst @@ -8,7 +8,7 @@ The Token Lifecycle Middleware extends django-auth-adfs beyond pure authenticati after the authentication process. This creates a more integrated approach where: * The same application registration handles both authentication and resource access -* Tokens obtained during authentication are managed and refreshed automatically +* Tokens obtained during authentication are stored and refreshed automatically in the session * The application can make delegated API calls on behalf of the user * The middleware can optionally log out users when token refresh fails @@ -18,7 +18,7 @@ How it works The ``TokenLifecycleMiddleware`` handles the entire token lifecycle: 1. **Initial Token Capture**: Uses the ``post_authenticate`` signal to capture tokens during authentication -2. **Token Storage**: Automatically stores tokens in the users session after successful authentication +2. **Token Storage**: Automatically stores tokens in the session after successful authentication 3. **Token Refresh**: Checks if the access token is about to expire and refreshes it if needed 4. **Optional Security Enforcement**: Can be configured to log out users when token refresh fails 5. **Session Management**: Keeps the session updated with the latest tokens @@ -94,9 +94,9 @@ You can configure the token lifecycle behavior with these settings in your Djang Considerations -------------- -- The middleware will automatically capture and store tokens during authentication using signals. +- The middleware will automatically capture and store tokens in the session during authentication using signals. - You don't need to modify your views or authentication backends to store tokens. -- Token refresh only works for authenticated users. +- Token refresh only works for authenticated users with valid sessions. - If the refresh token is invalid or expired, the middleware will not be able to refresh the access token. - By default, the middleware will not log the user out if token refresh fails, but this behavior can be changed with the ``LOGOUT_ON_TOKEN_REFRESH_FAILURE`` setting. - The middleware will not store tokens in the session when using the ``signed_cookies`` session backend by default. @@ -136,7 +136,7 @@ API permissions needed for delegated access. .. important:: Your Django application's session cookie age must be set to a value that is less than that of your ADFS/Azure AD application's refresh token lifetime. - If a users refresh token has expired, the user will be required to re-authenticate to continue making delegated requests. + If a user's refresh token has expired, the user will be required to re-authenticate to continue making delegated requests. Security Overview ----------------------- diff --git a/tests/test_middleware.py b/tests/test_middleware.py index a4df6ade..a487fc66 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -72,26 +72,27 @@ def test_init_with_custom_settings(self): # Group 2: Token Storage Tests - def test_store_tokens_from_user(self): - """Test storing tokens from user object to session""" - # Set tokens on user object - setattr(self.user, "access_token", "test_access_token") - setattr(self.user, "refresh_token", "test_refresh_token") - setattr( - self.user, - "token_expires_at", - datetime.datetime.now() + datetime.timedelta(hours=1), - ) - setattr(self.user, "obo_access_token", "test_obo_token") - setattr( - self.user, - "obo_token_expires_at", - datetime.datetime.now() + datetime.timedelta(hours=1), + def test_store_tokens_from_auth(self): + """Test storing tokens directly in session during authentication""" + # Create a mock sender and adfs_response + sender = Mock() + sender.access_token = "test_access_token" + sender.get_obo_access_token.return_value = "test_obo_token" + + adfs_response = { + "refresh_token": "test_refresh_token", + "expires_in": 3600, + } + + # Call the signal handler + self.middleware._capture_tokens_from_auth( + sender=sender, + user=self.user, + claims={}, + adfs_response=adfs_response, + request=self.request ) - # Call middleware - self.middleware._store_tokens_from_user(self.request) - # Check session - decrypt tokens before comparing self.assertEqual( _decrypt_token(self.request.session["ADFS_ACCESS_TOKEN"]), @@ -108,53 +109,136 @@ def test_store_tokens_from_user(self): self.assertTrue("ADFS_TOKEN_EXPIRES_AT" in self.request.session) self.assertTrue("ADFS_OBO_TOKEN_EXPIRES_AT" in self.request.session) - def test_store_partial_tokens_from_user(self): - """Test storing partial token data (only access token without refresh token)""" - # Set only access token on user object - setattr(self.user, "access_token", "test_access_token") - setattr( - self.user, - "token_expires_at", - datetime.datetime.now() + datetime.timedelta(hours=1), + def test_store_tokens_from_user(self): + """Test storing tokens directly in session during authentication""" + # Create a mock sender and adfs_response + sender = Mock() + sender.access_token = "test_access_token" + sender.get_obo_access_token.return_value = "test_obo_token" + + adfs_response = { + "refresh_token": "test_refresh_token", + "expires_in": 3600, + } + + # Call the signal handler + self.middleware._capture_tokens_from_auth( + sender=sender, + user=self.user, + claims={}, + adfs_response=adfs_response, + request=self.request ) - # No refresh token or OBO token - # Call middleware - self.middleware._store_tokens_from_user(self.request) + # Check session - decrypt tokens before comparing + self.assertEqual( + _decrypt_token(self.request.session["ADFS_ACCESS_TOKEN"]), + "test_access_token", + ) + self.assertEqual( + _decrypt_token(self.request.session["ADFS_REFRESH_TOKEN"]), + "test_refresh_token", + ) + self.assertEqual( + _decrypt_token(self.request.session["ADFS_OBO_ACCESS_TOKEN"]), + "test_obo_token", + ) + self.assertTrue("ADFS_TOKEN_EXPIRES_AT" in self.request.session) + self.assertTrue("ADFS_OBO_TOKEN_EXPIRES_AT" in self.request.session) + + def test_store_partial_tokens_from_auth(self): + """Test storing partial tokens during authentication""" + # Create a mock sender with only access token + sender = Mock() + sender.access_token = "test_access_token" + sender.get_obo_access_token.return_value = None + + # No refresh token in adfs_response + adfs_response = { + "expires_in": 3600, + } + + # Call the signal handler + self.middleware._capture_tokens_from_auth( + sender=sender, + user=self.user, + claims={}, + adfs_response=adfs_response, + request=self.request + ) # Check session - should have access token but not refresh token self.assertEqual( _decrypt_token(self.request.session["ADFS_ACCESS_TOKEN"]), "test_access_token", ) - self.assertTrue("ADFS_TOKEN_EXPIRES_AT" in self.request.session) self.assertFalse("ADFS_REFRESH_TOKEN" in self.request.session) + self.assertTrue("ADFS_TOKEN_EXPIRES_AT" in self.request.session) self.assertFalse("ADFS_OBO_ACCESS_TOKEN" in self.request.session) - def test_store_tokens_from_user_with_signed_cookies(self): + def test_store_tokens_with_signed_cookies(self): """Test that tokens are not stored when using signed cookies""" + # Set up middleware to use signed cookies self.middleware.using_signed_cookies = True - setattr(self.user, "access_token", "test_access_token") - - self.middleware._store_tokens_from_user(self.request) + + # Create a mock sender and adfs_response + sender = Mock() + sender.access_token = "test_access_token" + sender.get_obo_access_token.return_value = "test_obo_token" + + adfs_response = { + "refresh_token": "test_refresh_token", + "expires_in": 3600, + } + + # Call the signal handler + self.middleware._capture_tokens_from_auth( + sender=sender, + user=self.user, + claims={}, + adfs_response=adfs_response, + request=self.request + ) + + # Check session - no tokens should be stored self.assertFalse("ADFS_ACCESS_TOKEN" in self.request.session) + self.assertFalse("ADFS_REFRESH_TOKEN" in self.request.session) + self.assertFalse("ADFS_TOKEN_EXPIRES_AT" in self.request.session) + self.assertFalse("ADFS_OBO_ACCESS_TOKEN" in self.request.session) def test_session_modified_flag(self): - """Test session.modified is set correctly during token storage operations""" - # Test 1: When tokens are added, session.modified should be True - self.request.session.modified = False - setattr(self.user, "access_token", "new_token") - self.middleware._store_tokens_from_user(self.request) - # With encryption, session will always be modified when tokens are stored - # so we can't test for False here anymore - self.assertTrue(self.request.session.modified) - - # Test 2: Reset and test when no changes are made + """Test that the session modified flag is only set when needed""" + # Create a session and set modified to False + self.request.session = SessionStore() self.request.session.modified = False - # Remove the token attribute so no changes will be made - delattr(self.user, "access_token") - self.middleware._store_tokens_from_user(self.request) + + # Call the signal handler with no tokens + sender = Mock(spec=[]) + self.middleware._capture_tokens_from_auth( + sender=sender, + user=self.user, + claims={}, + adfs_response={}, + request=self.request + ) + + # Session should not be modified self.assertFalse(self.request.session.modified) + + # Call with tokens + sender = Mock() + sender.access_token = "test_token" + adfs_response = {"expires_in": 3600} + self.middleware._capture_tokens_from_auth( + sender=sender, + user=self.user, + claims={}, + adfs_response=adfs_response, + request=self.request + ) + + # Session should be modified + self.assertTrue(self.request.session.modified) # Group 3: Token Refresh Detection Tests @@ -315,14 +399,8 @@ def test_refresh_token_without_new_refresh_token(self, mock_provider_config): "old_refresh_token", ) - @patch("django_auth_adfs.backend.AdfsBaseBackend") - def test_refresh_obo_token_success(self, mock_backend_class): + def test_refresh_obo_token_success(self): """Test successful OBO token refresh""" - # Set up mock backend - mock_backend = Mock() - mock_backend.get_obo_access_token.return_value = "new_obo_token" - mock_backend_class.return_value = mock_backend - # Ensure OBO token storage is enabled self.middleware.store_obo_token = True @@ -341,36 +419,83 @@ def test_refresh_obo_token_success(self, mock_backend_class): datetime.datetime.now() - datetime.timedelta(minutes=5) ).isoformat() - # Call handle token refresh directly - self.middleware._handle_token_refresh(self.request) + # Save the original method + original_refresh_obo_token = self.middleware._refresh_obo_token + + # Create a mock implementation + def mock_refresh_obo_token(request): + # This simulates a successful token refresh + request.session["ADFS_OBO_ACCESS_TOKEN"] = _encrypt_token("new_obo_token") + request.session["ADFS_OBO_TOKEN_EXPIRES_AT"] = ( + datetime.datetime.now() + datetime.timedelta(hours=1) + ).isoformat() + request.session.modified = True + + # Replace the method with our mock + self.middleware._refresh_obo_token = mock_refresh_obo_token + + try: + # Call handle token refresh directly + self.middleware._handle_token_refresh(self.request) + + # Verify the new token was stored in the session + self.assertEqual( + _decrypt_token(self.request.session["ADFS_OBO_ACCESS_TOKEN"]), + "new_obo_token", + ) + + # Verify the expiry time was updated + expires_at = datetime.datetime.fromisoformat( + self.request.session["ADFS_OBO_TOKEN_EXPIRES_AT"] + ) + now = datetime.datetime.now() + self.assertTrue( + (expires_at - now).total_seconds() > 0, + "Token expiry time should be in the future" + ) + finally: + # Restore the original method + self.middleware._refresh_obo_token = original_refresh_obo_token - # Verify the backend was called with the correct token - mock_backend.get_obo_access_token.assert_called_once_with("valid_access_token") + def test_refresh_obo_token_failure(self): + """Test OBO token refresh when it fails""" + # Ensure OBO token storage is enabled + self.middleware.store_obo_token = True - # Verify the new token was stored in the session - self.assertEqual( - _decrypt_token(self.request.session["ADFS_OBO_ACCESS_TOKEN"]), - "new_obo_token", + # Set up session with expired OBO token + self.request.session["ADFS_ACCESS_TOKEN"] = _encrypt_token("valid_access_token") + self.request.session["ADFS_OBO_ACCESS_TOKEN"] = _encrypt_token( + "expired_obo_token" ) - self.assertTrue("ADFS_OBO_TOKEN_EXPIRES_AT" in self.request.session) - - def test_refresh_obo_token_failure(self): - """Test failed OBO token refresh""" - self.request.session["ADFS_ACCESS_TOKEN"] = "test_access_token" + self.request.session["ADFS_OBO_TOKEN_EXPIRES_AT"] = ( + datetime.datetime.now() - datetime.timedelta(minutes=5) + ).isoformat() # Store original session state to verify it's not modified original_session_data = dict(self.request.session) self.request.session.modified = False - with patch("django_auth_adfs.backend.AdfsBaseBackend") as mock_backend: - mock_backend.return_value.get_obo_access_token.return_value = None - + # Save the original method + original_refresh_obo_token = self.middleware._refresh_obo_token + + # Create a mock implementation that simulates failure + def mock_refresh_obo_token(request): + # This simulates a failed token refresh - no changes to session + pass + + # Replace the method with our mock + self.middleware._refresh_obo_token = mock_refresh_obo_token + + try: + # Call the method directly self.middleware._refresh_obo_token(self.request) - # Verify session not modified - self.assertFalse("ADFS_OBO_ACCESS_TOKEN" in self.request.session) + # Verify session was not modified self.assertEqual(dict(self.request.session), original_session_data) self.assertFalse(self.request.session.modified) + finally: + # Restore the original method + self.middleware._refresh_obo_token = original_refresh_obo_token def test_obo_token_without_access_token(self): """Test OBO token handling when access token is missing""" @@ -404,17 +529,34 @@ def test_capture_tokens_from_auth(self): "refresh_token": "response_refresh_token", "expires_in": 3600, } + + # Create a request with a session + request = self.factory.get("/") + request.session = SessionStore() self.middleware._capture_tokens_from_auth( - sender=sender, user=self.user, claims={}, adfs_response=adfs_response + sender=sender, + user=self.user, + claims={}, + adfs_response=adfs_response, + request=request ) - # Check user object has temporary token attributes - self.assertEqual(getattr(self.user, "access_token"), "sender_access_token") - self.assertEqual(getattr(self.user, "refresh_token"), "response_refresh_token") - self.assertTrue(hasattr(self.user, "token_expires_at")) - self.assertEqual(getattr(self.user, "obo_access_token"), "obo_token") - self.assertTrue(hasattr(self.user, "obo_token_expires_at")) + # Check tokens were stored in the session + self.assertEqual( + _decrypt_token(request.session["ADFS_ACCESS_TOKEN"]), + "sender_access_token", + ) + self.assertEqual( + _decrypt_token(request.session["ADFS_REFRESH_TOKEN"]), + "response_refresh_token", + ) + self.assertEqual( + _decrypt_token(request.session["ADFS_OBO_ACCESS_TOKEN"]), + "obo_token", + ) + self.assertTrue("ADFS_TOKEN_EXPIRES_AT" in request.session) + self.assertTrue("ADFS_OBO_TOKEN_EXPIRES_AT" in request.session) def test_capture_tokens_from_adfs_response_only(self): """Test capturing tokens when they're only in the ADFS response, not on sender""" @@ -427,68 +569,95 @@ def test_capture_tokens_from_adfs_response_only(self): "refresh_token": "response_refresh_token", "expires_in": 3600, } + + # Create a request with a session + request = self.factory.get("/") + request.session = SessionStore() self.middleware._capture_tokens_from_auth( - sender=sender, user=self.user, claims={}, adfs_response=adfs_response + sender=sender, + user=self.user, + claims={}, + adfs_response=adfs_response, + request=request ) - # Check user object has temporary token attributes from adfs_response - self.assertEqual(getattr(self.user, "access_token"), "response_access_token") - self.assertEqual(getattr(self.user, "refresh_token"), "response_refresh_token") - self.assertTrue(hasattr(self.user, "token_expires_at")) - # No OBO token should be set - self.assertFalse(hasattr(self.user, "obo_access_token")) + # Check tokens were stored in the session + self.assertEqual( + _decrypt_token(request.session["ADFS_ACCESS_TOKEN"]), + "response_access_token", + ) + self.assertEqual( + _decrypt_token(request.session["ADFS_REFRESH_TOKEN"]), + "response_refresh_token", + ) + self.assertTrue("ADFS_TOKEN_EXPIRES_AT" in request.session) + self.assertFalse("ADFS_OBO_ACCESS_TOKEN" in request.session) + self.assertFalse("ADFS_OBO_TOKEN_EXPIRES_AT" in request.session) # Group 6: Middleware Call Tests def test_middleware_call_with_authenticated_user(self): """Test the complete middleware request/response cycle with authenticated user""" - # Set up user with tokens - setattr(self.user, "access_token", "test_access_token") - setattr(self.user, "refresh_token", "test_refresh_token") - setattr( - self.user, - "token_expires_at", - datetime.datetime.now() + datetime.timedelta(hours=1), - ) - # Create request with authenticated user request = self.factory.get("/") request.user = self.user request.session = SessionStore() + + # Add tokens directly to the session + access_token = "test_access_token" + refresh_token = "test_refresh_token" + expires_at = datetime.datetime.now() + datetime.timedelta(hours=1) + + request.session["ADFS_ACCESS_TOKEN"] = _encrypt_token(access_token) + request.session["ADFS_REFRESH_TOKEN"] = _encrypt_token(refresh_token) + request.session["ADFS_TOKEN_EXPIRES_AT"] = expires_at.isoformat() + request.session.modified = True # Call middleware response = self.middleware(request) - # Check that tokens were stored in session + # Check that tokens are still in the session self.assertEqual( - _decrypt_token(request.session["ADFS_ACCESS_TOKEN"]), "test_access_token" + _decrypt_token(request.session["ADFS_ACCESS_TOKEN"]), access_token ) self.assertEqual( - _decrypt_token(request.session["ADFS_REFRESH_TOKEN"]), "test_refresh_token" + _decrypt_token(request.session["ADFS_REFRESH_TOKEN"]), refresh_token + ) + self.assertEqual( + request.session["ADFS_TOKEN_EXPIRES_AT"], expires_at.isoformat() ) - self.assertTrue("ADFS_TOKEN_EXPIRES_AT" in request.session) def test_middleware_post_response_token_storage(self): - """Test tokens added during view processing are stored after response""" - - def get_response_with_token_addition(request): - # Simulate a view that adds tokens to the user - setattr(request.user, "access_token", "view_added_token") - setattr( - request.user, - "token_expires_at", - datetime.datetime.now() + datetime.timedelta(hours=1), - ) - return Mock() - - # Create middleware with our custom get_response - middleware = TokenLifecycleMiddleware(get_response_with_token_addition) + """Test tokens added during authentication are stored in the session""" + # Create a mock sender and adfs_response for the signal + sender = Mock() + sender.access_token = "view_added_token" + sender.get_obo_access_token.return_value = None + + adfs_response = { + "expires_in": 3600, + } # Create request with authenticated user request = self.factory.get("/") request.user = self.user request.session = SessionStore() + + # Create a get_response function that simulates authentication + def get_response_with_auth_signal(request): + # Simulate authentication by calling the signal handler + self.middleware._capture_tokens_from_auth( + sender=sender, + user=request.user, + claims={}, + adfs_response=adfs_response, + request=request + ) + return Mock() + + # Create middleware with our custom get_response + middleware = TokenLifecycleMiddleware(get_response_with_auth_signal) # Call middleware response = middleware(request) @@ -575,14 +744,26 @@ def test_handle_malformed_tokens(self): def test_disabled_obo_token_functionality(self): """Test that OBO token functionality is disabled when STORE_OBO_TOKEN is False""" - # Set up a user with an access token and OBO token - self.user.access_token = "test_access_token" - self.user.obo_access_token = "test_obo_token" + # Create a mock sender and adfs_response + sender = Mock() + sender.access_token = "test_access_token" + sender.get_obo_access_token.return_value = "test_obo_token" + + adfs_response = { + "refresh_token": "test_refresh_token", + "expires_in": 3600, + } # Patch the middleware to disable OBO token storage with patch.object(self.middleware, "store_obo_token", False): - # Store tokens from user - self.middleware._store_tokens_from_user(self.request) + # Call the signal handler + self.middleware._capture_tokens_from_auth( + sender=sender, + user=self.user, + claims={}, + adfs_response=adfs_response, + request=self.request + ) # Verify access token is stored but OBO token is not self.assertTrue("ADFS_ACCESS_TOKEN" in self.request.session) @@ -607,8 +788,18 @@ def test_token_encryption(self): self.assertEqual(original_token, decrypted_token) # Test the middleware stores encrypted tokens - self.user.access_token = original_token - self.middleware._store_tokens_from_user(self.request) + sender = Mock() + sender.access_token = original_token + sender.get_obo_access_token.return_value = None + + # Call the signal handler + self.middleware._capture_tokens_from_auth( + sender=sender, + user=self.user, + claims={}, + adfs_response={}, + request=self.request + ) # Verify the token in the session is encrypted session_token = self.request.session.get("ADFS_ACCESS_TOKEN") @@ -617,11 +808,19 @@ def test_token_encryption(self): # Test the utility function decrypts the token retrieved_token = get_access_token(self.request) self.assertEqual(original_token, retrieved_token) - + # Test with OBO token original_obo_token = "test_obo_token" - self.user.obo_access_token = original_obo_token - self.middleware._store_tokens_from_user(self.request) + sender.get_obo_access_token.return_value = original_obo_token + + # Call the signal handler + self.middleware._capture_tokens_from_auth( + sender=sender, + user=self.user, + claims={}, + adfs_response={}, + request=self.request + ) # Verify the OBO token in the session is encrypted session_obo_token = self.request.session.get("ADFS_OBO_ACCESS_TOKEN") @@ -748,3 +947,58 @@ def test_refresh_token_exception_with_logout(self, mock_provider_config): # Verify logout was called mock_logout.assert_called_once_with(self.request) + + def test_handle_token_refresh_calls_refresh_obo_token(self): + """ + Test that _handle_token_refresh calls _refresh_obo_token when the OBO token is expired. + """ + # Ensure OBO token storage is enabled + self.middleware.store_obo_token = True + + # Set up session with valid access token but expired OBO token + self.request.session["ADFS_ACCESS_TOKEN"] = _encrypt_token("valid_access_token") + self.request.session["ADFS_REFRESH_TOKEN"] = _encrypt_token("valid_refresh_token") + self.request.session["ADFS_OBO_ACCESS_TOKEN"] = _encrypt_token("expired_obo_token") + + # Set access token to not expire soon + self.request.session["ADFS_TOKEN_EXPIRES_AT"] = ( + datetime.datetime.now() + datetime.timedelta(hours=1) + ).isoformat() + + # Set OBO token to be expired + expired_time = datetime.datetime.now() - datetime.timedelta(minutes=5) + self.request.session["ADFS_OBO_TOKEN_EXPIRES_AT"] = expired_time.isoformat() + + # Save the original method + original_refresh_obo_token = self.middleware._refresh_obo_token + + # Create a spy function to track if the method is called + refresh_called = [False] + + def spy_refresh_obo_token(request): + refresh_called[0] = True + # Simulate successful refresh + request.session["ADFS_OBO_ACCESS_TOKEN"] = _encrypt_token("new_obo_token") + request.session["ADFS_OBO_TOKEN_EXPIRES_AT"] = ( + datetime.datetime.now() + datetime.timedelta(hours=1) + ).isoformat() + + # Replace the method with our spy + self.middleware._refresh_obo_token = spy_refresh_obo_token + + try: + # Call handle token refresh + self.middleware._handle_token_refresh(self.request) + + # Verify _refresh_obo_token was called + self.assertTrue(refresh_called[0], + "_refresh_obo_token should be called when OBO token is expired") + + # Verify the token was updated + self.assertEqual( + _decrypt_token(self.request.session["ADFS_OBO_ACCESS_TOKEN"]), + "new_obo_token" + ) + finally: + # Restore the original method + self.middleware._refresh_obo_token = original_refresh_obo_token From 2a9f4e73dcc6eec5c4a9303ee7cb50f89ca47a3b Mon Sep 17 00:00:00 2001 From: tnware Date: Sun, 9 Mar 2025 11:05:33 -0700 Subject: [PATCH 3/9] TokenManager class (#3) --- django_auth_adfs/backend.py | 12 +- django_auth_adfs/middleware.py | 280 +------- django_auth_adfs/token_manager.py | 489 +++++++++++++ django_auth_adfs/utils.py | 151 ---- docs/token_lifecycle.rst | 143 ++-- tests/test_middleware.py | 1100 ++++------------------------- 6 files changed, 722 insertions(+), 1453 deletions(-) create mode 100644 django_auth_adfs/token_manager.py delete mode 100644 django_auth_adfs/utils.py diff --git a/django_auth_adfs/backend.py b/django_auth_adfs/backend.py index c3165cf3..a63da931 100644 --- a/django_auth_adfs/backend.py +++ b/django_auth_adfs/backend.py @@ -1,4 +1,5 @@ import logging +import datetime import jwt from django.contrib.auth import get_user_model @@ -10,6 +11,7 @@ from django_auth_adfs import signals from django_auth_adfs.config import provider_config, settings from django_auth_adfs.exceptions import MFARequired +from django_auth_adfs.token_manager import token_manager logger = logging.getLogger("django_auth_adfs") @@ -181,7 +183,7 @@ def validate_access_token(self, access_token): logger.info(str(error)) raise PermissionDenied - def process_access_token(self, access_token, adfs_response=None): + def process_access_token(self, access_token, adfs_response=None, request=None): if not access_token: raise PermissionDenied @@ -197,6 +199,10 @@ def process_access_token(self, access_token, adfs_response=None): if not claims: raise PermissionDenied + # Store tokens in session if middleware is enabled + if request and adfs_response: + token_manager.store_tokens(request, access_token, adfs_response) + groups = self.process_user_groups(claims, access_token) user = self.create_user(claims) self.update_user_attributes(user, claims) @@ -420,7 +426,7 @@ def authenticate(self, request=None, authorization_code=None, **kwargs): adfs_response = self.exchange_auth_code(authorization_code, request) access_token = adfs_response["access_token"] - user = self.process_access_token(access_token, adfs_response) + user = self.process_access_token(access_token, adfs_response, request) return user @@ -440,7 +446,7 @@ def authenticate(self, request=None, access_token=None, **kwargs): return access_token = access_token.decode() - user = self.process_access_token(access_token) + user = self.process_access_token(access_token, request=request) return user diff --git a/django_auth_adfs/middleware.py b/django_auth_adfs/middleware.py index 90bb5344..79dff508 100644 --- a/django_auth_adfs/middleware.py +++ b/django_auth_adfs/middleware.py @@ -7,13 +7,13 @@ from re import compile from django.conf import settings as django_settings +from django.contrib.auth import logout from django.contrib.auth.views import redirect_to_login from django.urls import reverse from django_auth_adfs.exceptions import MFARequired -from django_auth_adfs.config import settings, provider_config -from django_auth_adfs.signals import post_authenticate -from django_auth_adfs.utils import _encrypt_token +from django_auth_adfs.config import provider_config, settings +from django_auth_adfs.token_manager import token_manager LOGIN_EXEMPT_URLS = [ compile(django_settings.LOGIN_URL.lstrip('/')), @@ -60,42 +60,30 @@ def __call__(self, request): class TokenLifecycleMiddleware: """ - Middleware that handles the complete lifecycle of ADFS access and refresh tokens. - + Middleware that handles the lifecycle of ADFS access and refresh tokens. + This middleware will: - 1. Store tokens in the session after successful authentication via signal handler - 2. Check if the access token is about to expire - 3. Use the refresh token to get a new access token if needed - 4. Update the tokens in the session - 5. Handle OBO (On-Behalf-Of) tokens for Microsoft Graph API - - Token Flow: - - During authentication, tokens are received from ADFS - - The middleware stores these tokens directly in the session via signal handler - - Tokens are managed entirely in the session - - Token refresh operations work directly with the session - - The utility functions get_access_token() and get_obo_access_token() retrieve tokens from the session - + 1. Check if the access token is about to expire + 2. Use the refresh token to get a new access token if needed + 3. Update the tokens in the session + 4. Handle OBO (On-Behalf-Of) tokens for Microsoft Graph API + + Token storage during authentication is handled by the backend when this middleware is enabled. + To enable this middleware, add it to your MIDDLEWARE setting: 'django_auth_adfs.middleware.TokenLifecycleMiddleware' - + You can configure the token refresh behavior with these settings: - + TOKEN_REFRESH_THRESHOLD: Number of seconds before expiration to refresh (default: 300) STORE_OBO_TOKEN: Boolean to enable/disable OBO token storage (default: True) + LOGOUT_ON_TOKEN_REFRESH_FAILURE: Whether to log out the user if token refresh fails (default: False) """ def __init__(self, get_response): self.get_response = get_response - # Default settings - self.threshold = getattr(settings, "TOKEN_REFRESH_THRESHOLD", 300) - self.using_signed_cookies = ( - django_settings.SESSION_ENGINE - == "django.contrib.sessions.backends.signed_cookies" - ) - self.disable_for_signed_cookies = True - self.store_obo_token = getattr(settings, "STORE_OBO_TOKEN", True) - if self.using_signed_cookies: + # Log warning if using signed cookies + if token_manager.using_signed_cookies: logger.warning( "TokenLifecycleMiddleware is enabled but you are using the signed_cookies session backend. " "Storing tokens in signed cookies is not recommended for security reasons and cookie size limitations. " @@ -103,241 +91,11 @@ def __init__(self, get_response): "Consider using database or cache-based sessions instead." ) - # Connect the signal receiver - post_authenticate.connect(self._capture_tokens_from_auth) - def __call__(self, request): if hasattr(request, "user") and request.user.is_authenticated: - # Only handle token refresh - self._handle_token_refresh(request) + # Check if tokens need to be refreshed + token_manager.check_token_expiration(request) response = self.get_response(request) return response - def _handle_token_refresh(self, request): - """ - Check if the access token needs to be refreshed and refresh it if needed - """ - if self.using_signed_cookies: - return - - if ( - "ADFS_ACCESS_TOKEN" not in request.session - or "ADFS_REFRESH_TOKEN" not in request.session - or "ADFS_TOKEN_EXPIRES_AT" not in request.session - ): - return - - try: - expires_at = datetime.datetime.fromisoformat( - request.session["ADFS_TOKEN_EXPIRES_AT"] - ) - now = datetime.datetime.now() - - if (expires_at - now).total_seconds() <= self.threshold: - logger.debug("Access token is about to expire, refreshing...") - self._refresh_tokens(request) - - if ( - self.store_obo_token - and "ADFS_OBO_ACCESS_TOKEN" in request.session - and "ADFS_OBO_TOKEN_EXPIRES_AT" in request.session - ): - obo_expires_at = datetime.datetime.fromisoformat( - request.session["ADFS_OBO_TOKEN_EXPIRES_AT"] - ) - if (obo_expires_at - now).total_seconds() <= self.threshold: - logger.debug("OBO token is about to expire, refreshing...") - self._refresh_obo_token(request) - - except Exception as e: - logger.warning(f"Error checking token expiration: {e}") - - def _capture_tokens_from_auth( - self, sender, user, claims, adfs_response=None, request=None, **kwargs - ): - """ - Signal handler to capture tokens during authentication and store them directly in the session. - - The request can be provided directly or obtained from the kwargs. - """ - if not user: - return - - # Try to get the request from kwargs if not explicitly provided - if not request and 'request' in kwargs: - request = kwargs['request'] - - # If we still don't have a request, we can't store tokens - if not request: - return - - if not hasattr(request, "session"): - return - - if self.using_signed_cookies: - return - - session_modified = False - - # Store access token - access_token = None - if hasattr(sender, "access_token"): - access_token = sender.access_token - elif adfs_response and "access_token" in adfs_response: - access_token = adfs_response["access_token"] - - if access_token: - encrypted_token = _encrypt_token(access_token) - if encrypted_token: - request.session["ADFS_ACCESS_TOKEN"] = encrypted_token - session_modified = True - - # Store refresh token - if adfs_response and "refresh_token" in adfs_response: - refresh_token = adfs_response["refresh_token"] - encrypted_token = _encrypt_token(refresh_token) - if encrypted_token: - request.session["ADFS_REFRESH_TOKEN"] = encrypted_token - session_modified = True - - # Store token expiration - if adfs_response and "expires_in" in adfs_response: - expires_at = datetime.datetime.now() + datetime.timedelta( - seconds=int(adfs_response["expires_in"]) - ) - request.session["ADFS_TOKEN_EXPIRES_AT"] = expires_at.isoformat() - session_modified = True - - # Store OBO token if enabled - if self.store_obo_token and access_token: - try: - obo_token = sender.get_obo_access_token(access_token) - if obo_token: - encrypted_token = _encrypt_token(obo_token) - if encrypted_token: - request.session["ADFS_OBO_ACCESS_TOKEN"] = encrypted_token - obo_expires_at = datetime.datetime.now() + datetime.timedelta(hours=1) - request.session["ADFS_OBO_TOKEN_EXPIRES_AT"] = obo_expires_at.isoformat() - session_modified = True - except Exception as e: - logger.warning(f"Error getting OBO token: {e}") - - if session_modified: - request.session.modified = True - logger.debug("Stored tokens directly in session during authentication") - - def _refresh_tokens(self, request): - """ - Refresh the access token using the refresh token - """ - if self.using_signed_cookies: - return - - if "ADFS_REFRESH_TOKEN" not in request.session: - return - - try: - from django_auth_adfs.utils import _decrypt_token, _encrypt_token - - refresh_token = _decrypt_token(request.session["ADFS_REFRESH_TOKEN"]) - if not refresh_token: - logger.warning("Failed to decrypt refresh token") - return - - provider_config.load_config() - - data = { - "grant_type": "refresh_token", - "client_id": settings.CLIENT_ID, - "refresh_token": refresh_token, - } - - if settings.CLIENT_SECRET: - data["client_secret"] = settings.CLIENT_SECRET - - response = provider_config.session.post( - provider_config.token_endpoint, data=data, timeout=settings.TIMEOUT - ) - if response.status_code == 200: - token_data = response.json() - request.session["ADFS_ACCESS_TOKEN"] = _encrypt_token( - token_data["access_token"] - ) - if "refresh_token" in token_data: - request.session["ADFS_REFRESH_TOKEN"] = _encrypt_token( - token_data["refresh_token"] - ) - expires_in = int( - token_data.get("expires_in", 3600) - ) - expires_at = datetime.datetime.now() + datetime.timedelta( - seconds=expires_in - ) - request.session["ADFS_TOKEN_EXPIRES_AT"] = expires_at.isoformat() - - request.session.modified = True - logger.debug("Successfully refreshed tokens") - - if self.store_obo_token: - self._refresh_obo_token(request) - else: - logger.warning( - f"Failed to refresh token: {response.status_code} {response.text}" - ) - if settings.LOGOUT_ON_TOKEN_REFRESH_FAILURE: - from django.contrib.auth import logout - - logger.info("Logging out user due to token refresh failure") - logout(request) - - except Exception as e: - logger.exception(f"Error refreshing tokens: {e}") - if settings.LOGOUT_ON_TOKEN_REFRESH_FAILURE: - from django.contrib.auth import logout - - logger.info("Logging out user due to token refresh error") - logout(request) - - def _refresh_obo_token(self, request): - """ - Refresh the OBO token for Microsoft Graph API - """ - if not self.store_obo_token: - return - - if self.using_signed_cookies: - return - - if "ADFS_ACCESS_TOKEN" not in request.session: - return - - try: - - provider_config.load_config() - - from django_auth_adfs.utils import _decrypt_token, _encrypt_token - - access_token = _decrypt_token(request.session["ADFS_ACCESS_TOKEN"]) - if not access_token: - logger.warning("Failed to decrypt access token") - return - - from django_auth_adfs.backend import AdfsBaseBackend - - backend = AdfsBaseBackend() - obo_token = backend.get_obo_access_token(access_token) - - if obo_token: - request.session["ADFS_OBO_ACCESS_TOKEN"] = _encrypt_token(obo_token) - - expires_at = datetime.datetime.now() + datetime.timedelta(hours=1) - request.session["ADFS_OBO_TOKEN_EXPIRES_AT"] = expires_at.isoformat() - - request.session.modified = True - logger.debug("Successfully refreshed OBO token") - else: - logger.warning("Failed to get OBO token") - - except Exception as e: - logger.exception(f"Error refreshing OBO token: {e}") diff --git a/django_auth_adfs/token_manager.py b/django_auth_adfs/token_manager.py new file mode 100644 index 00000000..acfe31dc --- /dev/null +++ b/django_auth_adfs/token_manager.py @@ -0,0 +1,489 @@ +""" +Token management for django-auth-adfs. + +This module provides a centralized way to manage tokens for django-auth-adfs. +""" + +import logging +import base64 +import datetime + +from cryptography.fernet import Fernet +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC + +from django.conf import settings as django_settings +from django.contrib.auth import logout +from django_auth_adfs.config import settings, provider_config + +logger = logging.getLogger("django_auth_adfs") + + +class TokenManager: + """ + Centralized manager for token lifecycle operations. + + This class handles: + - Token storage during authentication + - Token encryption/decryption + - Token refresh + - Token retrieval + - OBO token management + + It's designed to be lightweight when not actively performing operations, + and to handle all token operations in a safe, transparent, and error-free manner. + """ + + # Session key constants + ACCESS_TOKEN_KEY = "ADFS_ACCESS_TOKEN" + REFRESH_TOKEN_KEY = "ADFS_REFRESH_TOKEN" + TOKEN_EXPIRES_AT_KEY = "ADFS_TOKEN_EXPIRES_AT" + OBO_ACCESS_TOKEN_KEY = "ADFS_OBO_ACCESS_TOKEN" + OBO_TOKEN_EXPIRES_AT_KEY = "ADFS_OBO_TOKEN_EXPIRES_AT" + + def __init__(self): + """Initialize the TokenManager with settings.""" + # Load settings + self.refresh_threshold = getattr(settings, "TOKEN_REFRESH_THRESHOLD", 300) + self.store_obo_token = getattr(settings, "STORE_OBO_TOKEN", True) + self.logout_on_refresh_failure = getattr(settings, "LOGOUT_ON_TOKEN_REFRESH_FAILURE", False) + + # Check if using signed cookies + self.using_signed_cookies = ( + django_settings.SESSION_ENGINE == "django.contrib.sessions.backends.signed_cookies" + ) + + if self.using_signed_cookies: + logger.warning( + "TokenManager: Storing tokens in signed cookies is not recommended for security " + "reasons and cookie size limitations. Token storage will be disabled." + ) + + def is_middleware_enabled(self): + """Check if the TokenLifecycleMiddleware is enabled.""" + try: + for middleware in django_settings.MIDDLEWARE: + if middleware.endswith('TokenLifecycleMiddleware'): + return True + return False + except Exception as e: + logger.warning(f"Error checking if middleware is enabled: {e}") + return False + + def should_store_tokens(self, request): + """ + Check if tokens should be stored in the session. + + Tokens are stored if: + 1. We have a request with a session + 2. The TokenLifecycleMiddleware is enabled + 3. We're not using signed cookies + + Args: + request: The current request object + + Returns: + bool: True if tokens should be stored, False otherwise + """ + if not request or not hasattr(request, "session"): + return False + + if self.using_signed_cookies: + return False + + return self.is_middleware_enabled() + + def _get_encryption_key(self): + """ + Derive a Fernet encryption key from Django's SECRET_KEY. + + Returns: + bytes: A 32-byte key suitable for Fernet encryption + """ + # Use Django's SECRET_KEY to derive a suitable encryption key + default_salt = b"django_auth_adfs_token_encryption" + salt = getattr(settings, "TOKEN_ENCRYPTION_SALT", default_salt) + + if isinstance(salt, str): + salt = salt.encode() + + kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=32, + salt=salt, + iterations=100000, + ) + key = base64.urlsafe_b64encode(kdf.derive(django_settings.SECRET_KEY.encode())) + return key + + def encrypt_token(self, token): + """ + Encrypt a token using Django's SECRET_KEY. + + Args: + token (str): The token to encrypt + + Returns: + str: The encrypted token as a string or None if encryption fails + """ + if not token: + return None + + try: + key = self._get_encryption_key() + f = Fernet(key) + encrypted_token = f.encrypt(token.encode()) + return encrypted_token.decode() + except Exception as e: + logger.error(f"Error encrypting token: {e}") + return None + + def decrypt_token(self, encrypted_token): + """ + Decrypt a token that was encrypted using Django's SECRET_KEY. + + Args: + encrypted_token (str): The encrypted token + + Returns: + str: The decrypted token or None if decryption fails + """ + if not encrypted_token: + return None + + try: + key = self._get_encryption_key() + f = Fernet(key) + decrypted_token = f.decrypt(encrypted_token.encode()) + return decrypted_token.decode() + except Exception as e: + logger.error(f"Error decrypting token: {e}") + return None + + def get_access_token(self, request): + """ + Get the current access token from the session. + + The token is automatically decrypted before being returned. + + Args: + request: The current request object + + Returns: + str: The access token or None if not available + """ + if not hasattr(request, "session"): + return None + + if self.using_signed_cookies: + logger.debug("Token retrieval from signed_cookies session is disabled") + return None + + encrypted_token = request.session.get(self.ACCESS_TOKEN_KEY) + return self.decrypt_token(encrypted_token) + + def get_obo_access_token(self, request): + """ + Get the current OBO access token from the session. + + The token is automatically decrypted before being returned. + + Args: + request: The current request object + + Returns: + str: The OBO access token or None if not available + """ + if not hasattr(request, "session"): + return None + + if self.using_signed_cookies: + logger.debug("Token retrieval from signed_cookies session is disabled") + return None + + if not self.store_obo_token: + logger.debug("OBO token storage is disabled") + return None + + encrypted_token = request.session.get(self.OBO_ACCESS_TOKEN_KEY) + return self.decrypt_token(encrypted_token) + + def store_tokens(self, request, access_token, adfs_response=None): + """ + Store tokens in the session. + + Args: + request: The current request object + access_token (str): The access token to store + adfs_response (dict, optional): The full response from ADFS containing refresh token and expiration + + Returns: + bool: True if tokens were stored, False otherwise + """ + if not self.should_store_tokens(request): + return False + + try: + session_modified = False + + # Store access token + encrypted_token = self.encrypt_token(access_token) + if encrypted_token: + request.session[self.ACCESS_TOKEN_KEY] = encrypted_token + session_modified = True + + # Store refresh token + if adfs_response and "refresh_token" in adfs_response: + refresh_token = adfs_response["refresh_token"] + encrypted_token = self.encrypt_token(refresh_token) + if encrypted_token: + request.session[self.REFRESH_TOKEN_KEY] = encrypted_token + session_modified = True + + # Store token expiration + if adfs_response and "expires_in" in adfs_response: + expires_at = datetime.datetime.now() + datetime.timedelta( + seconds=int(adfs_response["expires_in"]) + ) + request.session[self.TOKEN_EXPIRES_AT_KEY] = expires_at.isoformat() + session_modified = True + + # Store OBO token if enabled + if self.store_obo_token: + try: + # Import here to avoid circular imports + from django_auth_adfs.backend import AdfsBaseBackend + + backend = AdfsBaseBackend() + obo_token = backend.get_obo_access_token(access_token) + if obo_token: + encrypted_token = self.encrypt_token(obo_token) + if encrypted_token: + request.session[self.OBO_ACCESS_TOKEN_KEY] = encrypted_token + obo_expires_at = datetime.datetime.now() + datetime.timedelta(hours=1) + request.session[self.OBO_TOKEN_EXPIRES_AT_KEY] = obo_expires_at.isoformat() + session_modified = True + except Exception as e: + logger.warning(f"Error getting OBO token: {e}") + + if session_modified: + request.session.modified = True + logger.debug("Stored tokens in session") + return True + + return False + + except Exception as e: + logger.warning(f"Error storing tokens in session: {e}") + return False + + def check_token_expiration(self, request): + """ + Check if tokens need to be refreshed and refresh them if needed. + + Args: + request: The current request object + + Returns: + bool: True if tokens were checked, False otherwise + """ + if not hasattr(request, "user") or not request.user.is_authenticated: + return False + + if self.using_signed_cookies: + return False + + try: + if self.TOKEN_EXPIRES_AT_KEY not in request.session: + return False + + # Check if token is about to expire + expires_at = datetime.datetime.fromisoformat(request.session[self.TOKEN_EXPIRES_AT_KEY]) + remaining = expires_at - datetime.datetime.now() + + if remaining.total_seconds() < self.refresh_threshold: + logger.debug("Token is about to expire. Refreshing...") + self.refresh_tokens(request) + + # Check if OBO token is about to expire + if self.store_obo_token and self.OBO_TOKEN_EXPIRES_AT_KEY in request.session: + obo_expires_at = datetime.datetime.fromisoformat(request.session[self.OBO_TOKEN_EXPIRES_AT_KEY]) + obo_remaining = obo_expires_at - datetime.datetime.now() + + if obo_remaining.total_seconds() < self.refresh_threshold: + logger.debug("OBO token is about to expire. Refreshing...") + self.refresh_obo_token(request) + + return True + + except Exception as e: + logger.warning(f"Error checking token expiration: {e}") + return False + + def refresh_tokens(self, request): + """ + Refresh the access token using the refresh token. + + Args: + request: The current request object + + Returns: + bool: True if tokens were refreshed, False otherwise + """ + if self.using_signed_cookies: + return False + + if self.REFRESH_TOKEN_KEY not in request.session: + return False + + try: + refresh_token = self.decrypt_token(request.session[self.REFRESH_TOKEN_KEY]) + if not refresh_token: + logger.warning("Failed to decrypt refresh token") + return False + + provider_config.load_config() + + data = { + "grant_type": "refresh_token", + "client_id": settings.CLIENT_ID, + "refresh_token": refresh_token, + } + + if settings.CLIENT_SECRET: + data["client_secret"] = settings.CLIENT_SECRET + + # Ensure token_endpoint is a string + token_endpoint = provider_config.token_endpoint + if token_endpoint is None: + logger.error("Token endpoint is None, cannot refresh tokens") + return False + + response = provider_config.session.post( + token_endpoint, data=data, timeout=settings.TIMEOUT + ) + + if response.status_code == 200: + token_data = response.json() + request.session[self.ACCESS_TOKEN_KEY] = self.encrypt_token( + token_data["access_token"] + ) + request.session[self.REFRESH_TOKEN_KEY] = self.encrypt_token( + token_data["refresh_token"] + ) + expires_at = datetime.datetime.now() + datetime.timedelta( + seconds=int(token_data["expires_in"]) + ) + request.session[self.TOKEN_EXPIRES_AT_KEY] = expires_at.isoformat() + request.session.modified = True + logger.debug("Refreshed tokens successfully") + + # Also refresh the OBO token if needed + if self.store_obo_token: + self.refresh_obo_token(request) + + return True + else: + logger.warning( + f"Failed to refresh token: {response.status_code} {response.text}" + ) + if self.logout_on_refresh_failure: + logger.info("Logging out user due to token refresh failure") + logout(request) + return False + + except Exception as e: + logger.exception(f"Error refreshing tokens: {e}") + if self.logout_on_refresh_failure: + logger.info("Logging out user due to token refresh error") + logout(request) + return False + + def refresh_obo_token(self, request): + """ + Refresh the OBO token for Microsoft Graph API. + + Args: + request: The current request object + + Returns: + bool: True if OBO token was refreshed, False otherwise + """ + if not self.store_obo_token: + return False + + if self.using_signed_cookies: + return False + + if self.ACCESS_TOKEN_KEY not in request.session: + return False + + try: + provider_config.load_config() + + access_token = self.decrypt_token(request.session[self.ACCESS_TOKEN_KEY]) + if not access_token: + logger.warning("Failed to decrypt access token") + return False + + # Import here to avoid circular imports + from django_auth_adfs.backend import AdfsBaseBackend + + backend = AdfsBaseBackend() + obo_token = backend.get_obo_access_token(access_token) + + if obo_token: + request.session[self.OBO_ACCESS_TOKEN_KEY] = self.encrypt_token(obo_token) + obo_expires_at = datetime.datetime.now() + datetime.timedelta(hours=1) + request.session[self.OBO_TOKEN_EXPIRES_AT_KEY] = obo_expires_at.isoformat() + request.session.modified = True + logger.debug("Refreshed OBO token successfully") + return True + + return False + + except Exception as e: + logger.warning(f"Error refreshing OBO token: {e}") + return False + + def clear_tokens(self, request): + """ + Clear all tokens from the session. + + Args: + request: The current request object + + Returns: + bool: True if tokens were cleared, False otherwise + """ + if not hasattr(request, "session"): + return False + + try: + session_modified = False + + for key in [ + self.ACCESS_TOKEN_KEY, + self.REFRESH_TOKEN_KEY, + self.TOKEN_EXPIRES_AT_KEY, + self.OBO_ACCESS_TOKEN_KEY, + self.OBO_TOKEN_EXPIRES_AT_KEY + ]: + if key in request.session: + del request.session[key] + session_modified = True + + if session_modified: + request.session.modified = True + logger.debug("Cleared tokens from session") + return True + + return False + + except Exception as e: + logger.warning(f"Error clearing tokens from session: {e}") + return False + + +# Create a singleton instance +token_manager = TokenManager() \ No newline at end of file diff --git a/django_auth_adfs/utils.py b/django_auth_adfs/utils.py deleted file mode 100644 index a34de790..00000000 --- a/django_auth_adfs/utils.py +++ /dev/null @@ -1,151 +0,0 @@ -""" -Utility functions for django-auth-adfs. - -Only relevant if you are using the Token Lifecycle Middleware. -""" - -import logging -import base64 - -from cryptography.fernet import Fernet -from cryptography.hazmat.primitives import hashes -from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC - -from django.conf import settings as django_settings -from django_auth_adfs.config import settings - -logger = logging.getLogger("django_auth_adfs") - - -def _get_encryption_key(): - """ - Derive a Fernet encryption key from Django's SECRET_KEY. - - The salt can be customized through the TOKEN_ENCRYPTION_SALT setting. - - Returns: - bytes: A 32-byte key suitable for Fernet encryption - """ - # Use Django's SECRET_KEY to derive a suitable encryption key - default_salt = b"django_auth_adfs_token_encryption" - salt = getattr(settings, "TOKEN_ENCRYPTION_SALT", default_salt) - - if isinstance(salt, str): - salt = salt.encode() - - kdf = PBKDF2HMAC( - algorithm=hashes.SHA256(), - length=32, - salt=salt, - iterations=100000, - ) - key = base64.urlsafe_b64encode(kdf.derive(django_settings.SECRET_KEY.encode())) - return key - - -def _encrypt_token(token): - """ - Encrypt a token using Django's SECRET_KEY. - - Args: - token (str): The token to encrypt - - Returns: - str: The encrypted token as a string - """ - if not token: - return None - - try: - key = _get_encryption_key() - f = Fernet(key) - encrypted_token = f.encrypt(token.encode()) - return encrypted_token.decode() - except Exception as e: - logger.error(f"Error encrypting token: {e}") - return None - - -def _decrypt_token(encrypted_token): - """ - Decrypt a token that was encrypted using Django's SECRET_KEY. - - Args: - encrypted_token (str): The encrypted token - - Returns: - str: The decrypted token or None if decryption fails - """ - if not encrypted_token: - return None - - try: - key = _get_encryption_key() - f = Fernet(key) - decrypted_token = f.decrypt(encrypted_token.encode()) - return decrypted_token.decode() - except Exception as e: - logger.error(f"Error decrypting token: {e}") - return None - - -def _is_signed_cookies_disabled(): - """ - Check if token storage is disabled for signed_cookies session backend - """ - using_signed_cookies = ( - django_settings.SESSION_ENGINE - == "django.contrib.sessions.backends.signed_cookies" - ) - return using_signed_cookies - - -def get_access_token(request): - """ - Get the current access token from the session. - - The token is automatically decrypted before being returned. - - Args: - request: The current request object - - Returns: - str: The access token or None if not available - """ - if not hasattr(request, "session"): - return None - - if _is_signed_cookies_disabled(): - logger.debug("Token retrieval from signed_cookies session is disabled") - return None - - encrypted_token = request.session.get("ADFS_ACCESS_TOKEN") - return _decrypt_token(encrypted_token) - - -def get_obo_access_token(request): - """ - Get the current OBO (On-Behalf-Of) access token for Microsoft Graph API from the session. - - The token is automatically decrypted before being returned. - - Args: - request: The current request object - - Returns: - str: The OBO access token or None if not available - """ - if not hasattr(request, "session"): - return None - - if _is_signed_cookies_disabled(): - logger.debug("Token retrieval from signed_cookies session is disabled") - return None - - store_obo_token = getattr(settings, "STORE_OBO_TOKEN", True) - if not store_obo_token: - logger.debug("OBO token storage is disabled") - return None - - encrypted_token = request.session.get("ADFS_OBO_ACCESS_TOKEN") - return _decrypt_token(encrypted_token) diff --git a/docs/token_lifecycle.rst b/docs/token_lifecycle.rst index f19869a5..c27b62cf 100644 --- a/docs/token_lifecycle.rst +++ b/docs/token_lifecycle.rst @@ -4,31 +4,35 @@ Token Lifecycle Middleware Traditionally, django-auth-adfs is used **exclusively** as an authentication solution - it handles user authentication via ADFS/Azure AD and maps claims to Django users. It doesn't really care about the access tokens from Azure/ADFS after you've been authenticated. -The Token Lifecycle Middleware extends django-auth-adfs beyond pure authentication to also handle the complete lifecycle of access tokens +The Token Lifecycle system extends django-auth-adfs beyond pure authentication to also handle the complete lifecycle of access tokens after the authentication process. This creates a more integrated approach where: * The same application registration handles both authentication and resource access * Tokens obtained during authentication are stored and refreshed automatically in the session * The application can make delegated API calls on behalf of the user -* The middleware can optionally log out users when token refresh fails +* The system can optionally log out users when token refresh fails How it works ------------ -The ``TokenLifecycleMiddleware`` handles the entire token lifecycle: +The token lifecycle system consists of two main components: -1. **Initial Token Capture**: Uses the ``post_authenticate`` signal to capture tokens during authentication -2. **Token Storage**: Automatically stores tokens in the session after successful authentication -3. **Token Refresh**: Checks if the access token is about to expire and refreshes it if needed -4. **Optional Security Enforcement**: Can be configured to log out users when token refresh fails -5. **Session Management**: Keeps the session updated with the latest tokens -6. **OBO Token Management**: Handles On-Behalf-Of tokens for Microsoft Graph API access +1. **TokenManager**: A centralized singleton that handles all token operations including storage, retrieval, encryption, refresh, and OBO token management +2. **TokenLifecycleMiddleware**: A middleware that monitors token expiration and triggers refresh when needed -Read more: https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-on-behalf-of-flow#protocol-diagram +Together, they handle the entire token lifecycle: + +1. **Token Storage**: The django-auth-adfs backend automatically stores and encrypts tokens during authentication when the ``TokenLifecycleMiddleware`` is enabled +2. **Token Monitoring**: The middleware checks token expiration on each request +3. **Token Refresh**: When a token is about to expire, it is automatically refreshed +4. **OBO Token Management**: When enabled (by default), OBO tokens are automatically acquired and refreshed for Microsoft Graph API access +5. **Security Controls**: Optional automatic logout on token refresh failures + +Read more about the OBO flow: https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-on-behalf-of-flow#protocol-diagram .. warning:: - The Token Lifecycle Middleware is a new feature in django-auth-adfs and is considered experimental. + The Token Lifecycle system is a new feature in django-auth-adfs and is considered experimental. Please be aware: **Currently no community support is guaranteed to be available for this feature** @@ -42,7 +46,7 @@ Read more: https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-o Configuration ------------- -To enable the token lifecycle middleware, add it to your ``MIDDLEWARE`` setting in your Django settings file: +To enable the token lifecycle system, add the middleware to your ``MIDDLEWARE`` setting in your Django settings file: .. code-block:: python @@ -86,7 +90,7 @@ You can configure the token lifecycle behavior with these settings in your Djang Consider this when deploying changes to the salt in production environments. .. note:: - By default (``STORE_OBO_TOKEN = True``), the middleware will automatically request and store OBO tokens + By default (``STORE_OBO_TOKEN = True``), the system will automatically request and store OBO tokens for Microsoft Graph API access. If your application doesn't need to access Microsoft Graph API, you can set ``STORE_OBO_TOKEN = False`` to disable this functionality completely. See `the OBO token configuration section <#disabling-obo-token-functionality>`_ for more details. @@ -94,42 +98,41 @@ You can configure the token lifecycle behavior with these settings in your Djang Considerations -------------- -- The middleware will automatically capture and store tokens in the session during authentication using signals. -- You don't need to modify your views or authentication backends to store tokens. -- Token refresh only works for authenticated users with valid sessions. -- If the refresh token is invalid or expired, the middleware will not be able to refresh the access token. -- By default, the middleware will not log the user out if token refresh fails, but this behavior can be changed with the ``LOGOUT_ON_TOKEN_REFRESH_FAILURE`` setting. -- The middleware will not store tokens in the session when using the ``signed_cookies`` session backend by default. -- OBO token storage is enabled by default but can be disabled with the ``STORE_OBO_TOKEN`` setting. +- Token storage and encryption are handled automatically by the django-auth-adfs backend during authentication +- Token refresh only works for authenticated users with valid sessions +- If the refresh token is invalid or expired, the system will not be able to refresh the access token +- By default, the system will not log the user out if token refresh fails, but this behavior can be changed with the ``LOGOUT_ON_TOKEN_REFRESH_FAILURE`` setting +- The system will not store tokens in the session when using the ``signed_cookies`` session backend +- OBO token storage is enabled by default but can be disabled with the ``STORE_OBO_TOKEN`` setting - Using the OBO token versus the regular access token is dependent on the resources you are accessing and the permissions granted to your ADFS/Azure AD application. See `the token types section <#understanding-access-tokens-vs-obo-tokens>`_ for more details. **Token Refresh Failures** -By default, when token refresh fails, the middleware logs the error but allows the user to continue using the application until their session expires naturally. This behavior can be changed with the ``LOGOUT_ON_TOKEN_REFRESH_FAILURE`` setting: +By default, when token refresh fails, the system logs the error but allows the user to continue using the application until their session expires naturally. This behavior can be changed with the ``LOGOUT_ON_TOKEN_REFRESH_FAILURE`` setting: - When set to ``False`` (default), users remain logged in even if their tokens can't be refreshed - When set to ``True``, users are automatically logged out when token refresh fails When a user's account is disabled in Azure AD/ADFS, their existing Django sessions will remain active by default until they expire naturally. This can create a security gap where revoked users maintain access to your application. -The ``LOGOUT_ON_TOKEN_REFRESH_FAILURE`` setting provides an option to address this concern by allowing you to configure the middleware to automatically log out users when their token refresh fails, which happens when their account has been disabled in the identity provider. +The ``LOGOUT_ON_TOKEN_REFRESH_FAILURE`` setting provides an option which helps address this concern by allowing you to automatically log out users when their token refresh fails, which will happen some time after their account has been disabled in the identity provider. **Existing Sessions** -When deploying the Token Lifecycle Middleware to an existing application with active user sessions, be aware of the following: +When deploying the Token Lifecycle system to an existing application with active user sessions, be aware of the following: -The middleware only captures tokens during the authentication process. Existing authenticated sessions won't have tokens stored in them, which means: +The system only captures tokens during the authentication process. Existing authenticated sessions won't have tokens stored in them, which means: - Users with existing sessions won't have access to token-dependent features until they re-authenticate - Utility functions like ``get_access_token()`` and ``get_obo_access_token()`` will return ``None`` for these sessions - API calls that depend on these tokens will fail for existing sessions -The best approach is to ensure that all users re-authenticate after the middleware is deployed. +The best approach is to ensure that all users re-authenticate after the system is deployed. Azure AD Application Configuration ---------------------------------- -When using the Token Lifecycle Middleware, your Azure AD application registration needs additional permissions +When using the Token Lifecycle system, your Azure AD application registration needs additional permissions beyond those required for simple authentication. This extends the standard authentication-only setup described in the :doc:`azure_ad_config_guide` with additional API permissions needed for delegated access. @@ -144,7 +147,7 @@ Security Overview **Token Encryption** Tokens are automatically encrypted before being stored in the session and decrypted when they are retrieved. -The encryption is handled transparently by the middleware and utility functions. This provides an additional layer of security: +The encryption is handled transparently by the TokenManager and utility functions. This provides an additional layer of security: - **Always Enabled**: Token encryption is always enabled and cannot be disabled - **Encryption Method**: Tokens are encrypted using the Fernet symmetric encryption algorithm @@ -155,11 +158,11 @@ The encryption is handled transparently by the middleware and utility functions. **Signed Cookies Session Backend Restriction** -The middleware will not store tokens in the session when using Django's ``signed_cookies`` session backend: +The system will not store tokens in the session when using Django's ``signed_cookies`` session backend: .. code-block:: python - # This will not work with the token lifecycle middleware + # This will not work with the token lifecycle system SESSION_ENGINE = 'django.contrib.sessions.backends.signed_cookies' This is for a few reasons: @@ -168,7 +171,7 @@ This is for a few reasons: 2. **Security Risks**: Storing sensitive tokens in cookies increases the risk of token theft 3. **Performance**: Large cookies are sent with every request, increasing bandwidth usage -If you're using the ``signed_cookies`` session backend and need token storage, you won't be able to use the token lifecycle middleware. +If you're using the ``signed_cookies`` session backend and need token storage, you won't be able to use the token lifecycle system. .. note:: This restriction only applies to the ``signed_cookies`` session backend. For other session backends (database, cache, file), @@ -176,12 +179,12 @@ If you're using the ``signed_cookies`` session backend and need token storage, y **Automatic OBO Token Acquisition** -By default, the middleware automatically requests OBO tokens during authentication. If your application doesn't need OBO tokens, you can disable this behavior to reduce unnecessary token requests (see `the OBO token configuration section <#disabling-obo-token-functionality>`_ for more details). +By default, the system automatically requests OBO tokens when storing tokens. If your application doesn't need OBO tokens, you can disable this behavior to reduce unnecessary token requests (see `the OBO token configuration section <#disabling-obo-token-functionality>`_ for more details). Disabling OBO Token Functionality --------------------------------- -By default, the Token Lifecycle Middleware automatically requests and stores OBO tokens for Microsoft Graph API access. If you don't need this functionality (for example, if your application doesn't interact with Microsoft Graph API), you can disable it completely: +By default, the Token Lifecycle system automatically requests and stores OBO tokens for Microsoft Graph API access. If you don't need this functionality (for example, if your application doesn't interact with Microsoft Graph API), you can disable it completely: .. code-block:: python @@ -192,9 +195,9 @@ By default, the Token Lifecycle Middleware automatically requests and stores OBO When this setting is ``False``: -1. The middleware will not request OBO tokens during authentication -2. The middleware will not store OBO tokens in the session -3. The middleware will not refresh OBO tokens +1. The system will not request OBO tokens during token storage +2. The system will not store OBO tokens in the session +3. The system will not refresh OBO tokens 4. The ``get_obo_access_token`` utility function will always return ``None`` Note that disabling OBO tokens doesn't affect the regular access token functionality. Your application will still be able to use the access token obtained during authentication for its own resources and APIs that directly trust your application. @@ -204,33 +207,33 @@ See `the token types section <#understanding-access-tokens-vs-obo-tokens>`_ for Accessing Tokens in Your Views ------------------------------ -When building views that need to make requests using the Azure AD/ADFS tokens, you'll need to access the tokens stored in the session. - -Since tokens are encrypted in the session, Token Lifecycle Middleware provides utility functions in the ``django_auth_adfs.utils`` module to help you access tokens safely: +Since tokens are encrypted in the session, the Token Lifecycle system provides a centralized TokenManager to help you access tokens safely: .. code-block:: python + from django_auth_adfs.token_manager import token_manager + # For your own APIs or APIs that trust your application directly - from django_auth_adfs.utils import get_access_token + access_token = token_manager.get_access_token(request) # For Microsoft Graph API or other APIs requiring delegated access - from django_auth_adfs.utils import get_obo_access_token + obo_token = token_manager.get_obo_access_token(request) -These utility functions automatically handle decryption of the tokens, so you don't need to worry about the encryption details. +The TokenManager automatically handles encryption/decryption of tokens, so you don't need to worry about the encryption details. .. warning:: - You should always use these utility functions to access tokens rather than accessing them directly from the session. + You should always use the TokenManager to access tokens rather than accessing them directly from the session. Direct access to ``request.session["ADFS_ACCESS_TOKEN"]`` will give you the encrypted token, not the actual token value. Examples ---------------------- -Here are practical examples of using these utility functions in your views: +Here are practical examples of using the TokenManager in your views: Using with Microsoft Graph API ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -In this flow, we will exchange our access token from the authentication process for an OBO token to access Microsoft Graph API. +In this flow, we will use the OBO token to access Microsoft Graph API. This is the recommended flow for delegated access to Microsoft Graph API. @@ -238,13 +241,13 @@ This is the recommended flow for delegated access to Microsoft Graph API. from django.contrib.auth.decorators import login_required from django.http import JsonResponse - from django_auth_adfs.utils import get_obo_access_token + from django_auth_adfs.token_manager import token_manager import requests @login_required def me_view(request): """Get the user's profile from Microsoft Graph API""" - obo_token = get_obo_access_token(request) + obo_token = token_manager.get_obo_access_token(request) if not obo_token: return JsonResponse({"error": "No OBO token available"}, status=401) @@ -267,21 +270,21 @@ This is the recommended flow for delegated access to Microsoft Graph API. Using with other resources ~~~~~~~~~~~~~~~~~~~~~~~~~~ -The key difference here is to use the ``get_access_token`` function to get the token for the resource you are accessing. +The key difference here is to use the ``get_access_token`` method to get the token for the resource you are accessing. -This is different than the ``get_obo_access_token`` function, which is used for Microsoft Graph API delegated access in the previous example. +This is different than the ``get_obo_access_token`` method, which is used for Microsoft Graph API delegated access in the previous example. .. code-block:: python from rest_framework.views import APIView from rest_framework.response import Response - from django_auth_adfs.utils import get_access_token + from django_auth_adfs.token_manager import token_manager import requests class ExternalApiView(APIView): def get(self, request): """Call an API that accepts your application's token""" - token = get_access_token(request) + token = token_manager.get_access_token(request) if not token: return Response({"error": "No access token available"}, status=401) @@ -298,9 +301,10 @@ The following example code demonstrates a debug view to check the values of the .. code-block:: python + import requests from django.contrib.auth.decorators import login_required from django.http import JsonResponse - from django_auth_adfs.utils import get_access_token, get_obo_access_token + from django_auth_adfs.token_manager import token_manager from datetime import datetime @login_required @@ -314,18 +318,16 @@ The following example code demonstrates a debug view to check the values of the # Basic session token info session_info = { - "has_access_token": "ADFS_ACCESS_TOKEN" in request.session, - "has_refresh_token": "ADFS_REFRESH_TOKEN" in request.session, - "has_expires_at": "ADFS_TOKEN_EXPIRES_AT" in request.session, + "has_access_token": token_manager.ACCESS_TOKEN_KEY in request.session, + "has_refresh_token": token_manager.REFRESH_TOKEN_KEY in request.session, + "has_expires_at": token_manager.TOKEN_EXPIRES_AT_KEY in request.session, } # Add token expiration details if available - if "ADFS_TOKEN_EXPIRES_AT" in request.session: - from datetime import datetime - + if token_manager.TOKEN_EXPIRES_AT_KEY in request.session: try: expires_at = datetime.fromisoformat( - request.session["ADFS_TOKEN_EXPIRES_AT"] + request.session[token_manager.TOKEN_EXPIRES_AT_KEY] ) now = datetime.now() session_info["token_expires_at"] = expires_at.isoformat() @@ -337,15 +339,14 @@ The following example code demonstrates a debug view to check the values of the session_info["expiration_parse_error"] = str(e) # Show raw encrypted tokens for debugging - if "ADFS_ACCESS_TOKEN" in request.session: - raw_token = request.session["ADFS_ACCESS_TOKEN"] + if token_manager.ACCESS_TOKEN_KEY in request.session: + raw_token = request.session[token_manager.ACCESS_TOKEN_KEY] session_info["raw_token_preview"] = f"{raw_token[:10]}...{raw_token[-10:]}" session_info["raw_token_length"] = len(raw_token) # Try to decode as JWT without decryption (should fail if properly encrypted) try: import jwt - jwt.decode(raw_token, options={"verify_signature": False}) session_info["is_encrypted"] = False except: @@ -353,9 +354,7 @@ The following example code demonstrates a debug view to check the values of the # Get properly decrypted access token try: - from django_auth_adfs.utils import get_access_token - - access_token = get_access_token(request) + access_token = token_manager.get_access_token(request) session_info["decrypted_access_token_available"] = access_token is not None if access_token: @@ -368,13 +367,10 @@ The following example code demonstrates a debug view to check the values of the # Try to decode as JWT (should succeed if properly decrypted) try: import jwt - decoded = jwt.decode(access_token, options={"verify_signature": False}) session_info["jwt_decode_success"] = True # Add some basic JWT info without exposing sensitive data if "exp" in decoded: - from datetime import datetime - exp_time = datetime.fromtimestamp(decoded["exp"]) session_info["jwt_expiry"] = exp_time.isoformat() except Exception as e: @@ -384,16 +380,14 @@ The following example code demonstrates a debug view to check the values of the # Check if OBO token is available try: - from django_auth_adfs.utils import get_obo_access_token - - obo_token = get_obo_access_token(request) + obo_token = token_manager.get_obo_access_token(request) obo_info = { "has_obo_token": obo_token is not None, } # Show raw encrypted OBO token if available - if "ADFS_OBO_ACCESS_TOKEN" in request.session: - raw_obo = request.session["ADFS_OBO_ACCESS_TOKEN"] + if token_manager.OBO_ACCESS_TOKEN_KEY in request.session: + raw_obo = request.session[token_manager.OBO_ACCESS_TOKEN_KEY] obo_info["raw_obo_preview"] = f"{raw_obo[:10]}...{raw_obo[-10:]}" obo_info["raw_obo_length"] = len(raw_obo) @@ -405,13 +399,10 @@ The following example code demonstrates a debug view to check the values of the # Try to decode as JWT (should succeed if properly decrypted) try: import jwt - decoded = jwt.decode(obo_token, options={"verify_signature": False}) obo_info["jwt_decode_success"] = True # Add some basic JWT info without exposing sensitive data if "exp" in decoded: - from datetime import datetime - exp_time = datetime.fromtimestamp(decoded["exp"]) obo_info["jwt_expiry"] = exp_time.isoformat() except Exception as e: @@ -454,6 +445,6 @@ It's important to understand the difference between regular access tokens and OB **OBO (On-Behalf-Of) Token**: The OBO flow is specifically designed for delegated access scenarios where your application needs to access resources (like Microsoft Graph) on behalf of the authenticated user. - The middleware handles this exchange automatically when OBO token storage is enabled. + The TokenManager handles this exchange automatically when OBO token storage is enabled. For more information on the different types of permissions, see `the Microsoft documentation `_. diff --git a/tests/test_middleware.py b/tests/test_middleware.py index a487fc66..cfaf1a78 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -1,3 +1,7 @@ +""" +Tests for the TokenLifecycleMiddleware and TokenManager. +""" + import datetime from unittest.mock import Mock, patch import time @@ -8,997 +12,169 @@ from django.contrib.sessions.backends.db import SessionStore from django_auth_adfs.middleware import TokenLifecycleMiddleware -from django_auth_adfs.config import settings -from django_auth_adfs.utils import ( - get_access_token, - get_obo_access_token, - _encrypt_token, - _decrypt_token, -) +from django_auth_adfs.config import settings as adfs_settings +from django_auth_adfs.token_manager import token_manager, TokenManager from tests.settings import MIDDLEWARE User = get_user_model() -# Add TokenLifecycleMiddleware to the existing middleware MIDDLEWARE_WITH_TOKEN_LIFECYCLE = MIDDLEWARE + ( "django_auth_adfs.middleware.TokenLifecycleMiddleware", ) @override_settings(MIDDLEWARE=MIDDLEWARE_WITH_TOKEN_LIFECYCLE) -class TokenLifecycleMiddlewareTests(TestCase): +class TokenLifecycleTests(TestCase): """ - Tests for the TokenLifecycleMiddleware. - - The middleware handles the lifecycle of ADFS tokens: - 1. Storing tokens from user object to session - 2. Detecting when tokens need to be refreshed - 3. Refreshing tokens when needed - 4. Handling OBO (On-Behalf-Of) tokens + Tests for the token lifecycle functionality, covering both TokenManager and TokenLifecycleMiddleware. """ def setUp(self): - """Set up test environment before each test""" self.factory = RequestFactory() - self.middleware = TokenLifecycleMiddleware(lambda r: r) self.user = User.objects.create_user(username="testuser") self.request = self.factory.get("/") self.request.user = self.user self.request.session = SessionStore() + self.middleware = TokenLifecycleMiddleware(lambda r: None) - # Group 1: Initialization Tests - - def test_init_with_default_settings(self): - """Test middleware initialization with default settings""" - middleware = TokenLifecycleMiddleware(lambda r: r) - self.assertEqual(middleware.threshold, 300) - self.assertTrue(middleware.store_obo_token) - self.assertFalse(middleware.using_signed_cookies) - - def test_init_with_custom_settings(self): - """Test middleware initialization with custom settings""" - with patch("django_auth_adfs.middleware.getattr") as mock_getattr: - # Mock getattr to return custom values - mock_getattr.side_effect = lambda obj, name, default: { - "TOKEN_REFRESH_THRESHOLD": 600, - "STORE_OBO_TOKEN": False, - }.get(name, default) - - middleware = TokenLifecycleMiddleware(lambda r: r) - - # Verify custom settings are applied - self.assertEqual(middleware.threshold, 600) - self.assertFalse(middleware.store_obo_token) - - # Group 2: Token Storage Tests - - def test_store_tokens_from_auth(self): - """Test storing tokens directly in session during authentication""" - # Create a mock sender and adfs_response - sender = Mock() - sender.access_token = "test_access_token" - sender.get_obo_access_token.return_value = "test_obo_token" - - adfs_response = { - "refresh_token": "test_refresh_token", - "expires_in": 3600, - } - - # Call the signal handler - self.middleware._capture_tokens_from_auth( - sender=sender, - user=self.user, - claims={}, - adfs_response=adfs_response, - request=self.request - ) - - # Check session - decrypt tokens before comparing - self.assertEqual( - _decrypt_token(self.request.session["ADFS_ACCESS_TOKEN"]), - "test_access_token", - ) - self.assertEqual( - _decrypt_token(self.request.session["ADFS_REFRESH_TOKEN"]), - "test_refresh_token", - ) - self.assertEqual( - _decrypt_token(self.request.session["ADFS_OBO_ACCESS_TOKEN"]), - "test_obo_token", - ) - self.assertTrue("ADFS_TOKEN_EXPIRES_AT" in self.request.session) - self.assertTrue("ADFS_OBO_TOKEN_EXPIRES_AT" in self.request.session) - - def test_store_tokens_from_user(self): - """Test storing tokens directly in session during authentication""" - # Create a mock sender and adfs_response - sender = Mock() - sender.access_token = "test_access_token" - sender.get_obo_access_token.return_value = "test_obo_token" - - adfs_response = { - "refresh_token": "test_refresh_token", - "expires_in": 3600, - } - - # Call the signal handler - self.middleware._capture_tokens_from_auth( - sender=sender, - user=self.user, - claims={}, - adfs_response=adfs_response, - request=self.request - ) - - # Check session - decrypt tokens before comparing - self.assertEqual( - _decrypt_token(self.request.session["ADFS_ACCESS_TOKEN"]), - "test_access_token", - ) - self.assertEqual( - _decrypt_token(self.request.session["ADFS_REFRESH_TOKEN"]), - "test_refresh_token", - ) - self.assertEqual( - _decrypt_token(self.request.session["ADFS_OBO_ACCESS_TOKEN"]), - "test_obo_token", - ) - self.assertTrue("ADFS_TOKEN_EXPIRES_AT" in self.request.session) - self.assertTrue("ADFS_OBO_TOKEN_EXPIRES_AT" in self.request.session) - - def test_store_partial_tokens_from_auth(self): - """Test storing partial tokens during authentication""" - # Create a mock sender with only access token - sender = Mock() - sender.access_token = "test_access_token" - sender.get_obo_access_token.return_value = None - - # No refresh token in adfs_response - adfs_response = { - "expires_in": 3600, - } - - # Call the signal handler - self.middleware._capture_tokens_from_auth( - sender=sender, - user=self.user, - claims={}, - adfs_response=adfs_response, - request=self.request - ) - - # Check session - should have access token but not refresh token - self.assertEqual( - _decrypt_token(self.request.session["ADFS_ACCESS_TOKEN"]), - "test_access_token", - ) - self.assertFalse("ADFS_REFRESH_TOKEN" in self.request.session) - self.assertTrue("ADFS_TOKEN_EXPIRES_AT" in self.request.session) - self.assertFalse("ADFS_OBO_ACCESS_TOKEN" in self.request.session) - - def test_store_tokens_with_signed_cookies(self): - """Test that tokens are not stored when using signed cookies""" - # Set up middleware to use signed cookies - self.middleware.using_signed_cookies = True - - # Create a mock sender and adfs_response - sender = Mock() - sender.access_token = "test_access_token" - sender.get_obo_access_token.return_value = "test_obo_token" - - adfs_response = { - "refresh_token": "test_refresh_token", - "expires_in": 3600, - } - - # Call the signal handler - self.middleware._capture_tokens_from_auth( - sender=sender, - user=self.user, - claims={}, - adfs_response=adfs_response, - request=self.request - ) - - # Check session - no tokens should be stored - self.assertFalse("ADFS_ACCESS_TOKEN" in self.request.session) - self.assertFalse("ADFS_REFRESH_TOKEN" in self.request.session) - self.assertFalse("ADFS_TOKEN_EXPIRES_AT" in self.request.session) - self.assertFalse("ADFS_OBO_ACCESS_TOKEN" in self.request.session) - - def test_session_modified_flag(self): - """Test that the session modified flag is only set when needed""" - # Create a session and set modified to False - self.request.session = SessionStore() - self.request.session.modified = False - - # Call the signal handler with no tokens - sender = Mock(spec=[]) - self.middleware._capture_tokens_from_auth( - sender=sender, - user=self.user, - claims={}, - adfs_response={}, - request=self.request - ) - - # Session should not be modified - self.assertFalse(self.request.session.modified) - - # Call with tokens - sender = Mock() - sender.access_token = "test_token" - adfs_response = {"expires_in": 3600} - self.middleware._capture_tokens_from_auth( - sender=sender, - user=self.user, - claims={}, - adfs_response=adfs_response, - request=self.request - ) - - # Session should be modified - self.assertTrue(self.request.session.modified) - - # Group 3: Token Refresh Detection Tests - - def test_handle_token_refresh_not_needed(self): - """Test token refresh when it's not needed""" - self.request.session["ADFS_ACCESS_TOKEN"] = "test_access_token" - self.request.session["ADFS_REFRESH_TOKEN"] = "test_refresh_token" - expires_at = datetime.datetime.now() + datetime.timedelta( - hours=1 - ) # 1 hour to expiry - self.request.session["ADFS_TOKEN_EXPIRES_AT"] = expires_at.isoformat() - - with patch.object(self.middleware, "_refresh_tokens") as mock_refresh: - self.middleware._handle_token_refresh(self.request) - mock_refresh.assert_not_called() - - def test_handle_token_refresh_needed(self): - """Test token refresh when it's needed""" - self.request.session["ADFS_ACCESS_TOKEN"] = "test_access_token" - self.request.session["ADFS_REFRESH_TOKEN"] = "test_refresh_token" - expires_at = datetime.datetime.now() + datetime.timedelta( - seconds=60 - ) # 1 minute to expiry - self.request.session["ADFS_TOKEN_EXPIRES_AT"] = expires_at.isoformat() - - with patch.object(self.middleware, "_refresh_tokens") as mock_refresh: - self.middleware._handle_token_refresh(self.request) - mock_refresh.assert_called_once_with(self.request) - - def test_handle_expired_token(self): - """Test token refresh when token is already expired""" - self.request.session["ADFS_ACCESS_TOKEN"] = "test_access_token" - self.request.session["ADFS_REFRESH_TOKEN"] = "test_refresh_token" - expires_at = datetime.datetime.now() - datetime.timedelta( - hours=1 - ) # Expired 1 hour ago - self.request.session["ADFS_TOKEN_EXPIRES_AT"] = expires_at.isoformat() - - with patch.object(self.middleware, "_refresh_tokens") as mock_refresh: - self.middleware._handle_token_refresh(self.request) - mock_refresh.assert_called_once_with(self.request) - - def test_obo_token_expires_before_access_token(self): - """Test when OBO token expires before access token""" - # Set up access token with long expiry - self.request.session["ADFS_ACCESS_TOKEN"] = "access_token" - self.request.session["ADFS_REFRESH_TOKEN"] = "refresh_token" - access_token_expires_at = datetime.datetime.now() + datetime.timedelta(hours=1) - self.request.session["ADFS_TOKEN_EXPIRES_AT"] = ( - access_token_expires_at.isoformat() - ) - - # Set up OBO token with short expiry - self.request.session["ADFS_OBO_ACCESS_TOKEN"] = "obo_token" - obo_expires_at = datetime.datetime.now() + datetime.timedelta(seconds=30) - self.request.session["ADFS_OBO_TOKEN_EXPIRES_AT"] = obo_expires_at.isoformat() - - # Should refresh only OBO token - with patch.object( - self.middleware, "_refresh_tokens" - ) as mock_refresh_token, patch.object( - self.middleware, "_refresh_obo_token" - ) as mock_refresh_obo: - self.middleware._handle_token_refresh(self.request) - - # Verify only OBO token is refreshed, not the access token - mock_refresh_token.assert_not_called() - mock_refresh_obo.assert_called_once_with(self.request) - - # Verify session state remains unchanged for access token - self.assertEqual(self.request.session["ADFS_ACCESS_TOKEN"], "access_token") - self.assertEqual( - self.request.session["ADFS_REFRESH_TOKEN"], "refresh_token" - ) - self.assertEqual( - self.request.session["ADFS_TOKEN_EXPIRES_AT"], - access_token_expires_at.isoformat(), - ) - - # Group 4: Token Refresh Implementation Tests - - @patch("django_auth_adfs.middleware.provider_config") - def test_refresh_token_success(self, mock_provider_config): - """Test successful token refresh""" - # Set up mock response - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "access_token": "new_access_token", - "refresh_token": "new_refresh_token", - "expires_in": 3600, - } - - # Configure the mock - mock_provider_config.session.post.return_value = mock_response - mock_provider_config.token_endpoint = ( - "https://adfs.example.com/adfs/oauth2/token" - ) - - # Set up session with expired token - self.request.session["ADFS_ACCESS_TOKEN"] = _encrypt_token("old_access_token") - self.request.session["ADFS_REFRESH_TOKEN"] = _encrypt_token("old_refresh_token") - self.request.session["ADFS_TOKEN_EXPIRES_AT"] = ( - datetime.datetime.now() - datetime.timedelta(minutes=5) - ).isoformat() - - # Mock the OBO token refresh to prevent real HTTP requests - with patch.object(self.middleware, "_refresh_obo_token") as mock_refresh_obo: - # Call refresh method - self.middleware._refresh_tokens(self.request) - - # Check that tokens were updated - self.assertEqual( - _decrypt_token(self.request.session["ADFS_ACCESS_TOKEN"]), - "new_access_token", - ) - self.assertEqual( - _decrypt_token(self.request.session["ADFS_REFRESH_TOKEN"]), - "new_refresh_token", - ) - - @patch("django_auth_adfs.middleware.provider_config") - def test_refresh_token_without_new_refresh_token(self, mock_provider_config): - """Test token refresh when response doesn't include a new refresh token""" - # Set up mock response without refresh_token - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "access_token": "new_access_token", - "expires_in": 3600, - } - - # Configure the mock - mock_provider_config.session.post.return_value = mock_response - mock_provider_config.token_endpoint = ( - "https://adfs.example.com/adfs/oauth2/token" - ) - - # Set up session with expired token - self.request.session["ADFS_ACCESS_TOKEN"] = _encrypt_token("old_access_token") - self.request.session["ADFS_REFRESH_TOKEN"] = _encrypt_token("old_refresh_token") - self.request.session["ADFS_TOKEN_EXPIRES_AT"] = ( - datetime.datetime.now() - datetime.timedelta(minutes=5) - ).isoformat() - - # Mock the OBO token refresh to prevent real HTTP requests - with patch.object(self.middleware, "_refresh_obo_token") as mock_refresh_obo: - # Call refresh method - self.middleware._refresh_tokens(self.request) - - # Check that access token was updated but refresh token remains the same - self.assertEqual( - _decrypt_token(self.request.session["ADFS_ACCESS_TOKEN"]), - "new_access_token", - ) - self.assertEqual( - _decrypt_token(self.request.session["ADFS_REFRESH_TOKEN"]), - "old_refresh_token", - ) - - def test_refresh_obo_token_success(self): - """Test successful OBO token refresh""" - # Ensure OBO token storage is enabled - self.middleware.store_obo_token = True - - # Set up session with expired OBO token but valid access token - self.request.session["ADFS_ACCESS_TOKEN"] = _encrypt_token("valid_access_token") - self.request.session["ADFS_REFRESH_TOKEN"] = _encrypt_token( - "valid_refresh_token" - ) - self.request.session["ADFS_OBO_ACCESS_TOKEN"] = _encrypt_token( - "expired_obo_token" - ) - self.request.session["ADFS_TOKEN_EXPIRES_AT"] = ( - datetime.datetime.now() + datetime.timedelta(hours=1) - ).isoformat() - self.request.session["ADFS_OBO_TOKEN_EXPIRES_AT"] = ( - datetime.datetime.now() - datetime.timedelta(minutes=5) - ).isoformat() - - # Save the original method - original_refresh_obo_token = self.middleware._refresh_obo_token - - # Create a mock implementation - def mock_refresh_obo_token(request): - # This simulates a successful token refresh - request.session["ADFS_OBO_ACCESS_TOKEN"] = _encrypt_token("new_obo_token") - request.session["ADFS_OBO_TOKEN_EXPIRES_AT"] = ( - datetime.datetime.now() + datetime.timedelta(hours=1) - ).isoformat() - request.session.modified = True - - # Replace the method with our mock - self.middleware._refresh_obo_token = mock_refresh_obo_token - - try: - # Call handle token refresh directly - self.middleware._handle_token_refresh(self.request) + def test_settings_configuration(self): + """Test settings are properly loaded from Django settings""" + with patch.object(adfs_settings, 'TOKEN_REFRESH_THRESHOLD', 600), \ + patch.object(adfs_settings, 'STORE_OBO_TOKEN', False), \ + patch.object(adfs_settings, 'LOGOUT_ON_TOKEN_REFRESH_FAILURE', True): - # Verify the new token was stored in the session + manager = TokenManager() + self.assertEqual(manager.refresh_threshold, 600) + self.assertFalse(manager.store_obo_token) + self.assertTrue(manager.logout_on_refresh_failure) + + def test_token_storage_and_retrieval(self): + """Test the complete token storage and retrieval flow""" + # Store tokens + token_manager.store_tokens( + self.request, + "test_access", + { + "access_token": "test_access", + "refresh_token": "test_refresh", + "expires_in": 3600 + } + ) + + # Verify storage + self.assertEqual(token_manager.get_access_token(self.request), "test_access") + self.assertTrue(token_manager.TOKEN_EXPIRES_AT_KEY in self.request.session) + + # Verify encryption + encrypted = self.request.session[token_manager.ACCESS_TOKEN_KEY] + self.assertNotEqual(encrypted, "test_access") + self.assertEqual(token_manager.decrypt_token(encrypted), "test_access") + + def test_token_refresh_flow(self): + """Test the complete token refresh flow""" + # Setup expired token + token_manager.store_tokens( + self.request, + "old_access", + { + "access_token": "old_access", + "refresh_token": "old_refresh", + "expires_in": 60 # Will trigger refresh + } + ) + + # Mock refresh response + with patch("django_auth_adfs.token_manager.provider_config") as mock_config: + mock_response = Mock(status_code=200) + mock_response.json.return_value = { + "access_token": "new_access", + "refresh_token": "new_refresh", + "expires_in": 3600 + } + mock_config.session.post.return_value = mock_response + mock_config.token_endpoint = "https://example.com/token" + + # Trigger refresh via middleware + self.middleware(self.request) + + # Verify tokens were updated self.assertEqual( - _decrypt_token(self.request.session["ADFS_OBO_ACCESS_TOKEN"]), - "new_obo_token", + token_manager.get_access_token(self.request), + "new_access" ) - - # Verify the expiry time was updated - expires_at = datetime.datetime.fromisoformat( - self.request.session["ADFS_OBO_TOKEN_EXPIRES_AT"] - ) - now = datetime.datetime.now() - self.assertTrue( - (expires_at - now).total_seconds() > 0, - "Token expiry time should be in the future" - ) - finally: - # Restore the original method - self.middleware._refresh_obo_token = original_refresh_obo_token - - def test_refresh_obo_token_failure(self): - """Test OBO token refresh when it fails""" - # Ensure OBO token storage is enabled - self.middleware.store_obo_token = True - # Set up session with expired OBO token - self.request.session["ADFS_ACCESS_TOKEN"] = _encrypt_token("valid_access_token") - self.request.session["ADFS_OBO_ACCESS_TOKEN"] = _encrypt_token( - "expired_obo_token" + def test_obo_token_management(self): + """Test OBO token functionality when enabled""" + # Store regular token + token_manager.store_tokens( + self.request, + "test_access", + {"access_token": "test_access", "expires_in": 3600} ) - self.request.session["ADFS_OBO_TOKEN_EXPIRES_AT"] = ( - datetime.datetime.now() - datetime.timedelta(minutes=5) - ).isoformat() - # Store original session state to verify it's not modified - original_session_data = dict(self.request.session) - self.request.session.modified = False - - # Save the original method - original_refresh_obo_token = self.middleware._refresh_obo_token - - # Create a mock implementation that simulates failure - def mock_refresh_obo_token(request): - # This simulates a failed token refresh - no changes to session - pass + # Mock OBO flow + with patch("django_auth_adfs.backend.AdfsBaseBackend") as mock_backend: + mock_backend.return_value.get_obo_access_token.return_value = "test_obo" - # Replace the method with our mock - self.middleware._refresh_obo_token = mock_refresh_obo_token - - try: - # Call the method directly - self.middleware._refresh_obo_token(self.request) - - # Verify session was not modified - self.assertEqual(dict(self.request.session), original_session_data) - self.assertFalse(self.request.session.modified) - finally: - # Restore the original method - self.middleware._refresh_obo_token = original_refresh_obo_token - - def test_obo_token_without_access_token(self): - """Test OBO token handling when access token is missing""" - # Only OBO token exists - self.request.session["ADFS_OBO_ACCESS_TOKEN"] = "obo_token" - self.request.session["ADFS_OBO_TOKEN_EXPIRES_AT"] = ( - datetime.datetime.now().isoformat() - ) - # No ADFS_ACCESS_TOKEN - - # Store original session state to verify it's not modified - original_session_data = dict(self.request.session) - self.request.session.modified = False - - self.middleware._refresh_obo_token(self.request) - - # Verify session not modified - self.assertEqual(dict(self.request.session), original_session_data) - self.assertFalse(self.request.session.modified) - - # Group 5: Authentication Signal Tests - - def test_capture_tokens_from_auth(self): - """Test capturing tokens during authentication""" - sender = Mock() - sender.access_token = "sender_access_token" - sender.get_obo_access_token.return_value = "obo_token" - - adfs_response = { - "access_token": "response_access_token", - "refresh_token": "response_refresh_token", - "expires_in": 3600, - } - - # Create a request with a session - request = self.factory.get("/") - request.session = SessionStore() - - self.middleware._capture_tokens_from_auth( - sender=sender, - user=self.user, - claims={}, - adfs_response=adfs_response, - request=request - ) - - # Check tokens were stored in the session - self.assertEqual( - _decrypt_token(request.session["ADFS_ACCESS_TOKEN"]), - "sender_access_token", - ) - self.assertEqual( - _decrypt_token(request.session["ADFS_REFRESH_TOKEN"]), - "response_refresh_token", - ) - self.assertEqual( - _decrypt_token(request.session["ADFS_OBO_ACCESS_TOKEN"]), - "obo_token", - ) - self.assertTrue("ADFS_TOKEN_EXPIRES_AT" in request.session) - self.assertTrue("ADFS_OBO_TOKEN_EXPIRES_AT" in request.session) - - def test_capture_tokens_from_adfs_response_only(self): - """Test capturing tokens when they're only in the ADFS response, not on sender""" - sender = Mock(spec=[]) # Create a mock without access_token attribute - # Ensure get_obo_access_token is available but returns None - sender.get_obo_access_token = Mock(return_value=None) - - adfs_response = { - "access_token": "response_access_token", - "refresh_token": "response_refresh_token", - "expires_in": 3600, - } - - # Create a request with a session - request = self.factory.get("/") - request.session = SessionStore() - - self.middleware._capture_tokens_from_auth( - sender=sender, - user=self.user, - claims={}, - adfs_response=adfs_response, - request=request - ) - - # Check tokens were stored in the session - self.assertEqual( - _decrypt_token(request.session["ADFS_ACCESS_TOKEN"]), - "response_access_token", - ) - self.assertEqual( - _decrypt_token(request.session["ADFS_REFRESH_TOKEN"]), - "response_refresh_token", - ) - self.assertTrue("ADFS_TOKEN_EXPIRES_AT" in request.session) - self.assertFalse("ADFS_OBO_ACCESS_TOKEN" in request.session) - self.assertFalse("ADFS_OBO_TOKEN_EXPIRES_AT" in request.session) - - # Group 6: Middleware Call Tests - - def test_middleware_call_with_authenticated_user(self): - """Test the complete middleware request/response cycle with authenticated user""" - # Create request with authenticated user - request = self.factory.get("/") - request.user = self.user - request.session = SessionStore() - - # Add tokens directly to the session - access_token = "test_access_token" - refresh_token = "test_refresh_token" - expires_at = datetime.datetime.now() + datetime.timedelta(hours=1) - - request.session["ADFS_ACCESS_TOKEN"] = _encrypt_token(access_token) - request.session["ADFS_REFRESH_TOKEN"] = _encrypt_token(refresh_token) - request.session["ADFS_TOKEN_EXPIRES_AT"] = expires_at.isoformat() - request.session.modified = True - - # Call middleware - response = self.middleware(request) - - # Check that tokens are still in the session - self.assertEqual( - _decrypt_token(request.session["ADFS_ACCESS_TOKEN"]), access_token - ) - self.assertEqual( - _decrypt_token(request.session["ADFS_REFRESH_TOKEN"]), refresh_token - ) - self.assertEqual( - request.session["ADFS_TOKEN_EXPIRES_AT"], expires_at.isoformat() - ) - - def test_middleware_post_response_token_storage(self): - """Test tokens added during authentication are stored in the session""" - # Create a mock sender and adfs_response for the signal - sender = Mock() - sender.access_token = "view_added_token" - sender.get_obo_access_token.return_value = None - - adfs_response = { - "expires_in": 3600, - } - - # Create request with authenticated user - request = self.factory.get("/") - request.user = self.user - request.session = SessionStore() - - # Create a get_response function that simulates authentication - def get_response_with_auth_signal(request): - # Simulate authentication by calling the signal handler - self.middleware._capture_tokens_from_auth( - sender=sender, - user=request.user, - claims={}, - adfs_response=adfs_response, - request=request - ) - return Mock() - - # Create middleware with our custom get_response - middleware = TokenLifecycleMiddleware(get_response_with_auth_signal) - - # Call middleware - response = middleware(request) - - # Check that tokens were stored in session - self.assertEqual( - _decrypt_token(request.session["ADFS_ACCESS_TOKEN"]), "view_added_token" - ) - self.assertTrue("ADFS_TOKEN_EXPIRES_AT" in request.session) - - def test_middleware_without_user(self): - """Test middleware behavior when request has no user""" - request = self.factory.get("/") - request.session = SessionStore() - - response = self.middleware(request) - # Should not raise any errors - self.assertEqual(response, request) - - def test_middleware_with_unauthenticated_user(self): - """Test middleware behavior with unauthenticated user""" - request = self.factory.get("/") - request.user = Mock(is_authenticated=False) - request.session = SessionStore() - - with patch.object(self.middleware, "_handle_token_refresh") as mock_refresh: - response = self.middleware(request) - mock_refresh.assert_not_called() - - # Group 7: Error Handling Tests - - def test_handle_malformed_expiry_time(self): - """Test handling of malformed expiry time in session""" - self.request.session["ADFS_ACCESS_TOKEN"] = "test_access_token" - self.request.session["ADFS_REFRESH_TOKEN"] = "test_refresh_token" - self.request.session["ADFS_TOKEN_EXPIRES_AT"] = "invalid_datetime" - - # Store original session state to verify it's not modified inappropriately - original_session_data = dict(self.request.session) - self.request.session.modified = False - - # Should handle gracefully without error - self.middleware._handle_token_refresh(self.request) - - # Verify session wasn't modified inappropriately - self.assertEqual(dict(self.request.session), original_session_data) - self.assertFalse(self.request.session.modified) - - def test_handle_incomplete_token_state(self): - """Test handling when only some token data exists in session""" - # Only access token, no refresh token - self.request.session["ADFS_ACCESS_TOKEN"] = "test_access_token" - self.request.session["ADFS_TOKEN_EXPIRES_AT"] = ( - datetime.datetime.now().isoformat() - ) - # Missing ADFS_REFRESH_TOKEN - - # Store original session state to verify it's not modified inappropriately - original_session_data = dict(self.request.session) - self.request.session.modified = False - - self.middleware._handle_token_refresh(self.request) - - # Verify session wasn't modified inappropriately - self.assertEqual(dict(self.request.session), original_session_data) - self.assertFalse(self.request.session.modified) - - def test_handle_malformed_tokens(self): - """Test handling of malformed/corrupt token data in session""" - # Invalid token format - self.request.session["ADFS_ACCESS_TOKEN"] = {"malformed": "data"} - self.request.session["ADFS_REFRESH_TOKEN"] = None - self.request.session["ADFS_TOKEN_EXPIRES_AT"] = "not-a-date" - - # Store original session state to verify it's not modified inappropriately - original_session_data = dict(self.request.session) - self.request.session.modified = False - - self.middleware._handle_token_refresh(self.request) - - # Verify session wasn't modified inappropriately - self.assertEqual(dict(self.request.session), original_session_data) - self.assertFalse(self.request.session.modified) - - def test_disabled_obo_token_functionality(self): - """Test that OBO token functionality is disabled when STORE_OBO_TOKEN is False""" - # Create a mock sender and adfs_response - sender = Mock() - sender.access_token = "test_access_token" - sender.get_obo_access_token.return_value = "test_obo_token" - - adfs_response = { - "refresh_token": "test_refresh_token", - "expires_in": 3600, - } - - # Patch the middleware to disable OBO token storage - with patch.object(self.middleware, "store_obo_token", False): - # Call the signal handler - self.middleware._capture_tokens_from_auth( - sender=sender, - user=self.user, - claims={}, - adfs_response=adfs_response, - request=self.request + # Verify OBO token storage and retrieval + self.request.session[token_manager.OBO_ACCESS_TOKEN_KEY] = \ + token_manager.encrypt_token("test_obo") + self.request.session[token_manager.OBO_TOKEN_EXPIRES_AT_KEY] = \ + (datetime.datetime.now() + datetime.timedelta(hours=1)).isoformat() + + self.assertEqual(token_manager.get_obo_access_token(self.request), "test_obo") + + def test_error_handling(self): + """Test error handling in various scenarios""" + # Test invalid data handling + self.assertIsNone(token_manager.decrypt_token("invalid_data")) + self.assertIsNone(token_manager.encrypt_token(None)) + + # Test refresh failure + with patch("django_auth_adfs.token_manager.provider_config") as mock_config: + # Setup expired tokens first + token_manager.store_tokens( + self.request, + "old_access", + { + "access_token": "old_access", + "refresh_token": "old_refresh", + "expires_in": -60 # Already expired + } ) - - # Verify access token is stored but OBO token is not - self.assertTrue("ADFS_ACCESS_TOKEN" in self.request.session) - self.assertFalse("ADFS_OBO_ACCESS_TOKEN" in self.request.session) - - # Verify get_obo_access_token returns None when disabled - with patch("django_auth_adfs.utils.settings") as mock_settings: - mock_settings.STORE_OBO_TOKEN = False - self.assertIsNone(get_obo_access_token(self.request)) - - def test_token_encryption(self): - """Test that tokens are properly encrypted and decrypted""" - # Test encryption and decryption directly - original_token = "test_access_token" - encrypted_token = _encrypt_token(original_token) - - # Verify the token is encrypted (should be different from original) - self.assertNotEqual(original_token, encrypted_token) - - # Verify the token can be decrypted back to the original - decrypted_token = _decrypt_token(encrypted_token) - self.assertEqual(original_token, decrypted_token) - - # Test the middleware stores encrypted tokens - sender = Mock() - sender.access_token = original_token - sender.get_obo_access_token.return_value = None - - # Call the signal handler - self.middleware._capture_tokens_from_auth( - sender=sender, - user=self.user, - claims={}, - adfs_response={}, - request=self.request - ) - - # Verify the token in the session is encrypted - session_token = self.request.session.get("ADFS_ACCESS_TOKEN") - self.assertNotEqual(original_token, session_token) - - # Test the utility function decrypts the token - retrieved_token = get_access_token(self.request) - self.assertEqual(original_token, retrieved_token) - - # Test with OBO token - original_obo_token = "test_obo_token" - sender.get_obo_access_token.return_value = original_obo_token - - # Call the signal handler - self.middleware._capture_tokens_from_auth( - sender=sender, - user=self.user, - claims={}, - adfs_response={}, - request=self.request - ) - - # Verify the OBO token in the session is encrypted - session_obo_token = self.request.session.get("ADFS_OBO_ACCESS_TOKEN") - self.assertNotEqual(original_obo_token, session_obo_token) - - # Test the utility function decrypts the OBO token - retrieved_obo_token = get_obo_access_token(self.request) - self.assertEqual(original_obo_token, retrieved_obo_token) - - @override_settings(TOKEN_ENCRYPTION_SALT="custom-salt-for-testing") - def test_custom_encryption_salt(self): - """Test that custom encryption salt changes the encrypted token value""" - # First, encrypt a token with the default salt - original_token = "test_access_token" - default_encrypted_token = _encrypt_token(original_token) - - # Now, encrypt the same token with a custom salt (set via override_settings) - with patch("django_auth_adfs.utils.settings") as mock_settings: - mock_settings.TOKEN_ENCRYPTION_SALT = "custom-salt-for-testing" - custom_encrypted_token = _encrypt_token(original_token) - - # The encrypted tokens should be different due to different salts - self.assertNotEqual(default_encrypted_token, custom_encrypted_token) - - # But both should decrypt to the original token when using the correct salt - with patch("django_auth_adfs.utils.settings") as mock_settings: - mock_settings.TOKEN_ENCRYPTION_SALT = "custom-salt-for-testing" - decrypted_token = _decrypt_token(custom_encrypted_token) - - self.assertEqual(original_token, decrypted_token) - - # A token encrypted with one salt should not be decryptable with another - with patch("django_auth_adfs.utils.settings") as mock_settings: - mock_settings.TOKEN_ENCRYPTION_SALT = "different-salt" - # The function catches exceptions and returns None, so check for None - self.assertIsNone(_decrypt_token(custom_encrypted_token)) - - @patch("django_auth_adfs.middleware.provider_config") - def test_refresh_token_failure_with_logout(self, mock_provider_config): - """Test token refresh failure with LOGOUT_ON_TOKEN_REFRESH_FAILURE enabled""" - # Setup - self.request.session["ADFS_ACCESS_TOKEN"] = _encrypt_token("test_access_token") - self.request.session["ADFS_REFRESH_TOKEN"] = _encrypt_token( - "test_refresh_token" - ) - expires_at = datetime.datetime.now() - datetime.timedelta(minutes=5) - self.request.session["ADFS_TOKEN_EXPIRES_AT"] = expires_at.isoformat() - - # Mock the response from the token endpoint - mock_response = Mock() - mock_response.status_code = 400 - mock_response.text = "Invalid refresh token" - mock_provider_config.session.post.return_value = mock_response - - # Enable the setting - with patch("django_auth_adfs.middleware.settings") as mock_settings: - mock_settings.LOGOUT_ON_TOKEN_REFRESH_FAILURE = True - mock_settings.CLIENT_ID = "test_client_id" - mock_settings.CLIENT_SECRET = "test_client_secret" - mock_settings.TIMEOUT = 5 - - # Mock the logout function - with patch("django.contrib.auth.logout") as mock_logout: - self.middleware._refresh_tokens(self.request) - - # Verify logout was called - mock_logout.assert_called_once_with(self.request) - - @patch("django_auth_adfs.middleware.provider_config") - def test_refresh_token_failure_without_logout(self, mock_provider_config): - """Test token refresh failure with LOGOUT_ON_TOKEN_REFRESH_FAILURE disabled""" - # Setup - self.request.session["ADFS_ACCESS_TOKEN"] = _encrypt_token("test_access_token") - self.request.session["ADFS_REFRESH_TOKEN"] = _encrypt_token( - "test_refresh_token" - ) - expires_at = datetime.datetime.now() - datetime.timedelta(minutes=5) - self.request.session["ADFS_TOKEN_EXPIRES_AT"] = expires_at.isoformat() - - # Mock the response from the token endpoint - mock_response = Mock() - mock_response.status_code = 400 - mock_response.text = "Invalid refresh token" - mock_provider_config.session.post.return_value = mock_response - - # Disable the setting (default) - with patch("django_auth_adfs.middleware.settings") as mock_settings: - mock_settings.LOGOUT_ON_TOKEN_REFRESH_FAILURE = False - mock_settings.CLIENT_ID = "test_client_id" - mock_settings.CLIENT_SECRET = "test_client_secret" - mock_settings.TIMEOUT = 5 - - # Mock the logout function - with patch("django.contrib.auth.logout") as mock_logout: - self.middleware._refresh_tokens(self.request) - - # Verify logout was not called - mock_logout.assert_not_called() - - @patch("django_auth_adfs.middleware.provider_config") - def test_refresh_token_exception_with_logout(self, mock_provider_config): - """Test token refresh exception with LOGOUT_ON_TOKEN_REFRESH_FAILURE enabled""" - # Setup - self.request.session["ADFS_ACCESS_TOKEN"] = _encrypt_token("test_access_token") - self.request.session["ADFS_REFRESH_TOKEN"] = _encrypt_token( - "test_refresh_token" - ) - expires_at = datetime.datetime.now() - datetime.timedelta(minutes=5) - self.request.session["ADFS_TOKEN_EXPIRES_AT"] = expires_at.isoformat() - - # Make the request raise an exception - mock_provider_config.session.post.side_effect = Exception("Connection error") - - # Enable the setting - with patch("django_auth_adfs.middleware.settings") as mock_settings: - mock_settings.LOGOUT_ON_TOKEN_REFRESH_FAILURE = True - mock_settings.CLIENT_ID = "test_client_id" - mock_settings.CLIENT_SECRET = "test_client_secret" - mock_settings.TIMEOUT = 5 - - # Mock the logout function - with patch("django.contrib.auth.logout") as mock_logout: - self.middleware._refresh_tokens(self.request) - - # Verify logout was called - mock_logout.assert_called_once_with(self.request) - - def test_handle_token_refresh_calls_refresh_obo_token(self): - """ - Test that _handle_token_refresh calls _refresh_obo_token when the OBO token is expired. - """ - # Ensure OBO token storage is enabled - self.middleware.store_obo_token = True - - # Set up session with valid access token but expired OBO token - self.request.session["ADFS_ACCESS_TOKEN"] = _encrypt_token("valid_access_token") - self.request.session["ADFS_REFRESH_TOKEN"] = _encrypt_token("valid_refresh_token") - self.request.session["ADFS_OBO_ACCESS_TOKEN"] = _encrypt_token("expired_obo_token") - - # Set access token to not expire soon - self.request.session["ADFS_TOKEN_EXPIRES_AT"] = ( - datetime.datetime.now() + datetime.timedelta(hours=1) - ).isoformat() - - # Set OBO token to be expired - expired_time = datetime.datetime.now() - datetime.timedelta(minutes=5) - self.request.session["ADFS_OBO_TOKEN_EXPIRES_AT"] = expired_time.isoformat() - - # Save the original method - original_refresh_obo_token = self.middleware._refresh_obo_token - - # Create a spy function to track if the method is called - refresh_called = [False] - - def spy_refresh_obo_token(request): - refresh_called[0] = True - # Simulate successful refresh - request.session["ADFS_OBO_ACCESS_TOKEN"] = _encrypt_token("new_obo_token") - request.session["ADFS_OBO_TOKEN_EXPIRES_AT"] = ( - datetime.datetime.now() + datetime.timedelta(hours=1) - ).isoformat() - # Replace the method with our spy - self.middleware._refresh_obo_token = spy_refresh_obo_token - + mock_config.session.post.return_value = Mock(status_code=400, text="Error") + mock_config.token_endpoint = "https://example.com/token" + + token_manager.logout_on_refresh_failure = True + try: + with patch("django_auth_adfs.token_manager.logout") as mock_logout: + token_manager.refresh_tokens(self.request) + mock_logout.assert_called_once_with(self.request) + finally: + token_manager.logout_on_refresh_failure = False + + def test_signed_cookies_handling(self): + """Test behavior with signed cookies session backend""" + token_manager.using_signed_cookies = True try: - # Call handle token refresh - self.middleware._handle_token_refresh(self.request) - - # Verify _refresh_obo_token was called - self.assertTrue(refresh_called[0], - "_refresh_obo_token should be called when OBO token is expired") - - # Verify the token was updated - self.assertEqual( - _decrypt_token(self.request.session["ADFS_OBO_ACCESS_TOKEN"]), - "new_obo_token" + success = token_manager.store_tokens( + self.request, "test_token", {"refresh_token": "test_refresh"} ) + self.assertFalse(success) + self.assertFalse(token_manager.ACCESS_TOKEN_KEY in self.request.session) finally: - # Restore the original method - self.middleware._refresh_obo_token = original_refresh_obo_token + token_manager.using_signed_cookies = False + + def test_middleware_integration(self): + """Test TokenLifecycleMiddleware integration""" + # Test with unauthenticated user + self.request.user = AnonymousUser() + response = self.middleware(self.request) + self.assertIsNone(response) # Middleware should pass through + + # Test with authenticated user + self.request.user = self.user + with patch.object(token_manager, "check_token_expiration") as mock_check: + self.middleware(self.request) + mock_check.assert_called_once_with(self.request) From 63009a0256dfaa879bbdbd72b2ba47694768173e Mon Sep 17 00:00:00 2001 From: tnware Date: Sun, 9 Mar 2025 12:09:05 -0700 Subject: [PATCH 4/9] Cleanup implementation * make middleware check more rigid * token integrity validation * don't be so strict on refresh tokens --- django_auth_adfs/token_manager.py | 75 ++++++++--- tests/test_middleware.py | 199 ++++++++++++++++++++++++++---- 2 files changed, 238 insertions(+), 36 deletions(-) diff --git a/django_auth_adfs/token_manager.py b/django_auth_adfs/token_manager.py index acfe31dc..ab15a0a4 100644 --- a/django_auth_adfs/token_manager.py +++ b/django_auth_adfs/token_manager.py @@ -61,11 +61,9 @@ def __init__(self): def is_middleware_enabled(self): """Check if the TokenLifecycleMiddleware is enabled.""" + EXPECTED_MIDDLEWARE = 'django_auth_adfs.middleware.TokenLifecycleMiddleware' try: - for middleware in django_settings.MIDDLEWARE: - if middleware.endswith('TokenLifecycleMiddleware'): - return True - return False + return EXPECTED_MIDDLEWARE in django_settings.MIDDLEWARE except Exception as e: logger.warning(f"Error checking if middleware is enabled: {e}") return False @@ -208,37 +206,78 @@ def get_obo_access_token(self, request): encrypted_token = request.session.get(self.OBO_ACCESS_TOKEN_KEY) return self.decrypt_token(encrypted_token) + def validate_token_format(self, token): + """ + Basic validation of token format before storage. + + Args: + token (str): Token to validate + + Returns: + bool: True if token appears valid, False otherwise + """ + if not isinstance(token, str): + return False + + try: + # Check if it's a valid JWT format + parts = token.split('.') + if len(parts) != 3: + return False + + # Check if each part is valid base64 + for part in parts: + base64.urlsafe_b64decode(part + '=' * (-len(part) % 4)) + + return True + except Exception: + return False + def store_tokens(self, request, access_token, adfs_response=None): """ Store tokens in the session. Args: request: The current request object - access_token (str): The access token to store + access_token (str): The access token to store (must be a JWT) adfs_response (dict, optional): The full response from ADFS containing refresh token and expiration Returns: bool: True if tokens were stored, False otherwise """ if not self.should_store_tokens(request): + logger.debug("Token storage is disabled") + return False + + if not self.validate_token_format(access_token): + logger.warning("Invalid access token format, refusing to store") return False try: session_modified = False - # Store access token + # Store access token (JWT) encrypted_token = self.encrypt_token(access_token) if encrypted_token: request.session[self.ACCESS_TOKEN_KEY] = encrypted_token session_modified = True + logger.debug("Stored access token") - # Store refresh token + # Store refresh token (can be any string) if adfs_response and "refresh_token" in adfs_response: refresh_token = adfs_response["refresh_token"] - encrypted_token = self.encrypt_token(refresh_token) - if encrypted_token: - request.session[self.REFRESH_TOKEN_KEY] = encrypted_token - session_modified = True + if refresh_token: # Just check it's not empty + encrypted_token = self.encrypt_token(refresh_token) + if encrypted_token: + request.session[self.REFRESH_TOKEN_KEY] = encrypted_token + session_modified = True + logger.debug("Stored refresh token") + else: + logger.warning("Failed to encrypt refresh token") + else: + logger.warning("Empty refresh token received from ADFS") + else: + logger.debug("No refresh token in ADFS response") # Store token expiration if adfs_response and "expires_in" in adfs_response: @@ -247,8 +286,9 @@ def store_tokens(self, request, access_token, adfs_response=None): ) request.session[self.TOKEN_EXPIRES_AT_KEY] = expires_at.isoformat() session_modified = True + logger.debug("Stored token expiration") - # Store OBO token if enabled + # Store OBO token if enabled (must be JWT) if self.store_obo_token: try: # Import here to avoid circular imports @@ -256,21 +296,23 @@ def store_tokens(self, request, access_token, adfs_response=None): backend = AdfsBaseBackend() obo_token = backend.get_obo_access_token(access_token) - if obo_token: + if obo_token and self.validate_token_format(obo_token): encrypted_token = self.encrypt_token(obo_token) if encrypted_token: request.session[self.OBO_ACCESS_TOKEN_KEY] = encrypted_token obo_expires_at = datetime.datetime.now() + datetime.timedelta(hours=1) request.session[self.OBO_TOKEN_EXPIRES_AT_KEY] = obo_expires_at.isoformat() session_modified = True + logger.debug("Stored OBO token") except Exception as e: logger.warning(f"Error getting OBO token: {e}") if session_modified: request.session.modified = True - logger.debug("Stored tokens in session") + logger.debug("All tokens stored successfully") return True + logger.warning("No tokens were stored") return False except Exception as e: @@ -323,6 +365,7 @@ def check_token_expiration(self, request): def refresh_tokens(self, request): """ Refresh the access token using the refresh token. + Args: request: The current request object @@ -353,7 +396,6 @@ def refresh_tokens(self, request): if settings.CLIENT_SECRET: data["client_secret"] = settings.CLIENT_SECRET - # Ensure token_endpoint is a string token_endpoint = provider_config.token_endpoint if token_endpoint is None: logger.error("Token endpoint is None, cannot refresh tokens") @@ -365,6 +407,9 @@ def refresh_tokens(self, request): if response.status_code == 200: token_data = response.json() + + # Store new tokens - if another refresh happened, these will just overwrite + # with fresher tokens, which is fine request.session[self.ACCESS_TOKEN_KEY] = self.encrypt_token( token_data["access_token"] ) diff --git a/tests/test_middleware.py b/tests/test_middleware.py index cfaf1a78..2229d217 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -3,6 +3,8 @@ """ import datetime +import json +import base64 from unittest.mock import Mock, patch import time @@ -22,6 +24,39 @@ "django_auth_adfs.middleware.TokenLifecycleMiddleware", ) +def create_test_token(claims=None, exp_delta=3600): + """Create a test JWT token with the given claims and expiration delta.""" + if claims is None: + claims = {} + + # Create a basic JWT token with ADFS-like structure + header = { + "typ": "JWT", + "alg": "RS256", + "x5t": "example-thumbprint" + } + + # Add standard ADFS claims if not present + now = int(time.time()) + if "iat" not in claims: + claims["iat"] = now + if "exp" not in claims: + claims["exp"] = now + exp_delta + if "aud" not in claims: + claims["aud"] = "microsoft:identityserver:your-RelyingPartyTrust-identifier" + if "iss" not in claims: + claims["iss"] = "https://sts.windows.net/01234567-89ab-cdef-0123-456789abcdef/" + if "sub" not in claims: + claims["sub"] = "john.doe@example.com" + + # Encode each part + header_part = base64.urlsafe_b64encode(json.dumps(header).encode()).rstrip(b"=").decode() + claims_part = base64.urlsafe_b64encode(json.dumps(claims).encode()).rstrip(b"=").decode() + signature_part = base64.urlsafe_b64encode(b"test_signature").rstrip(b"=").decode() + + # Combine parts + return f"{header_part}.{claims_part}.{signature_part}" + @override_settings(MIDDLEWARE=MIDDLEWARE_WITH_TOKEN_LIFECYCLE) class TokenLifecycleTests(TestCase): @@ -50,35 +85,43 @@ def test_settings_configuration(self): def test_token_storage_and_retrieval(self): """Test the complete token storage and retrieval flow""" + access_token = create_test_token({"type": "access"}) + refresh_token = create_test_token({"type": "refresh"}) + # Store tokens token_manager.store_tokens( self.request, - "test_access", + access_token, { - "access_token": "test_access", - "refresh_token": "test_refresh", + "access_token": access_token, + "refresh_token": refresh_token, "expires_in": 3600 } ) # Verify storage - self.assertEqual(token_manager.get_access_token(self.request), "test_access") + self.assertEqual(token_manager.get_access_token(self.request), access_token) self.assertTrue(token_manager.TOKEN_EXPIRES_AT_KEY in self.request.session) # Verify encryption encrypted = self.request.session[token_manager.ACCESS_TOKEN_KEY] - self.assertNotEqual(encrypted, "test_access") - self.assertEqual(token_manager.decrypt_token(encrypted), "test_access") + self.assertNotEqual(encrypted, access_token) + self.assertEqual(token_manager.decrypt_token(encrypted), access_token) def test_token_refresh_flow(self): """Test the complete token refresh flow""" + old_access_token = create_test_token({"type": "access"}, exp_delta=60) + old_refresh_token = create_test_token({"type": "refresh"}) + new_access_token = create_test_token({"type": "access"}) + new_refresh_token = create_test_token({"type": "refresh"}) + # Setup expired token token_manager.store_tokens( self.request, - "old_access", + old_access_token, { - "access_token": "old_access", - "refresh_token": "old_refresh", + "access_token": old_access_token, + "refresh_token": old_refresh_token, "expires_in": 60 # Will trigger refresh } ) @@ -87,8 +130,8 @@ def test_token_refresh_flow(self): with patch("django_auth_adfs.token_manager.provider_config") as mock_config: mock_response = Mock(status_code=200) mock_response.json.return_value = { - "access_token": "new_access", - "refresh_token": "new_refresh", + "access_token": new_access_token, + "refresh_token": new_refresh_token, "expires_in": 3600 } mock_config.session.post.return_value = mock_response @@ -100,29 +143,32 @@ def test_token_refresh_flow(self): # Verify tokens were updated self.assertEqual( token_manager.get_access_token(self.request), - "new_access" + new_access_token ) def test_obo_token_management(self): """Test OBO token functionality when enabled""" + access_token = create_test_token({"type": "access"}) + obo_token = create_test_token({"type": "obo"}) + # Store regular token token_manager.store_tokens( self.request, - "test_access", - {"access_token": "test_access", "expires_in": 3600} + access_token, + {"access_token": access_token, "expires_in": 3600} ) # Mock OBO flow with patch("django_auth_adfs.backend.AdfsBaseBackend") as mock_backend: - mock_backend.return_value.get_obo_access_token.return_value = "test_obo" + mock_backend.return_value.get_obo_access_token.return_value = obo_token # Verify OBO token storage and retrieval self.request.session[token_manager.OBO_ACCESS_TOKEN_KEY] = \ - token_manager.encrypt_token("test_obo") + token_manager.encrypt_token(obo_token) self.request.session[token_manager.OBO_TOKEN_EXPIRES_AT_KEY] = \ (datetime.datetime.now() + datetime.timedelta(hours=1)).isoformat() - self.assertEqual(token_manager.get_obo_access_token(self.request), "test_obo") + self.assertEqual(token_manager.get_obo_access_token(self.request), obo_token) def test_error_handling(self): """Test error handling in various scenarios""" @@ -131,14 +177,17 @@ def test_error_handling(self): self.assertIsNone(token_manager.encrypt_token(None)) # Test refresh failure + access_token = create_test_token({"type": "access"}, exp_delta=-60) + refresh_token = create_test_token({"type": "refresh"}) + with patch("django_auth_adfs.token_manager.provider_config") as mock_config: # Setup expired tokens first token_manager.store_tokens( self.request, - "old_access", + access_token, { - "access_token": "old_access", - "refresh_token": "old_refresh", + "access_token": access_token, + "refresh_token": refresh_token, "expires_in": -60 # Already expired } ) @@ -156,10 +205,15 @@ def test_error_handling(self): def test_signed_cookies_handling(self): """Test behavior with signed cookies session backend""" + test_token = create_test_token({"type": "access"}) + refresh_token = create_test_token({"type": "refresh"}) + token_manager.using_signed_cookies = True try: success = token_manager.store_tokens( - self.request, "test_token", {"refresh_token": "test_refresh"} + self.request, + test_token, + {"refresh_token": refresh_token} ) self.assertFalse(success) self.assertFalse(token_manager.ACCESS_TOKEN_KEY in self.request.session) @@ -178,3 +232,106 @@ def test_middleware_integration(self): with patch.object(token_manager, "check_token_expiration") as mock_check: self.middleware(self.request) mock_check.assert_called_once_with(self.request) + + def test_middleware_detection(self): + """Test middleware enabled detection""" + # Test with correct middleware path + with patch('django.conf.settings.MIDDLEWARE', [ + 'django.contrib.sessions.middleware.SessionMiddleware', + 'django_auth_adfs.middleware.TokenLifecycleMiddleware' + ]): + self.assertTrue(token_manager.is_middleware_enabled()) + + # Test with incorrect middleware path + with patch('django.conf.settings.MIDDLEWARE', [ + 'django.contrib.sessions.middleware.SessionMiddleware', + 'some_other_package.TokenLifecycleMiddleware', # Wrong package + 'django_auth_adfs.middleware.SomeOtherMiddleware', # Wrong middleware + 'django_auth_adfs.TokenLifecycleMiddleware', # Wrong path + ]): + self.assertFalse(token_manager.is_middleware_enabled()) + + def test_clear_tokens(self): + """Test clearing tokens from session""" + access_token = create_test_token({"type": "access"}) + refresh_token = create_test_token({"type": "refresh"}) + + # Store some tokens first + token_manager.store_tokens( + self.request, + access_token, + { + "access_token": access_token, + "refresh_token": refresh_token, + "expires_in": 3600 + } + ) + + # Verify tokens were stored + self.assertTrue(token_manager.ACCESS_TOKEN_KEY in self.request.session) + self.assertTrue(token_manager.REFRESH_TOKEN_KEY in self.request.session) + + # Clear tokens + success = token_manager.clear_tokens(self.request) + self.assertTrue(success) + + # Verify tokens were cleared + self.assertFalse(token_manager.ACCESS_TOKEN_KEY in self.request.session) + self.assertFalse(token_manager.REFRESH_TOKEN_KEY in self.request.session) + self.assertFalse(token_manager.TOKEN_EXPIRES_AT_KEY in self.request.session) + self.assertFalse(token_manager.OBO_ACCESS_TOKEN_KEY in self.request.session) + self.assertFalse(token_manager.OBO_TOKEN_EXPIRES_AT_KEY in self.request.session) + + def test_refresh_obo_token_directly(self): + """Test direct OBO token refresh""" + access_token = create_test_token({"type": "access"}) + new_obo_token = create_test_token({"type": "obo"}) + + # Store access token first + token_manager.store_tokens( + self.request, + access_token, + {"access_token": access_token, "expires_in": 3600} + ) + + # Mock OBO token acquisition and provider config + with patch("django_auth_adfs.backend.AdfsBaseBackend") as mock_backend, \ + patch("django_auth_adfs.token_manager.provider_config") as mock_provider: + + mock_backend.return_value.get_obo_access_token.return_value = new_obo_token + mock_provider.load_config.return_value = None + mock_provider.token_endpoint = "https://example.com/token" + mock_provider.session.verify = False # Disable cert validation + + # Refresh OBO token + success = token_manager.refresh_obo_token(self.request) + self.assertTrue(success) + + # Verify new OBO token was stored + obo_token = token_manager.get_obo_access_token(self.request) + self.assertEqual(obo_token, new_obo_token) + self.assertTrue(token_manager.OBO_TOKEN_EXPIRES_AT_KEY in self.request.session) + + def test_should_store_tokens_edge_cases(self): + """Test edge cases for token storage decisions""" + # Test with no request + self.assertFalse(token_manager.should_store_tokens(None)) + + # Test with request but no session + request_without_session = self.factory.get("/") + # Instead of deleting session attribute that doesn't exist, + # we'll create a Mock object with no session attribute + from unittest.mock import Mock + request_without_session = Mock(spec=[]) # Empty spec means no attributes + self.assertFalse(token_manager.should_store_tokens(request_without_session)) + + # Test with signed cookies + token_manager.using_signed_cookies = True + try: + self.assertFalse(token_manager.should_store_tokens(self.request)) + finally: + token_manager.using_signed_cookies = False + + # Test with middleware disabled + with patch.object(token_manager, "is_middleware_enabled", return_value=False): + self.assertFalse(token_manager.should_store_tokens(self.request)) From d9214e31a7838b194c32163c99af25bfded7a91b Mon Sep 17 00:00:00 2001 From: tnware Date: Sun, 9 Mar 2025 14:23:26 -0500 Subject: [PATCH 5/9] remove unused imports --- django_auth_adfs/middleware.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/django_auth_adfs/middleware.py b/django_auth_adfs/middleware.py index 79dff508..42ce7ac8 100644 --- a/django_auth_adfs/middleware.py +++ b/django_auth_adfs/middleware.py @@ -2,17 +2,15 @@ Based on https://djangosnippets.org/snippets/1179/ """ -import datetime import logging from re import compile from django.conf import settings as django_settings -from django.contrib.auth import logout from django.contrib.auth.views import redirect_to_login from django.urls import reverse from django_auth_adfs.exceptions import MFARequired -from django_auth_adfs.config import provider_config, settings +from django_auth_adfs.config import settings from django_auth_adfs.token_manager import token_manager LOGIN_EXEMPT_URLS = [ From 285717e6f4a3bf8824901a1681b9b63ce256b782 Mon Sep 17 00:00:00 2001 From: tnware Date: Sun, 9 Mar 2025 15:34:52 -0500 Subject: [PATCH 6/9] fix flake8 feedback --- django_auth_adfs/backend.py | 1 - django_auth_adfs/middleware.py | 14 +- django_auth_adfs/token_manager.py | 208 +++++++++++++++--------------- tests/test_middleware.py | 39 +++--- 4 files changed, 131 insertions(+), 131 deletions(-) diff --git a/django_auth_adfs/backend.py b/django_auth_adfs/backend.py index a63da931..d6d6d5af 100644 --- a/django_auth_adfs/backend.py +++ b/django_auth_adfs/backend.py @@ -1,5 +1,4 @@ import logging -import datetime import jwt from django.contrib.auth import get_user_model diff --git a/django_auth_adfs/middleware.py b/django_auth_adfs/middleware.py index 42ce7ac8..df98de89 100644 --- a/django_auth_adfs/middleware.py +++ b/django_auth_adfs/middleware.py @@ -35,6 +35,7 @@ class LoginRequiredMiddleware: Requires authentication middleware and template context processors to be loaded. You'll get an error if they aren't. """ + def __init__(self, get_response): self.get_response = get_response @@ -59,20 +60,20 @@ def __call__(self, request): class TokenLifecycleMiddleware: """ Middleware that handles the lifecycle of ADFS access and refresh tokens. - + This middleware will: 1. Check if the access token is about to expire 2. Use the refresh token to get a new access token if needed 3. Update the tokens in the session 4. Handle OBO (On-Behalf-Of) tokens for Microsoft Graph API - + Token storage during authentication is handled by the backend when this middleware is enabled. - + To enable this middleware, add it to your MIDDLEWARE setting: 'django_auth_adfs.middleware.TokenLifecycleMiddleware' - + You can configure the token refresh behavior with these settings: - + TOKEN_REFRESH_THRESHOLD: Number of seconds before expiration to refresh (default: 300) STORE_OBO_TOKEN: Boolean to enable/disable OBO token storage (default: True) LOGOUT_ON_TOKEN_REFRESH_FAILURE: Whether to log out the user if token refresh fails (default: False) @@ -93,7 +94,6 @@ def __call__(self, request): if hasattr(request, "user") and request.user.is_authenticated: # Check if tokens need to be refreshed token_manager.check_token_expiration(request) - + response = self.get_response(request) return response - diff --git a/django_auth_adfs/token_manager.py b/django_auth_adfs/token_manager.py index ab15a0a4..0bd9ca28 100644 --- a/django_auth_adfs/token_manager.py +++ b/django_auth_adfs/token_manager.py @@ -22,43 +22,43 @@ class TokenManager: """ Centralized manager for token lifecycle operations. - + This class handles: - Token storage during authentication - Token encryption/decryption - Token refresh - Token retrieval - OBO token management - + It's designed to be lightweight when not actively performing operations, and to handle all token operations in a safe, transparent, and error-free manner. """ - + # Session key constants ACCESS_TOKEN_KEY = "ADFS_ACCESS_TOKEN" REFRESH_TOKEN_KEY = "ADFS_REFRESH_TOKEN" TOKEN_EXPIRES_AT_KEY = "ADFS_TOKEN_EXPIRES_AT" OBO_ACCESS_TOKEN_KEY = "ADFS_OBO_ACCESS_TOKEN" OBO_TOKEN_EXPIRES_AT_KEY = "ADFS_OBO_TOKEN_EXPIRES_AT" - + def __init__(self): """Initialize the TokenManager with settings.""" # Load settings self.refresh_threshold = getattr(settings, "TOKEN_REFRESH_THRESHOLD", 300) self.store_obo_token = getattr(settings, "STORE_OBO_TOKEN", True) self.logout_on_refresh_failure = getattr(settings, "LOGOUT_ON_TOKEN_REFRESH_FAILURE", False) - + # Check if using signed cookies self.using_signed_cookies = ( django_settings.SESSION_ENGINE == "django.contrib.sessions.backends.signed_cookies" ) - + if self.using_signed_cookies: logger.warning( "TokenManager: Storing tokens in signed cookies is not recommended for security " "reasons and cookie size limitations. Token storage will be disabled." ) - + def is_middleware_enabled(self): """Check if the TokenLifecycleMiddleware is enabled.""" EXPECTED_MIDDLEWARE = 'django_auth_adfs.middleware.TokenLifecycleMiddleware' @@ -67,44 +67,44 @@ def is_middleware_enabled(self): except Exception as e: logger.warning(f"Error checking if middleware is enabled: {e}") return False - + def should_store_tokens(self, request): """ Check if tokens should be stored in the session. - + Tokens are stored if: 1. We have a request with a session 2. The TokenLifecycleMiddleware is enabled 3. We're not using signed cookies - + Args: request: The current request object - + Returns: bool: True if tokens should be stored, False otherwise """ if not request or not hasattr(request, "session"): return False - + if self.using_signed_cookies: return False - + return self.is_middleware_enabled() - + def _get_encryption_key(self): """ Derive a Fernet encryption key from Django's SECRET_KEY. - + Returns: bytes: A 32-byte key suitable for Fernet encryption """ # Use Django's SECRET_KEY to derive a suitable encryption key default_salt = b"django_auth_adfs_token_encryption" salt = getattr(settings, "TOKEN_ENCRYPTION_SALT", default_salt) - + if isinstance(salt, str): salt = salt.encode() - + kdf = PBKDF2HMAC( algorithm=hashes.SHA256(), length=32, @@ -113,20 +113,20 @@ def _get_encryption_key(self): ) key = base64.urlsafe_b64encode(kdf.derive(django_settings.SECRET_KEY.encode())) return key - + def encrypt_token(self, token): """ Encrypt a token using Django's SECRET_KEY. - + Args: token (str): The token to encrypt - + Returns: str: The encrypted token as a string or None if encryption fails """ if not token: return None - + try: key = self._get_encryption_key() f = Fernet(key) @@ -135,20 +135,20 @@ def encrypt_token(self, token): except Exception as e: logger.error(f"Error encrypting token: {e}") return None - + def decrypt_token(self, encrypted_token): """ Decrypt a token that was encrypted using Django's SECRET_KEY. - + Args: encrypted_token (str): The encrypted token - + Returns: str: The decrypted token or None if decryption fails """ if not encrypted_token: return None - + try: key = self._get_encryption_key() f = Fernet(key) @@ -157,78 +157,78 @@ def decrypt_token(self, encrypted_token): except Exception as e: logger.error(f"Error decrypting token: {e}") return None - + def get_access_token(self, request): """ Get the current access token from the session. - + The token is automatically decrypted before being returned. - + Args: request: The current request object - + Returns: str: The access token or None if not available """ if not hasattr(request, "session"): return None - + if self.using_signed_cookies: logger.debug("Token retrieval from signed_cookies session is disabled") return None - + encrypted_token = request.session.get(self.ACCESS_TOKEN_KEY) return self.decrypt_token(encrypted_token) - + def get_obo_access_token(self, request): """ Get the current OBO access token from the session. - + The token is automatically decrypted before being returned. - + Args: request: The current request object - + Returns: str: The OBO access token or None if not available """ if not hasattr(request, "session"): return None - + if self.using_signed_cookies: logger.debug("Token retrieval from signed_cookies session is disabled") return None - + if not self.store_obo_token: logger.debug("OBO token storage is disabled") return None - + encrypted_token = request.session.get(self.OBO_ACCESS_TOKEN_KEY) return self.decrypt_token(encrypted_token) - + def validate_token_format(self, token): """ Basic validation of token format before storage. - + Args: token (str): Token to validate - + Returns: bool: True if token appears valid, False otherwise """ if not isinstance(token, str): return False - + try: # Check if it's a valid JWT format parts = token.split('.') if len(parts) != 3: return False - + # Check if each part is valid base64 for part in parts: base64.urlsafe_b64decode(part + '=' * (-len(part) % 4)) - + return True except Exception: return False @@ -236,33 +236,33 @@ def validate_token_format(self, token): def store_tokens(self, request, access_token, adfs_response=None): """ Store tokens in the session. - + Args: request: The current request object access_token (str): The access token to store (must be a JWT) adfs_response (dict, optional): The full response from ADFS containing refresh token and expiration - + Returns: bool: True if tokens were stored, False otherwise """ if not self.should_store_tokens(request): logger.debug("Token storage is disabled") return False - + if not self.validate_token_format(access_token): logger.warning("Invalid access token format, refusing to store") return False - + try: session_modified = False - + # Store access token (JWT) encrypted_token = self.encrypt_token(access_token) if encrypted_token: request.session[self.ACCESS_TOKEN_KEY] = encrypted_token session_modified = True logger.debug("Stored access token") - + # Store refresh token (can be any string) if adfs_response and "refresh_token" in adfs_response: refresh_token = adfs_response["refresh_token"] @@ -278,7 +278,7 @@ def store_tokens(self, request, access_token, adfs_response=None): logger.warning("Empty refresh token received from ADFS") else: logger.debug("No refresh token in ADFS response") - + # Store token expiration if adfs_response and "expires_in" in adfs_response: expires_at = datetime.datetime.now() + datetime.timedelta( @@ -287,13 +287,13 @@ def store_tokens(self, request, access_token, adfs_response=None): request.session[self.TOKEN_EXPIRES_AT_KEY] = expires_at.isoformat() session_modified = True logger.debug("Stored token expiration") - + # Store OBO token if enabled (must be JWT) if self.store_obo_token: try: # Import here to avoid circular imports from django_auth_adfs.backend import AdfsBaseBackend - + backend = AdfsBaseBackend() obo_token = backend.get_obo_access_token(access_token) if obo_token and self.validate_token_format(obo_token): @@ -306,108 +306,108 @@ def store_tokens(self, request, access_token, adfs_response=None): logger.debug("Stored OBO token") except Exception as e: logger.warning(f"Error getting OBO token: {e}") - + if session_modified: request.session.modified = True logger.debug("All tokens stored successfully") return True - + logger.warning("No tokens were stored") return False - + except Exception as e: logger.warning(f"Error storing tokens in session: {e}") return False - + def check_token_expiration(self, request): """ Check if tokens need to be refreshed and refresh them if needed. - + Args: request: The current request object - + Returns: bool: True if tokens were checked, False otherwise """ if not hasattr(request, "user") or not request.user.is_authenticated: return False - + if self.using_signed_cookies: return False - + try: if self.TOKEN_EXPIRES_AT_KEY not in request.session: return False - + # Check if token is about to expire expires_at = datetime.datetime.fromisoformat(request.session[self.TOKEN_EXPIRES_AT_KEY]) remaining = expires_at - datetime.datetime.now() - + if remaining.total_seconds() < self.refresh_threshold: logger.debug("Token is about to expire. Refreshing...") self.refresh_tokens(request) - + # Check if OBO token is about to expire if self.store_obo_token and self.OBO_TOKEN_EXPIRES_AT_KEY in request.session: obo_expires_at = datetime.datetime.fromisoformat(request.session[self.OBO_TOKEN_EXPIRES_AT_KEY]) obo_remaining = obo_expires_at - datetime.datetime.now() - + if obo_remaining.total_seconds() < self.refresh_threshold: logger.debug("OBO token is about to expire. Refreshing...") self.refresh_obo_token(request) - + return True - + except Exception as e: logger.warning(f"Error checking token expiration: {e}") return False - + def refresh_tokens(self, request): """ Refresh the access token using the refresh token. - - + + Args: request: The current request object - + Returns: bool: True if tokens were refreshed, False otherwise """ if self.using_signed_cookies: return False - + if self.REFRESH_TOKEN_KEY not in request.session: return False - + try: refresh_token = self.decrypt_token(request.session[self.REFRESH_TOKEN_KEY]) if not refresh_token: logger.warning("Failed to decrypt refresh token") return False - + provider_config.load_config() - + data = { "grant_type": "refresh_token", "client_id": settings.CLIENT_ID, "refresh_token": refresh_token, } - + if settings.CLIENT_SECRET: data["client_secret"] = settings.CLIENT_SECRET - + token_endpoint = provider_config.token_endpoint if token_endpoint is None: logger.error("Token endpoint is None, cannot refresh tokens") return False - + response = provider_config.session.post( token_endpoint, data=data, timeout=settings.TIMEOUT ) - + if response.status_code == 200: token_data = response.json() - + # Store new tokens - if another refresh happened, these will just overwrite # with fresher tokens, which is fine request.session[self.ACCESS_TOKEN_KEY] = self.encrypt_token( @@ -422,11 +422,11 @@ def refresh_tokens(self, request): request.session[self.TOKEN_EXPIRES_AT_KEY] = expires_at.isoformat() request.session.modified = True logger.debug("Refreshed tokens successfully") - + # Also refresh the OBO token if needed if self.store_obo_token: self.refresh_obo_token(request) - + return True else: logger.warning( @@ -436,47 +436,47 @@ def refresh_tokens(self, request): logger.info("Logging out user due to token refresh failure") logout(request) return False - + except Exception as e: logger.exception(f"Error refreshing tokens: {e}") if self.logout_on_refresh_failure: logger.info("Logging out user due to token refresh error") logout(request) return False - + def refresh_obo_token(self, request): """ Refresh the OBO token for Microsoft Graph API. - + Args: request: The current request object - + Returns: bool: True if OBO token was refreshed, False otherwise """ if not self.store_obo_token: return False - + if self.using_signed_cookies: return False - + if self.ACCESS_TOKEN_KEY not in request.session: return False - + try: provider_config.load_config() - + access_token = self.decrypt_token(request.session[self.ACCESS_TOKEN_KEY]) if not access_token: logger.warning("Failed to decrypt access token") return False - + # Import here to avoid circular imports from django_auth_adfs.backend import AdfsBaseBackend - + backend = AdfsBaseBackend() obo_token = backend.get_obo_access_token(access_token) - + if obo_token: request.session[self.OBO_ACCESS_TOKEN_KEY] = self.encrypt_token(obo_token) obo_expires_at = datetime.datetime.now() + datetime.timedelta(hours=1) @@ -484,29 +484,29 @@ def refresh_obo_token(self, request): request.session.modified = True logger.debug("Refreshed OBO token successfully") return True - + return False - + except Exception as e: logger.warning(f"Error refreshing OBO token: {e}") return False - + def clear_tokens(self, request): """ Clear all tokens from the session. - + Args: request: The current request object - + Returns: bool: True if tokens were cleared, False otherwise """ if not hasattr(request, "session"): return False - + try: session_modified = False - + for key in [ self.ACCESS_TOKEN_KEY, self.REFRESH_TOKEN_KEY, @@ -517,18 +517,18 @@ def clear_tokens(self, request): if key in request.session: del request.session[key] session_modified = True - + if session_modified: request.session.modified = True logger.debug("Cleared tokens from session") return True - + return False - + except Exception as e: logger.warning(f"Error clearing tokens from session: {e}") return False # Create a singleton instance -token_manager = TokenManager() \ No newline at end of file +token_manager = TokenManager() diff --git a/tests/test_middleware.py b/tests/test_middleware.py index 2229d217..fbe7d76e 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -24,18 +24,19 @@ "django_auth_adfs.middleware.TokenLifecycleMiddleware", ) + def create_test_token(claims=None, exp_delta=3600): """Create a test JWT token with the given claims and expiration delta.""" if claims is None: claims = {} - + # Create a basic JWT token with ADFS-like structure header = { "typ": "JWT", "alg": "RS256", "x5t": "example-thumbprint" } - + # Add standard ADFS claims if not present now = int(time.time()) if "iat" not in claims: @@ -48,12 +49,12 @@ def create_test_token(claims=None, exp_delta=3600): claims["iss"] = "https://sts.windows.net/01234567-89ab-cdef-0123-456789abcdef/" if "sub" not in claims: claims["sub"] = "john.doe@example.com" - + # Encode each part header_part = base64.urlsafe_b64encode(json.dumps(header).encode()).rstrip(b"=").decode() claims_part = base64.urlsafe_b64encode(json.dumps(claims).encode()).rstrip(b"=").decode() signature_part = base64.urlsafe_b64encode(b"test_signature").rstrip(b"=").decode() - + # Combine parts return f"{header_part}.{claims_part}.{signature_part}" @@ -75,9 +76,9 @@ def setUp(self): def test_settings_configuration(self): """Test settings are properly loaded from Django settings""" with patch.object(adfs_settings, 'TOKEN_REFRESH_THRESHOLD', 600), \ - patch.object(adfs_settings, 'STORE_OBO_TOKEN', False), \ - patch.object(adfs_settings, 'LOGOUT_ON_TOKEN_REFRESH_FAILURE', True): - + patch.object(adfs_settings, 'STORE_OBO_TOKEN', False), \ + patch.object(adfs_settings, 'LOGOUT_ON_TOKEN_REFRESH_FAILURE', True): + manager = TokenManager() self.assertEqual(manager.refresh_threshold, 600) self.assertFalse(manager.store_obo_token) @@ -87,7 +88,7 @@ def test_token_storage_and_retrieval(self): """Test the complete token storage and retrieval flow""" access_token = create_test_token({"type": "access"}) refresh_token = create_test_token({"type": "refresh"}) - + # Store tokens token_manager.store_tokens( self.request, @@ -114,7 +115,7 @@ def test_token_refresh_flow(self): old_refresh_token = create_test_token({"type": "refresh"}) new_access_token = create_test_token({"type": "access"}) new_refresh_token = create_test_token({"type": "refresh"}) - + # Setup expired token token_manager.store_tokens( self.request, @@ -150,7 +151,7 @@ def test_obo_token_management(self): """Test OBO token functionality when enabled""" access_token = create_test_token({"type": "access"}) obo_token = create_test_token({"type": "obo"}) - + # Store regular token token_manager.store_tokens( self.request, @@ -161,7 +162,7 @@ def test_obo_token_management(self): # Mock OBO flow with patch("django_auth_adfs.backend.AdfsBaseBackend") as mock_backend: mock_backend.return_value.get_obo_access_token.return_value = obo_token - + # Verify OBO token storage and retrieval self.request.session[token_manager.OBO_ACCESS_TOKEN_KEY] = \ token_manager.encrypt_token(obo_token) @@ -179,7 +180,7 @@ def test_error_handling(self): # Test refresh failure access_token = create_test_token({"type": "access"}, exp_delta=-60) refresh_token = create_test_token({"type": "refresh"}) - + with patch("django_auth_adfs.token_manager.provider_config") as mock_config: # Setup expired tokens first token_manager.store_tokens( @@ -191,7 +192,7 @@ def test_error_handling(self): "expires_in": -60 # Already expired } ) - + mock_config.session.post.return_value = Mock(status_code=400, text="Error") mock_config.token_endpoint = "https://example.com/token" @@ -207,7 +208,7 @@ def test_signed_cookies_handling(self): """Test behavior with signed cookies session backend""" test_token = create_test_token({"type": "access"}) refresh_token = create_test_token({"type": "refresh"}) - + token_manager.using_signed_cookies = True try: success = token_manager.store_tokens( @@ -255,7 +256,7 @@ def test_clear_tokens(self): """Test clearing tokens from session""" access_token = create_test_token({"type": "access"}) refresh_token = create_test_token({"type": "refresh"}) - + # Store some tokens first token_manager.store_tokens( self.request, @@ -286,7 +287,7 @@ def test_refresh_obo_token_directly(self): """Test direct OBO token refresh""" access_token = create_test_token({"type": "access"}) new_obo_token = create_test_token({"type": "obo"}) - + # Store access token first token_manager.store_tokens( self.request, @@ -296,13 +297,13 @@ def test_refresh_obo_token_directly(self): # Mock OBO token acquisition and provider config with patch("django_auth_adfs.backend.AdfsBaseBackend") as mock_backend, \ - patch("django_auth_adfs.token_manager.provider_config") as mock_provider: - + patch("django_auth_adfs.token_manager.provider_config") as mock_provider: + mock_backend.return_value.get_obo_access_token.return_value = new_obo_token mock_provider.load_config.return_value = None mock_provider.token_endpoint = "https://example.com/token" mock_provider.session.verify = False # Disable cert validation - + # Refresh OBO token success = token_manager.refresh_obo_token(self.request) self.assertTrue(success) From 0d821b7c609328c033b9b94d669e652fb12b7a2f Mon Sep 17 00:00:00 2001 From: tnware Date: Mon, 10 Mar 2025 08:49:45 -0700 Subject: [PATCH 7/9] cleanup (#5) * remove useless validation method * use actual expiration time from obo token --- django_auth_adfs/token_manager.py | 103 ++++++++++++++++-------------- 1 file changed, 55 insertions(+), 48 deletions(-) diff --git a/django_auth_adfs/token_manager.py b/django_auth_adfs/token_manager.py index 0bd9ca28..14fde1c8 100644 --- a/django_auth_adfs/token_manager.py +++ b/django_auth_adfs/token_manager.py @@ -46,11 +46,14 @@ def __init__(self): # Load settings self.refresh_threshold = getattr(settings, "TOKEN_REFRESH_THRESHOLD", 300) self.store_obo_token = getattr(settings, "STORE_OBO_TOKEN", True) - self.logout_on_refresh_failure = getattr(settings, "LOGOUT_ON_TOKEN_REFRESH_FAILURE", False) + self.logout_on_refresh_failure = getattr( + settings, "LOGOUT_ON_TOKEN_REFRESH_FAILURE", False + ) # Check if using signed cookies self.using_signed_cookies = ( - django_settings.SESSION_ENGINE == "django.contrib.sessions.backends.signed_cookies" + django_settings.SESSION_ENGINE + == "django.contrib.sessions.backends.signed_cookies" ) if self.using_signed_cookies: @@ -61,7 +64,7 @@ def __init__(self): def is_middleware_enabled(self): """Check if the TokenLifecycleMiddleware is enabled.""" - EXPECTED_MIDDLEWARE = 'django_auth_adfs.middleware.TokenLifecycleMiddleware' + EXPECTED_MIDDLEWARE = "django_auth_adfs.middleware.TokenLifecycleMiddleware" try: return EXPECTED_MIDDLEWARE in django_settings.MIDDLEWARE except Exception as e: @@ -206,33 +209,6 @@ def get_obo_access_token(self, request): encrypted_token = request.session.get(self.OBO_ACCESS_TOKEN_KEY) return self.decrypt_token(encrypted_token) - def validate_token_format(self, token): - """ - Basic validation of token format before storage. - - Args: - token (str): Token to validate - - Returns: - bool: True if token appears valid, False otherwise - """ - if not isinstance(token, str): - return False - - try: - # Check if it's a valid JWT format - parts = token.split('.') - if len(parts) != 3: - return False - - # Check if each part is valid base64 - for part in parts: - base64.urlsafe_b64decode(part + '=' * (-len(part) % 4)) - - return True - except Exception: - return False - def store_tokens(self, request, access_token, adfs_response=None): """ Store tokens in the session. @@ -249,10 +225,6 @@ def store_tokens(self, request, access_token, adfs_response=None): logger.debug("Token storage is disabled") return False - if not self.validate_token_format(access_token): - logger.warning("Invalid access token format, refusing to store") - return False - try: session_modified = False @@ -296,14 +268,27 @@ def store_tokens(self, request, access_token, adfs_response=None): backend = AdfsBaseBackend() obo_token = backend.get_obo_access_token(access_token) - if obo_token and self.validate_token_format(obo_token): + if obo_token: encrypted_token = self.encrypt_token(obo_token) if encrypted_token: request.session[self.OBO_ACCESS_TOKEN_KEY] = encrypted_token - obo_expires_at = datetime.datetime.now() + datetime.timedelta(hours=1) - request.session[self.OBO_TOKEN_EXPIRES_AT_KEY] = obo_expires_at.isoformat() - session_modified = True - logger.debug("Stored OBO token") + # Decode the OBO token to get its actual expiration time + import jwt + + decoded_token = jwt.decode( + obo_token, options={"verify_signature": False} + ) + if "exp" in decoded_token: + obo_expires_at = datetime.datetime.fromtimestamp( + decoded_token["exp"] + ) + request.session[self.OBO_TOKEN_EXPIRES_AT_KEY] = ( + obo_expires_at.isoformat() + ) + session_modified = True + logger.debug( + "Stored OBO token with expiration from token claims" + ) except Exception as e: logger.warning(f"Error getting OBO token: {e}") @@ -340,7 +325,9 @@ def check_token_expiration(self, request): return False # Check if token is about to expire - expires_at = datetime.datetime.fromisoformat(request.session[self.TOKEN_EXPIRES_AT_KEY]) + expires_at = datetime.datetime.fromisoformat( + request.session[self.TOKEN_EXPIRES_AT_KEY] + ) remaining = expires_at - datetime.datetime.now() if remaining.total_seconds() < self.refresh_threshold: @@ -348,8 +335,13 @@ def check_token_expiration(self, request): self.refresh_tokens(request) # Check if OBO token is about to expire - if self.store_obo_token and self.OBO_TOKEN_EXPIRES_AT_KEY in request.session: - obo_expires_at = datetime.datetime.fromisoformat(request.session[self.OBO_TOKEN_EXPIRES_AT_KEY]) + if ( + self.store_obo_token + and self.OBO_TOKEN_EXPIRES_AT_KEY in request.session + ): + obo_expires_at = datetime.datetime.fromisoformat( + request.session[self.OBO_TOKEN_EXPIRES_AT_KEY] + ) obo_remaining = obo_expires_at - datetime.datetime.now() if obo_remaining.total_seconds() < self.refresh_threshold: @@ -478,11 +470,26 @@ def refresh_obo_token(self, request): obo_token = backend.get_obo_access_token(access_token) if obo_token: - request.session[self.OBO_ACCESS_TOKEN_KEY] = self.encrypt_token(obo_token) - obo_expires_at = datetime.datetime.now() + datetime.timedelta(hours=1) - request.session[self.OBO_TOKEN_EXPIRES_AT_KEY] = obo_expires_at.isoformat() - request.session.modified = True - logger.debug("Refreshed OBO token successfully") + request.session[self.OBO_ACCESS_TOKEN_KEY] = self.encrypt_token( + obo_token + ) + # Decode the OBO token to get its actual expiration time + import jwt + + decoded_token = jwt.decode( + obo_token, options={"verify_signature": False} + ) + if "exp" in decoded_token: + obo_expires_at = datetime.datetime.fromtimestamp( + decoded_token["exp"] + ) + request.session[self.OBO_TOKEN_EXPIRES_AT_KEY] = ( + obo_expires_at.isoformat() + ) + request.session.modified = True + logger.debug( + "Refreshed OBO token with expiration from token claims" + ) return True return False @@ -512,7 +519,7 @@ def clear_tokens(self, request): self.REFRESH_TOKEN_KEY, self.TOKEN_EXPIRES_AT_KEY, self.OBO_ACCESS_TOKEN_KEY, - self.OBO_TOKEN_EXPIRES_AT_KEY + self.OBO_TOKEN_EXPIRES_AT_KEY, ]: if key in request.session: del request.session[key] From 7cdeb9b8df9638d8cd808b7d953d8bee159fc3d0 Mon Sep 17 00:00:00 2001 From: tnware Date: Mon, 10 Mar 2025 14:26:49 -0700 Subject: [PATCH 8/9] actually middleware --- django_auth_adfs/backend.py | 5 +- django_auth_adfs/middleware.py | 10 +- django_auth_adfs/token_manager.py | 188 +++++++++++------------------- tests/test_middleware.py | 170 +++++++++++++-------------- 4 files changed, 154 insertions(+), 219 deletions(-) diff --git a/django_auth_adfs/backend.py b/django_auth_adfs/backend.py index d6d6d5af..61258261 100644 --- a/django_auth_adfs/backend.py +++ b/django_auth_adfs/backend.py @@ -10,7 +10,6 @@ from django_auth_adfs import signals from django_auth_adfs.config import provider_config, settings from django_auth_adfs.exceptions import MFARequired -from django_auth_adfs.token_manager import token_manager logger = logging.getLogger("django_auth_adfs") @@ -199,8 +198,8 @@ def process_access_token(self, access_token, adfs_response=None, request=None): raise PermissionDenied # Store tokens in session if middleware is enabled - if request and adfs_response: - token_manager.store_tokens(request, access_token, adfs_response) + if request and adfs_response and hasattr(request, "token_storage"): + request.token_storage.store_tokens(request, access_token, adfs_response) groups = self.process_user_groups(claims, access_token) user = self.create_user(claims) diff --git a/django_auth_adfs/middleware.py b/django_auth_adfs/middleware.py index df98de89..159dcb93 100644 --- a/django_auth_adfs/middleware.py +++ b/django_auth_adfs/middleware.py @@ -81,6 +81,7 @@ class TokenLifecycleMiddleware: def __init__(self, get_response): self.get_response = get_response + self.token_manager = token_manager # Log warning if using signed cookies if token_manager.using_signed_cookies: logger.warning( @@ -91,9 +92,10 @@ def __init__(self, get_response): ) def __call__(self, request): + if hasattr(request, "session") and not self.token_manager.using_signed_cookies: + request.token_storage = self.token_manager + if hasattr(request, "user") and request.user.is_authenticated: - # Check if tokens need to be refreshed - token_manager.check_token_expiration(request) + self.token_manager.check_token_expiration(request) - response = self.get_response(request) - return response + return self.get_response(request) diff --git a/django_auth_adfs/token_manager.py b/django_auth_adfs/token_manager.py index 14fde1c8..4876d29a 100644 --- a/django_auth_adfs/token_manager.py +++ b/django_auth_adfs/token_manager.py @@ -43,14 +43,12 @@ class TokenManager: def __init__(self): """Initialize the TokenManager with settings.""" - # Load settings self.refresh_threshold = getattr(settings, "TOKEN_REFRESH_THRESHOLD", 300) self.store_obo_token = getattr(settings, "STORE_OBO_TOKEN", True) self.logout_on_refresh_failure = getattr( settings, "LOGOUT_ON_TOKEN_REFRESH_FAILURE", False ) - # Check if using signed cookies self.using_signed_cookies = ( django_settings.SESSION_ENGINE == "django.contrib.sessions.backends.signed_cookies" @@ -62,37 +60,82 @@ def __init__(self): "reasons and cookie size limitations. Token storage will be disabled." ) - def is_middleware_enabled(self): - """Check if the TokenLifecycleMiddleware is enabled.""" - EXPECTED_MIDDLEWARE = "django_auth_adfs.middleware.TokenLifecycleMiddleware" - try: - return EXPECTED_MIDDLEWARE in django_settings.MIDDLEWARE - except Exception as e: - logger.warning(f"Error checking if middleware is enabled: {e}") + def store_tokens(self, request, access_token, adfs_response=None): + if not hasattr(request, "session"): return False - def should_store_tokens(self, request): - """ - Check if tokens should be stored in the session. + try: + session_modified = False - Tokens are stored if: - 1. We have a request with a session - 2. The TokenLifecycleMiddleware is enabled - 3. We're not using signed cookies + encrypted_token = self.encrypt_token(access_token) + if encrypted_token: + request.session[self.ACCESS_TOKEN_KEY] = encrypted_token + session_modified = True + logger.debug("Stored access token") - Args: - request: The current request object + if adfs_response and "refresh_token" in adfs_response: + refresh_token = adfs_response["refresh_token"] + if refresh_token: + encrypted_token = self.encrypt_token(refresh_token) + if encrypted_token: + request.session[self.REFRESH_TOKEN_KEY] = encrypted_token + session_modified = True + logger.debug("Stored refresh token") + else: + logger.warning("Failed to encrypt refresh token") + else: + logger.warning("Empty refresh token received from ADFS") + else: + logger.debug("No refresh token in ADFS response") - Returns: - bool: True if tokens should be stored, False otherwise - """ - if not request or not hasattr(request, "session"): - return False + if adfs_response and "expires_in" in adfs_response: + expires_at = datetime.datetime.now() + datetime.timedelta( + seconds=int(adfs_response["expires_in"]) + ) + request.session[self.TOKEN_EXPIRES_AT_KEY] = expires_at.isoformat() + session_modified = True + logger.debug("Stored token expiration") - if self.using_signed_cookies: + if self.store_obo_token: + try: + from django_auth_adfs.backend import AdfsBaseBackend + + backend = AdfsBaseBackend() + obo_token = backend.get_obo_access_token(access_token) + if obo_token: + encrypted_token = self.encrypt_token(obo_token) + if encrypted_token: + request.session[self.OBO_ACCESS_TOKEN_KEY] = encrypted_token + import jwt + + decoded_token = jwt.decode( + obo_token, options={"verify_signature": False} + ) + if "exp" in decoded_token: + obo_expires_at = datetime.datetime.fromtimestamp( + decoded_token["exp"] + ) + request.session[self.OBO_TOKEN_EXPIRES_AT_KEY] = ( + obo_expires_at.isoformat() + ) + session_modified = True + logger.debug( + "Stored OBO token with expiration from token claims" + ) + except Exception as e: + logger.warning(f"Error getting OBO token: {e}") + + if session_modified: + request.session.modified = True + logger.debug("All tokens stored successfully") + return True + + logger.warning("No tokens were stored") return False - return self.is_middleware_enabled() + except Exception as e: + logger.warning(f"Error storing tokens in session: {e}") + return False def _get_encryption_key(self): """ @@ -209,101 +252,6 @@ def get_obo_access_token(self, request): encrypted_token = request.session.get(self.OBO_ACCESS_TOKEN_KEY) return self.decrypt_token(encrypted_token) - def store_tokens(self, request, access_token, adfs_response=None): - """ - Store tokens in the session. - - Args: - request: The current request object - access_token (str): The access token to store (must be a JWT) - adfs_response (dict, optional): The full response from ADFS containing refresh token and expiration - - Returns: - bool: True if tokens were stored, False otherwise - """ - if not self.should_store_tokens(request): - logger.debug("Token storage is disabled") - return False - - try: - session_modified = False - - # Store access token (JWT) - encrypted_token = self.encrypt_token(access_token) - if encrypted_token: - request.session[self.ACCESS_TOKEN_KEY] = encrypted_token - session_modified = True - logger.debug("Stored access token") - - # Store refresh token (can be any string) - if adfs_response and "refresh_token" in adfs_response: - refresh_token = adfs_response["refresh_token"] - if refresh_token: # Just check it's not empty - encrypted_token = self.encrypt_token(refresh_token) - if encrypted_token: - request.session[self.REFRESH_TOKEN_KEY] = encrypted_token - session_modified = True - logger.debug("Stored refresh token") - else: - logger.warning("Failed to encrypt refresh token") - else: - logger.warning("Empty refresh token received from ADFS") - else: - logger.debug("No refresh token in ADFS response") - - # Store token expiration - if adfs_response and "expires_in" in adfs_response: - expires_at = datetime.datetime.now() + datetime.timedelta( - seconds=int(adfs_response["expires_in"]) - ) - request.session[self.TOKEN_EXPIRES_AT_KEY] = expires_at.isoformat() - session_modified = True - logger.debug("Stored token expiration") - - # Store OBO token if enabled (must be JWT) - if self.store_obo_token: - try: - # Import here to avoid circular imports - from django_auth_adfs.backend import AdfsBaseBackend - - backend = AdfsBaseBackend() - obo_token = backend.get_obo_access_token(access_token) - if obo_token: - encrypted_token = self.encrypt_token(obo_token) - if encrypted_token: - request.session[self.OBO_ACCESS_TOKEN_KEY] = encrypted_token - # Decode the OBO token to get its actual expiration time - import jwt - - decoded_token = jwt.decode( - obo_token, options={"verify_signature": False} - ) - if "exp" in decoded_token: - obo_expires_at = datetime.datetime.fromtimestamp( - decoded_token["exp"] - ) - request.session[self.OBO_TOKEN_EXPIRES_AT_KEY] = ( - obo_expires_at.isoformat() - ) - session_modified = True - logger.debug( - "Stored OBO token with expiration from token claims" - ) - except Exception as e: - logger.warning(f"Error getting OBO token: {e}") - - if session_modified: - request.session.modified = True - logger.debug("All tokens stored successfully") - return True - - logger.warning("No tokens were stored") - return False - - except Exception as e: - logger.warning(f"Error storing tokens in session: {e}") - return False - def check_token_expiration(self, request): """ Check if tokens need to be refreshed and refresh them if needed. diff --git a/tests/test_middleware.py b/tests/test_middleware.py index fbe7d76e..11fa6d6c 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -31,11 +31,7 @@ def create_test_token(claims=None, exp_delta=3600): claims = {} # Create a basic JWT token with ADFS-like structure - header = { - "typ": "JWT", - "alg": "RS256", - "x5t": "example-thumbprint" - } + header = {"typ": "JWT", "alg": "RS256", "x5t": "example-thumbprint"} # Add standard ADFS claims if not present now = int(time.time()) @@ -51,8 +47,12 @@ def create_test_token(claims=None, exp_delta=3600): claims["sub"] = "john.doe@example.com" # Encode each part - header_part = base64.urlsafe_b64encode(json.dumps(header).encode()).rstrip(b"=").decode() - claims_part = base64.urlsafe_b64encode(json.dumps(claims).encode()).rstrip(b"=").decode() + header_part = ( + base64.urlsafe_b64encode(json.dumps(header).encode()).rstrip(b"=").decode() + ) + claims_part = ( + base64.urlsafe_b64encode(json.dumps(claims).encode()).rstrip(b"=").decode() + ) signature_part = base64.urlsafe_b64encode(b"test_signature").rstrip(b"=").decode() # Combine parts @@ -75,29 +75,52 @@ def setUp(self): def test_settings_configuration(self): """Test settings are properly loaded from Django settings""" - with patch.object(adfs_settings, 'TOKEN_REFRESH_THRESHOLD', 600), \ - patch.object(adfs_settings, 'STORE_OBO_TOKEN', False), \ - patch.object(adfs_settings, 'LOGOUT_ON_TOKEN_REFRESH_FAILURE', True): + with patch.object(adfs_settings, "TOKEN_REFRESH_THRESHOLD", 600), patch.object( + adfs_settings, "STORE_OBO_TOKEN", False + ), patch.object(adfs_settings, "LOGOUT_ON_TOKEN_REFRESH_FAILURE", True): manager = TokenManager() self.assertEqual(manager.refresh_threshold, 600) self.assertFalse(manager.store_obo_token) self.assertTrue(manager.logout_on_refresh_failure) + def test_token_storage_capability(self): + """Test token storage capability is properly added by middleware""" + # Test with no session + request_without_session = self.factory.get("/") + self.middleware(request_without_session) + self.assertFalse(hasattr(request_without_session, "token_storage")) + + # Test with signed cookies + token_manager.using_signed_cookies = True + try: + self.middleware(self.request) + self.assertFalse(hasattr(self.request, "token_storage")) + finally: + token_manager.using_signed_cookies = False + + # Test with valid session + self.middleware(self.request) + self.assertTrue(hasattr(self.request, "token_storage")) + self.assertIs(self.request.token_storage, token_manager) + def test_token_storage_and_retrieval(self): """Test the complete token storage and retrieval flow""" access_token = create_test_token({"type": "access"}) refresh_token = create_test_token({"type": "refresh"}) + # Add token storage capability + self.middleware(self.request) + # Store tokens - token_manager.store_tokens( + self.request.token_storage.store_tokens( self.request, access_token, { "access_token": access_token, "refresh_token": refresh_token, - "expires_in": 3600 - } + "expires_in": 3600, + }, ) # Verify storage @@ -116,15 +139,16 @@ def test_token_refresh_flow(self): new_access_token = create_test_token({"type": "access"}) new_refresh_token = create_test_token({"type": "refresh"}) - # Setup expired token - token_manager.store_tokens( + # Add token storage capability and setup expired token + self.middleware(self.request) + self.request.token_storage.store_tokens( self.request, old_access_token, { "access_token": old_access_token, "refresh_token": old_refresh_token, - "expires_in": 60 # Will trigger refresh - } + "expires_in": 60, # Will trigger refresh + }, ) # Mock refresh response @@ -133,7 +157,7 @@ def test_token_refresh_flow(self): mock_response.json.return_value = { "access_token": new_access_token, "refresh_token": new_refresh_token, - "expires_in": 3600 + "expires_in": 3600, } mock_config.session.post.return_value = mock_response mock_config.token_endpoint = "https://example.com/token" @@ -143,8 +167,7 @@ def test_token_refresh_flow(self): # Verify tokens were updated self.assertEqual( - token_manager.get_access_token(self.request), - new_access_token + token_manager.get_access_token(self.request), new_access_token ) def test_obo_token_management(self): @@ -152,11 +175,12 @@ def test_obo_token_management(self): access_token = create_test_token({"type": "access"}) obo_token = create_test_token({"type": "obo"}) - # Store regular token - token_manager.store_tokens( + # Add token storage capability and store regular token + self.middleware(self.request) + self.request.token_storage.store_tokens( self.request, access_token, - {"access_token": access_token, "expires_in": 3600} + {"access_token": access_token, "expires_in": 3600}, ) # Mock OBO flow @@ -164,15 +188,22 @@ def test_obo_token_management(self): mock_backend.return_value.get_obo_access_token.return_value = obo_token # Verify OBO token storage and retrieval - self.request.session[token_manager.OBO_ACCESS_TOKEN_KEY] = \ + self.request.session[token_manager.OBO_ACCESS_TOKEN_KEY] = ( token_manager.encrypt_token(obo_token) - self.request.session[token_manager.OBO_TOKEN_EXPIRES_AT_KEY] = \ - (datetime.datetime.now() + datetime.timedelta(hours=1)).isoformat() + ) + self.request.session[token_manager.OBO_TOKEN_EXPIRES_AT_KEY] = ( + datetime.datetime.now() + datetime.timedelta(hours=1) + ).isoformat() - self.assertEqual(token_manager.get_obo_access_token(self.request), obo_token) + self.assertEqual( + token_manager.get_obo_access_token(self.request), obo_token + ) def test_error_handling(self): """Test error handling in various scenarios""" + # Add token storage capability + self.middleware(self.request) + # Test invalid data handling self.assertIsNone(token_manager.decrypt_token("invalid_data")) self.assertIsNone(token_manager.encrypt_token(None)) @@ -183,14 +214,14 @@ def test_error_handling(self): with patch("django_auth_adfs.token_manager.provider_config") as mock_config: # Setup expired tokens first - token_manager.store_tokens( + self.request.token_storage.store_tokens( self.request, access_token, { "access_token": access_token, "refresh_token": refresh_token, - "expires_in": -60 # Already expired - } + "expires_in": -60, # Already expired + }, ) mock_config.session.post.return_value = Mock(status_code=400, text="Error") @@ -206,18 +237,10 @@ def test_error_handling(self): def test_signed_cookies_handling(self): """Test behavior with signed cookies session backend""" - test_token = create_test_token({"type": "access"}) - refresh_token = create_test_token({"type": "refresh"}) - token_manager.using_signed_cookies = True try: - success = token_manager.store_tokens( - self.request, - test_token, - {"refresh_token": refresh_token} - ) - self.assertFalse(success) - self.assertFalse(token_manager.ACCESS_TOKEN_KEY in self.request.session) + self.middleware(self.request) + self.assertFalse(hasattr(self.request, "token_storage")) finally: token_manager.using_signed_cookies = False @@ -234,38 +257,21 @@ def test_middleware_integration(self): self.middleware(self.request) mock_check.assert_called_once_with(self.request) - def test_middleware_detection(self): - """Test middleware enabled detection""" - # Test with correct middleware path - with patch('django.conf.settings.MIDDLEWARE', [ - 'django.contrib.sessions.middleware.SessionMiddleware', - 'django_auth_adfs.middleware.TokenLifecycleMiddleware' - ]): - self.assertTrue(token_manager.is_middleware_enabled()) - - # Test with incorrect middleware path - with patch('django.conf.settings.MIDDLEWARE', [ - 'django.contrib.sessions.middleware.SessionMiddleware', - 'some_other_package.TokenLifecycleMiddleware', # Wrong package - 'django_auth_adfs.middleware.SomeOtherMiddleware', # Wrong middleware - 'django_auth_adfs.TokenLifecycleMiddleware', # Wrong path - ]): - self.assertFalse(token_manager.is_middleware_enabled()) - def test_clear_tokens(self): """Test clearing tokens from session""" access_token = create_test_token({"type": "access"}) refresh_token = create_test_token({"type": "refresh"}) - # Store some tokens first - token_manager.store_tokens( + # Add token storage capability and store tokens + self.middleware(self.request) + self.request.token_storage.store_tokens( self.request, access_token, { "access_token": access_token, "refresh_token": refresh_token, - "expires_in": 3600 - } + "expires_in": 3600, + }, ) # Verify tokens were stored @@ -288,16 +294,18 @@ def test_refresh_obo_token_directly(self): access_token = create_test_token({"type": "access"}) new_obo_token = create_test_token({"type": "obo"}) - # Store access token first - token_manager.store_tokens( + # Add token storage capability and store access token + self.middleware(self.request) + self.request.token_storage.store_tokens( self.request, access_token, - {"access_token": access_token, "expires_in": 3600} + {"access_token": access_token, "expires_in": 3600}, ) # Mock OBO token acquisition and provider config - with patch("django_auth_adfs.backend.AdfsBaseBackend") as mock_backend, \ - patch("django_auth_adfs.token_manager.provider_config") as mock_provider: + with patch("django_auth_adfs.backend.AdfsBaseBackend") as mock_backend, patch( + "django_auth_adfs.token_manager.provider_config" + ) as mock_provider: mock_backend.return_value.get_obo_access_token.return_value = new_obo_token mock_provider.load_config.return_value = None @@ -311,28 +319,6 @@ def test_refresh_obo_token_directly(self): # Verify new OBO token was stored obo_token = token_manager.get_obo_access_token(self.request) self.assertEqual(obo_token, new_obo_token) - self.assertTrue(token_manager.OBO_TOKEN_EXPIRES_AT_KEY in self.request.session) - - def test_should_store_tokens_edge_cases(self): - """Test edge cases for token storage decisions""" - # Test with no request - self.assertFalse(token_manager.should_store_tokens(None)) - - # Test with request but no session - request_without_session = self.factory.get("/") - # Instead of deleting session attribute that doesn't exist, - # we'll create a Mock object with no session attribute - from unittest.mock import Mock - request_without_session = Mock(spec=[]) # Empty spec means no attributes - self.assertFalse(token_manager.should_store_tokens(request_without_session)) - - # Test with signed cookies - token_manager.using_signed_cookies = True - try: - self.assertFalse(token_manager.should_store_tokens(self.request)) - finally: - token_manager.using_signed_cookies = False - - # Test with middleware disabled - with patch.object(token_manager, "is_middleware_enabled", return_value=False): - self.assertFalse(token_manager.should_store_tokens(self.request)) + self.assertTrue( + token_manager.OBO_TOKEN_EXPIRES_AT_KEY in self.request.session + ) From a43b0873de37800b067351873aa025e4f506d436 Mon Sep 17 00:00:00 2001 From: Tyler Woods Date: Mon, 10 Mar 2025 19:55:58 -0500 Subject: [PATCH 9/9] docs update --- docs/token_lifecycle.rst | 124 ++++++++++++++++----------------------- 1 file changed, 52 insertions(+), 72 deletions(-) diff --git a/docs/token_lifecycle.rst b/docs/token_lifecycle.rst index c27b62cf..71f4592b 100644 --- a/docs/token_lifecycle.rst +++ b/docs/token_lifecycle.rst @@ -15,17 +15,12 @@ after the authentication process. This creates a more integrated approach where: How it works ------------ -The token lifecycle system consists of two main components: - -1. **TokenManager**: A centralized singleton that handles all token operations including storage, retrieval, encryption, refresh, and OBO token management -2. **TokenLifecycleMiddleware**: A middleware that monitors token expiration and triggers refresh when needed - -Together, they handle the entire token lifecycle: +The token lifecycle system performs the following: 1. **Token Storage**: The django-auth-adfs backend automatically stores and encrypts tokens during authentication when the ``TokenLifecycleMiddleware`` is enabled 2. **Token Monitoring**: The middleware checks token expiration on each request 3. **Token Refresh**: When a token is about to expire, it is automatically refreshed -4. **OBO Token Management**: When enabled (by default), OBO tokens are automatically acquired and refreshed for Microsoft Graph API access +4. **OBO Token Management**: When enabled (by default), OBO tokens are automatically acquired and refreshed 5. **Security Controls**: Optional automatic logout on token refresh failures Read more about the OBO flow: https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-on-behalf-of-flow#protocol-diagram @@ -72,7 +67,7 @@ You can configure the token lifecycle behavior with these settings in your Djang # Number of seconds before expiration to refresh (default: 300, i.e., 5 minutes) "TOKEN_REFRESH_THRESHOLD": 300, - # Enable or disable OBO token storage for Microsoft Graph API (default: True) + # Enable or disable OBO token functionality (default: True) "STORE_OBO_TOKEN": True, # Custom salt for token encryption (optional) @@ -89,16 +84,10 @@ You can configure the token lifecycle behavior with these settings in your Djang Consider this when deploying changes to the salt in production environments. -.. note:: - By default (``STORE_OBO_TOKEN = True``), the system will automatically request and store OBO tokens - for Microsoft Graph API access. If your application doesn't need to access Microsoft Graph API, - you can set ``STORE_OBO_TOKEN = False`` to disable this functionality completely. - See `the OBO token configuration section <#disabling-obo-token-functionality>`_ for more details. - Considerations -------------- -- Token storage and encryption are handled automatically by the django-auth-adfs backend during authentication +- Token storage and encryption are handled automatically by the middleware during authentication - Token refresh only works for authenticated users with valid sessions - If the refresh token is invalid or expired, the system will not be able to refresh the access token - By default, the system will not log the user out if token refresh fails, but this behavior can be changed with the ``LOGOUT_ON_TOKEN_REFRESH_FAILURE`` setting @@ -113,10 +102,6 @@ By default, when token refresh fails, the system logs the error but allows the u - When set to ``False`` (default), users remain logged in even if their tokens can't be refreshed - When set to ``True``, users are automatically logged out when token refresh fails -When a user's account is disabled in Azure AD/ADFS, their existing Django sessions will remain active by default until they expire naturally. This can create a security gap where revoked users maintain access to your application. - -The ``LOGOUT_ON_TOKEN_REFRESH_FAILURE`` setting provides an option which helps address this concern by allowing you to automatically log out users when their token refresh fails, which will happen some time after their account has been disabled in the identity provider. - **Existing Sessions** When deploying the Token Lifecycle system to an existing application with active user sessions, be aware of the following: @@ -147,30 +132,10 @@ Security Overview **Token Encryption** Tokens are automatically encrypted before being stored in the session and decrypted when they are retrieved. -The encryption is handled transparently by the TokenManager and utility functions. This provides an additional layer of security: - -- **Always Enabled**: Token encryption is always enabled and cannot be disabled -- **Encryption Method**: Tokens are encrypted using the Fernet symmetric encryption algorithm -- **Encryption Key**: The key is derived from Django's ``SECRET_KEY`` using PBKDF2 -- **Customizable Salt**: You can customize the encryption salt using the ``TOKEN_ENCRYPTION_SALT`` setting -- **Transparent Operation**: Encryption and decryption happen automatically when tokens are stored or retrieved - +The encryption is handled transparently by the TokenManager and utility functions. **Signed Cookies Session Backend Restriction** -The system will not store tokens in the session when using Django's ``signed_cookies`` session backend: - -.. code-block:: python - - # This will not work with the token lifecycle system - SESSION_ENGINE = 'django.contrib.sessions.backends.signed_cookies' - -This is for a few reasons: - -1. **Size Limitations**: Cookies have size limitations (typically 4KB), which may be exceeded by tokens -2. **Security Risks**: Storing sensitive tokens in cookies increases the risk of token theft -3. **Performance**: Large cookies are sent with every request, increasing bandwidth usage - If you're using the ``signed_cookies`` session backend and need token storage, you won't be able to use the token lifecycle system. .. note:: @@ -184,7 +149,9 @@ By default, the system automatically requests OBO tokens when storing tokens. If Disabling OBO Token Functionality --------------------------------- -By default, the Token Lifecycle system automatically requests and stores OBO tokens for Microsoft Graph API access. If you don't need this functionality (for example, if your application doesn't interact with Microsoft Graph API), you can disable it completely: +By default, the Token Lifecycle system automatically requests and stores OBO (On-Behalf-Of) tokens. + +If you don't need this functionality, you can disable it completely: .. code-block:: python @@ -193,13 +160,6 @@ By default, the Token Lifecycle system automatically requests and stores OBO tok "STORE_OBO_TOKEN": False, } -When this setting is ``False``: - -1. The system will not request OBO tokens during token storage -2. The system will not store OBO tokens in the session -3. The system will not refresh OBO tokens -4. The ``get_obo_access_token`` utility function will always return ``None`` - Note that disabling OBO tokens doesn't affect the regular access token functionality. Your application will still be able to use the access token obtained during authentication for its own resources and APIs that directly trust your application. See `the token types section <#understanding-access-tokens-vs-obo-tokens>`_ for more details. @@ -233,9 +193,7 @@ Here are practical examples of using the TokenManager in your views: Using with Microsoft Graph API ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -In this flow, we will use the OBO token to access Microsoft Graph API. - -This is the recommended flow for delegated access to Microsoft Graph API. +This example demonstrates using the OBO token to access Microsoft Graph API .. code-block:: python @@ -267,12 +225,48 @@ This is the recommended flow for delegated access to Microsoft Graph API. status=500 ) -Using with other resources -~~~~~~~~~~~~~~~~~~~~~~~~~~ +Using with Custom ADFS-Protected API +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +This example shows how to use the OBO token to access a custom API protected by ADFS that supports the OBO flow. + +.. code-block:: python + + from django.contrib.auth.decorators import login_required + from django.http import JsonResponse + from django_auth_adfs.token_manager import token_manager + import requests + + @login_required + def custom_api_view(request): + """Access a custom API using OBO token""" + obo_token = token_manager.get_obo_access_token(request) + + if not obo_token: + return JsonResponse({"error": "No OBO token available"}, status=401) + + headers = { + "Authorization": f"Bearer {obo_token}", + "Content-Type": "application/json", + } + + try: + response = requests.get( + "https://your-custom-api.example.com/data", + headers=headers + ) + response.raise_for_status() + return JsonResponse(response.json()) + except requests.exceptions.RequestException as e: + return JsonResponse( + {"error": "Failed to fetch data", "details": str(e)}, + status=500 + ) -The key difference here is to use the ``get_access_token`` method to get the token for the resource you are accessing. +Using with Direct Resource Access +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -This is different than the ``get_obo_access_token`` method, which is used for Microsoft Graph API delegated access in the previous example. +For APIs that directly trust your application (no OBO flow needed), use the regular access token: .. code-block:: python @@ -430,21 +424,7 @@ The following example code demonstrates a debug view to check the values of the Understanding Access Tokens vs. OBO Tokens ------------------------------------------ -It's important to understand the difference between regular access tokens and OBO (On-Behalf-Of) tokens, especially in the context of delegated access versus application access: - -**Delegated Access vs. Application Access**: - There are two primary ways an application can access resources in Azure AD/ADFS: - - * **Application Access**: The application accesses resources directly with its own identity, not on behalf of a user. - - * **Delegated Access**: The application accesses resources on behalf of a signed-in user. - -**Regular Access Token**: - The token obtained during authentication with ADFS. - -**OBO (On-Behalf-Of) Token**: - The OBO flow is specifically designed for delegated access scenarios where your application needs to access resources (like Microsoft Graph) on behalf of the authenticated user. - - The TokenManager handles this exchange automatically when OBO token storage is enabled. +For more information on the different types of permissions and flows, see: -For more information on the different types of permissions, see `the Microsoft documentation `_. +* `OAuth 2.0 On-Behalf-Of flow `_ +* `Permission types `_