From 727ee0e4f18bb6e42543fea0ca59ca635bf152d8 Mon Sep 17 00:00:00 2001 From: Dmitry Supranovich Date: Fri, 29 Jul 2022 20:41:42 +0000 Subject: [PATCH] Added client_id parameter to AssertionClient Per https://datatracker.ietf.org/doc/html/rfc7521#section-4.1, client_id parameter, although optional, can still be passed when using assertions as authorization grants. Adding a way to pass that id to refresh token body. --- authlib/integrations/httpx_client/assertion_client.py | 8 ++++---- authlib/integrations/requests_client/assertion_session.py | 4 ++-- authlib/oauth2/rfc7521/client.py | 5 ++++- tests/clients/test_httpx/test_assertion_client.py | 1 + tests/clients/test_httpx/test_async_assertion_client.py | 1 + 5 files changed, 12 insertions(+), 7 deletions(-) diff --git a/authlib/integrations/httpx_client/assertion_client.py b/authlib/integrations/httpx_client/assertion_client.py index 4832850c..46c9cf0e 100644 --- a/authlib/integrations/httpx_client/assertion_client.py +++ b/authlib/integrations/httpx_client/assertion_client.py @@ -17,7 +17,7 @@ class AsyncAssertionClient(_AssertionClient, AsyncClient): DEFAULT_GRANT_TYPE = JWT_BEARER_GRANT_TYPE def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=None, - claims=None, token_placement='header', scope=None, **kwargs): + claims=None, token_placement='header', scope=None, client_id=None, **kwargs): client_kwargs = extract_client_kwargs(kwargs) AsyncClient.__init__(self, **client_kwargs) @@ -26,7 +26,7 @@ def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=No self, session=None, token_endpoint=token_endpoint, issuer=issuer, subject=subject, audience=audience, grant_type=grant_type, claims=claims, - token_placement=token_placement, scope=scope, **kwargs + token_placement=token_placement, scope=scope, client_id=client_id, **kwargs ) async def request(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs): @@ -62,7 +62,7 @@ class AssertionClient(_AssertionClient, Client): DEFAULT_GRANT_TYPE = JWT_BEARER_GRANT_TYPE def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=None, - claims=None, token_placement='header', scope=None, **kwargs): + claims=None, token_placement='header', scope=None, client_id=None, **kwargs): client_kwargs = extract_client_kwargs(kwargs) Client.__init__(self, **client_kwargs) @@ -71,7 +71,7 @@ def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=No self, session=self, token_endpoint=token_endpoint, issuer=issuer, subject=subject, audience=audience, grant_type=grant_type, claims=claims, - token_placement=token_placement, scope=scope, **kwargs + token_placement=token_placement, scope=scope, client_id=client_id, **kwargs ) def request(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs): diff --git a/authlib/integrations/requests_client/assertion_session.py b/authlib/integrations/requests_client/assertion_session.py index b5eb3891..f91536b8 100644 --- a/authlib/integrations/requests_client/assertion_session.py +++ b/authlib/integrations/requests_client/assertion_session.py @@ -25,14 +25,14 @@ class AssertionSession(AssertionClient, Session): DEFAULT_GRANT_TYPE = JWT_BEARER_GRANT_TYPE def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=None, - claims=None, token_placement='header', scope=None, **kwargs): + claims=None, token_placement='header', scope=None, client_id=None, **kwargs): Session.__init__(self) update_session_configure(self, kwargs) AssertionClient.__init__( self, session=self, token_endpoint=token_endpoint, issuer=issuer, subject=subject, audience=audience, grant_type=grant_type, claims=claims, - token_placement=token_placement, scope=scope, **kwargs + token_placement=token_placement, scope=scope, client_id=client_id, **kwargs ) def request(self, method, url, withhold_token=False, auth=None, **kwargs): diff --git a/authlib/oauth2/rfc7521/client.py b/authlib/oauth2/rfc7521/client.py index d1b98ba5..d02dce3e 100644 --- a/authlib/oauth2/rfc7521/client.py +++ b/authlib/oauth2/rfc7521/client.py @@ -14,7 +14,7 @@ class AssertionClient(object): def __init__(self, session, token_endpoint, issuer, subject, audience=None, grant_type=None, claims=None, - token_placement='header', scope=None, **kwargs): + token_placement='header', scope=None, client_id=None, **kwargs): self.session = session @@ -34,6 +34,7 @@ def __init__(self, session, token_endpoint, issuer, subject, self.audience = audience self.claims = claims self.scope = scope + self.client_id = client_id if self.token_auth_class is not None: self.token_auth = self.token_auth_class(None, token_placement, self) self._kwargs = kwargs @@ -66,6 +67,8 @@ def refresh_token(self): } if self.scope: data['scope'] = self.scope + if self.client_id: + data['client_id'] = self.client_id return self._refresh_token(data) diff --git a/tests/clients/test_httpx/test_assertion_client.py b/tests/clients/test_httpx/test_assertion_client.py index 1e267b82..d8855635 100644 --- a/tests/clients/test_httpx/test_assertion_client.py +++ b/tests/clients/test_httpx/test_assertion_client.py @@ -42,6 +42,7 @@ def verifier(request): header={'alg': 'HS256'}, key='secret', scope='email', + client_id='client', claims={'test_mode': 'true'}, app=MockDispatch(default_token, assert_func=verifier) ) as client: diff --git a/tests/clients/test_httpx/test_async_assertion_client.py b/tests/clients/test_httpx/test_async_assertion_client.py index 9087b864..8371cc31 100644 --- a/tests/clients/test_httpx/test_async_assertion_client.py +++ b/tests/clients/test_httpx/test_async_assertion_client.py @@ -44,6 +44,7 @@ async def verifier(request): header={'alg': 'HS256'}, key='secret', scope='email', + client_id='client', claims={'test_mode': 'true'}, app=AsyncMockDispatch(default_token, assert_func=verifier) ) as client: