diff --git a/jwt/jwks_client.py b/jwt/jwks_client.py index 120eb256..e83f84c8 100644 --- a/jwt/jwks_client.py +++ b/jwt/jwks_client.py @@ -87,7 +87,7 @@ def get_signing_keys(self, refresh: bool = False) -> List[PyJWK]: signing_keys = [ jwk_set_key for jwk_set_key in jwk_set.keys - if jwk_set_key.public_key_use in ["sig", None] and jwk_set_key.key_id + if jwk_set_key.public_key_use in ["sig", None] ] if not signing_keys: @@ -95,7 +95,7 @@ def get_signing_keys(self, refresh: bool = False) -> List[PyJWK]: return signing_keys - def get_signing_key(self, kid: str) -> PyJWK: + def get_signing_key(self, kid: Optional[str]) -> PyJWK: signing_keys = self.get_signing_keys() signing_key = self.match_kid(signing_keys, kid) @@ -117,7 +117,12 @@ def get_signing_key_from_jwt(self, token: str | bytes) -> PyJWK: return self.get_signing_key(header.get("kid")) @staticmethod - def match_kid(signing_keys: List[PyJWK], kid: str) -> Optional[PyJWK]: + def match_kid(signing_keys: List[PyJWK], kid: Optional[str]) -> Optional[PyJWK]: + if kid is None: + if len(signing_keys) == 1: + return signing_keys[0] + else: + return None signing_key = None for key in signing_keys: