Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removed error prone key decode looping #7325

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 31 additions & 48 deletions redash/authentication/jwt_auth.py
Original file line number Diff line number Diff line change
@@ -1,82 +1,65 @@
import json
import logging

import jwt
import requests
from jwt import PyJWKClient


logger = logging.getLogger("jwt_auth")

FILE_SCHEME_PREFIX = "file://"


def get_public_key_from_file(url):
def get_signing_key_from_file(url):
file_path = url[len(FILE_SCHEME_PREFIX) :]
with open(file_path) as key_file:
key_str = key_file.read()

get_public_keys.key_cache[url] = [key_str]
get_signing_key.key_cache[url] = key_str
return key_str


def get_public_key_from_net(url):
r = requests.get(url)
r.raise_for_status()
data = r.json()
if "keys" in data:
public_keys = []
for key_dict in data["keys"]:
public_key = jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(key_dict))
public_keys.append(public_key)

get_public_keys.key_cache[url] = public_keys
return public_keys
else:
get_public_keys.key_cache[url] = data
return data
def get_signing_key_from_net(url, jwt_token):
optional_custom_headers = {"User-agent": "redash"}
client = PyJWKClient(url, headers=optional_custom_headers)
# Gets the matching signing key from the JWKS endpoint
signing_key = client.get_signing_key_from_jwt(jwt_token)
get_signing_key.key_cache[url] = signing_key
return signing_key


def get_public_keys(url):
def get_signing_key(url, jwt_token):
"""
Returns:
List of RSA public keys usable by PyJWT.
Signing key for given jwt_token.
"""
key_cache = get_public_keys.key_cache
keys = {}
key_cache = get_signing_key.key_cache
key = {}
if url in key_cache:
keys = key_cache[url]
key = key_cache[url]
else:
if url.startswith(FILE_SCHEME_PREFIX):
keys = [get_public_key_from_file(url)]
key = get_signing_key_from_file(url)
else:
keys = get_public_key_from_net(url)
return keys


get_public_keys.key_cache = {}
key = get_signing_key_from_net(url, jwt_token)
return key

#This cache shoud have a lifespan
get_signing_key.key_cache = {}

def verify_jwt_token(jwt_token, expected_issuer, expected_audience, algorithms, public_certs_url):
# https://developers.cloudflare.com/access/setting-up-access/validate-jwt-tokens/
# https://cloud.google.com/iap/docs/signed-headers-howto
# Loop through the keys since we can't pass the key set to the decoder
keys = get_public_keys(public_certs_url)

key_id = jwt.get_unverified_header(jwt_token).get("kid", "")
if key_id and isinstance(keys, dict):
keys = [keys.get(key_id)]

key = get_signing_key(public_certs_url, jwt_token)
valid_token = False
payload = None
for key in keys:
try:
# decode returns the claims which has the email if you need it
payload = jwt.decode(jwt_token, key=key, audience=expected_audience, algorithms=algorithms)
issuer = payload["iss"]
if issuer != expected_issuer:
raise Exception("Wrong issuer: {}".format(issuer))
valid_token = True
break
except Exception as e:
logging.exception(e)
try:
# decode returns the claims which has the email if you need it
payload = jwt.decode(jwt_token, key=key, audience=expected_audience, algorithms=algorithms)
issuer = payload["iss"]
if issuer != expected_issuer:
raise Exception("Wrong issuer: {}".format(issuer))
valid_token = True
except Exception as e:
logging.exception(e)

return payload, valid_token