Skip to content

Commit

Permalink
Added client_id parameter to AssertionClient
Browse files Browse the repository at this point in the history
Per https://datatracker.ietf.org/doc/html/rfc7521#section-4.1,
client_id parameter, although optional, can still be passed
when using assertions as authorization grants. Adding a way to pass
that id to refresh token body.
  • Loading branch information
Dmitry Supranovich authored and Dmitry Supranovich committed Aug 9, 2022
1 parent ca01dc4 commit 727ee0e
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 7 deletions.
8 changes: 4 additions & 4 deletions authlib/integrations/httpx_client/assertion_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class AsyncAssertionClient(_AssertionClient, AsyncClient):
DEFAULT_GRANT_TYPE = JWT_BEARER_GRANT_TYPE

def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=None,
claims=None, token_placement='header', scope=None, **kwargs):
claims=None, token_placement='header', scope=None, client_id=None, **kwargs):

client_kwargs = extract_client_kwargs(kwargs)
AsyncClient.__init__(self, **client_kwargs)
Expand All @@ -26,7 +26,7 @@ def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=No
self, session=None,
token_endpoint=token_endpoint, issuer=issuer, subject=subject,
audience=audience, grant_type=grant_type, claims=claims,
token_placement=token_placement, scope=scope, **kwargs
token_placement=token_placement, scope=scope, client_id=client_id, **kwargs
)

async def request(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs):
Expand Down Expand Up @@ -62,7 +62,7 @@ class AssertionClient(_AssertionClient, Client):
DEFAULT_GRANT_TYPE = JWT_BEARER_GRANT_TYPE

def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=None,
claims=None, token_placement='header', scope=None, **kwargs):
claims=None, token_placement='header', scope=None, client_id=None, **kwargs):

client_kwargs = extract_client_kwargs(kwargs)
Client.__init__(self, **client_kwargs)
Expand All @@ -71,7 +71,7 @@ def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=No
self, session=self,
token_endpoint=token_endpoint, issuer=issuer, subject=subject,
audience=audience, grant_type=grant_type, claims=claims,
token_placement=token_placement, scope=scope, **kwargs
token_placement=token_placement, scope=scope, client_id=client_id, **kwargs
)

def request(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions authlib/integrations/requests_client/assertion_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ class AssertionSession(AssertionClient, Session):
DEFAULT_GRANT_TYPE = JWT_BEARER_GRANT_TYPE

def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=None,
claims=None, token_placement='header', scope=None, **kwargs):
claims=None, token_placement='header', scope=None, client_id=None, **kwargs):
Session.__init__(self)
update_session_configure(self, kwargs)
AssertionClient.__init__(
self, session=self,
token_endpoint=token_endpoint, issuer=issuer, subject=subject,
audience=audience, grant_type=grant_type, claims=claims,
token_placement=token_placement, scope=scope, **kwargs
token_placement=token_placement, scope=scope, client_id=client_id, **kwargs
)

def request(self, method, url, withhold_token=False, auth=None, **kwargs):
Expand Down
5 changes: 4 additions & 1 deletion authlib/oauth2/rfc7521/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class AssertionClient(object):

def __init__(self, session, token_endpoint, issuer, subject,
audience=None, grant_type=None, claims=None,
token_placement='header', scope=None, **kwargs):
token_placement='header', scope=None, client_id=None, **kwargs):

self.session = session

Expand All @@ -34,6 +34,7 @@ def __init__(self, session, token_endpoint, issuer, subject,
self.audience = audience
self.claims = claims
self.scope = scope
self.client_id = client_id
if self.token_auth_class is not None:
self.token_auth = self.token_auth_class(None, token_placement, self)
self._kwargs = kwargs
Expand Down Expand Up @@ -66,6 +67,8 @@ def refresh_token(self):
}
if self.scope:
data['scope'] = self.scope
if self.client_id:
data['client_id'] = self.client_id

return self._refresh_token(data)

Expand Down
1 change: 1 addition & 0 deletions tests/clients/test_httpx/test_assertion_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def verifier(request):
header={'alg': 'HS256'},
key='secret',
scope='email',
client_id='client',
claims={'test_mode': 'true'},
app=MockDispatch(default_token, assert_func=verifier)
) as client:
Expand Down
1 change: 1 addition & 0 deletions tests/clients/test_httpx/test_async_assertion_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ async def verifier(request):
header={'alg': 'HS256'},
key='secret',
scope='email',
client_id='client',
claims={'test_mode': 'true'},
app=AsyncMockDispatch(default_token, assert_func=verifier)
) as client:
Expand Down

0 comments on commit 727ee0e

Please sign in to comment.