From 8ec620ff28312923154ee82e1b6b7109fe434b0c Mon Sep 17 00:00:00 2001 From: Trevor Gross Date: Fri, 7 Jan 2022 20:52:06 -0500 Subject: [PATCH 01/17] Added basics for verify_type, addes some type hints, added missing flake8 dependency --- .gitignore | 1 + flask_jwt_extended/view_decorators.py | 12 +++++++++--- requirements.txt | 1 + 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index cf70e9fe..66c16882 100644 --- a/.gitignore +++ b/.gitignore @@ -81,6 +81,7 @@ celerybeat-schedule # virtualenv venv/ +.venv/ ENV/ # Spyder project settings diff --git a/flask_jwt_extended/view_decorators.py b/flask_jwt_extended/view_decorators.py index 55cfec17..273cbb41 100644 --- a/flask_jwt_extended/view_decorators.py +++ b/flask_jwt_extended/view_decorators.py @@ -35,7 +35,9 @@ def _verify_token_is_fresh(jwt_header, jwt_data): raise FreshTokenRequired("Fresh token required", jwt_header, jwt_data) -def verify_jwt_in_request(optional=False, fresh=False, refresh=False, locations=None): +def verify_jwt_in_request( + optional: bool = False, fresh: bool = False, refresh: bool = False, locations=None +): """ Verify that a valid JWT is present in the request, unless ``optional=True`` in which case no JWT is also considered valid. @@ -255,7 +257,9 @@ def _decode_jwt_from_json(refresh): return encoded_token, None -def _decode_jwt_from_request(locations, fresh, refresh=False): +def _decode_jwt_from_request( + locations, fresh, refresh: bool = False, verify_type: bool = True +): # Figure out what locations to look for the JWT in this request if isinstance(locations, str): locations = [locations] @@ -314,7 +318,9 @@ def _decode_jwt_from_request(locations, fresh, refresh=False): raise NoAuthorizationError(errors[0]) # Additional verifications provided by this extension - verify_token_type(decoded_token, refresh) + if verify_type: + verify_token_type(decoded_token, refresh) + if fresh: _verify_token_is_fresh(jwt_header, decoded_token) verify_token_not_blocklisted(jwt_header, decoded_token) diff --git a/requirements.txt b/requirements.txt index 082b8775..403a17c4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ black==21.6b0 cryptography==35.0.0 Flask==2.0.1 +flake8==4.0.1 Pallets-Sphinx-Themes==2.0.1 pre-commit==2.13.0 PyJWT==2.1.0 From 578345c4e4bef6bb4006f26c6d2ba29fe16f97a1 Mon Sep 17 00:00:00 2001 From: Trevor Gross Date: Fri, 7 Jan 2022 21:36:54 -0500 Subject: [PATCH 02/17] More type hinting --- docs/conf.py | 5 +++ examples/blocklist_database.py | 2 +- examples/blocklist_redis.py | 2 +- examples/loaders.py | 2 +- flask_jwt_extended/default_callbacks.py | 29 ++++++++------ flask_jwt_extended/exceptions.py | 8 ++-- flask_jwt_extended/internal_utils.py | 11 +++--- flask_jwt_extended/jwt_manager.py | 51 +++++++++++++------------ flask_jwt_extended/utils.py | 19 +++++---- requirements.txt | 2 +- 10 files changed, 75 insertions(+), 56 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index e1cb34b3..4da5cf51 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -358,3 +358,8 @@ # If true, do not generate a @detailmenu in the "Top" node's menu. # # texinfo_no_detailmenu = False + +# Fix warnings about refernce targets. See link: +# https://stackoverflow.com/questions/11417221/ +# sphinx-autodoc-gives-warning-pyclass-reference-target-not-found-type-warning +nitpick_ignore = [("py:class", "flask.app.Flask"), ("py:class", "datetime.timedelta")] diff --git a/examples/blocklist_database.py b/examples/blocklist_database.py index f4278ee1..f7ec1f40 100644 --- a/examples/blocklist_database.py +++ b/examples/blocklist_database.py @@ -36,7 +36,7 @@ class TokenBlocklist(db.Model): # Callback function to check if a JWT exists in the database blocklist @jwt.token_in_blocklist_loader -def check_if_token_revoked(jwt_header, jwt_payload): +def check_if_token_revoked(jwt_header, jwt_payload: dict) -> bool: jti = jwt_payload["jti"] token = db.session.query(TokenBlocklist.id).filter_by(jti=jti).scalar() return token is not None diff --git a/examples/blocklist_redis.py b/examples/blocklist_redis.py index bcc055f8..e1aa1f0a 100644 --- a/examples/blocklist_redis.py +++ b/examples/blocklist_redis.py @@ -26,7 +26,7 @@ # Callback function to check if a JWT exists in the redis blocklist @jwt.token_in_blocklist_loader -def check_if_token_is_revoked(jwt_header, jwt_payload): +def check_if_token_is_revoked(jwt_header, jwt_payload: dict): jti = jwt_payload["jti"] token_in_redis = jwt_redis_blocklist.get(jti) return token_in_redis is not None diff --git a/examples/loaders.py b/examples/loaders.py index 0dc5b537..02bd48fe 100644 --- a/examples/loaders.py +++ b/examples/loaders.py @@ -17,7 +17,7 @@ # response. Check the API documentation to see the required argument and return # values for other callback functions. @jwt.expired_token_loader -def my_expired_token_callback(jwt_header, jwt_payload): +def my_expired_token_callback(jwt_header, jwt_payload: dict): return jsonify(code="dave", err="I can't let you do that"), 401 diff --git a/flask_jwt_extended/default_callbacks.py b/flask_jwt_extended/default_callbacks.py index 32e88719..e2c4e7aa 100644 --- a/flask_jwt_extended/default_callbacks.py +++ b/flask_jwt_extended/default_callbacks.py @@ -7,11 +7,12 @@ http://flask-jwt-extended.readthedocs.io/en/latest/tokens_from_complex_object.html """ from flask import jsonify +from flask import Response from flask_jwt_extended.config import config -def default_additional_claims_callback(userdata): +def default_additional_claims_callback(userdata) -> dict: """ By default, we add no additional claims to the access tokens. @@ -22,11 +23,11 @@ def default_additional_claims_callback(userdata): return {} -def default_blocklist_callback(jwt_headers, jwt_data): +def default_blocklist_callback(jwt_headers: dict, jwt_data: dict) -> bool: return False -def default_jwt_headers_callback(default_headers): +def default_jwt_headers_callback(default_headers) -> dict: """ By default header typically consists of two parts: the type of the token, which is JWT, and the signing algorithm being used, such as HMAC SHA256 @@ -49,7 +50,9 @@ def default_user_identity_callback(userdata): return userdata -def default_expired_token_callback(_expired_jwt_header, _expired_jwt_data): +def default_expired_token_callback( + _expired_jwt_header: dict, _expired_jwt_data: dict +) -> Response: """ By default, if an expired token attempts to access a protected endpoint, we return a generic error message with a 401 status @@ -57,7 +60,7 @@ def default_expired_token_callback(_expired_jwt_header, _expired_jwt_data): return jsonify({config.error_msg_key: "Token has expired"}), 401 -def default_invalid_token_callback(error_string): +def default_invalid_token_callback(error_string: str) -> Response: """ By default, if an invalid token attempts to access a protected endpoint, we return the error string for why it is not valid with a 422 status code @@ -67,7 +70,7 @@ def default_invalid_token_callback(error_string): return jsonify({config.error_msg_key: error_string}), 422 -def default_unauthorized_callback(error_string): +def default_unauthorized_callback(error_string: str) -> Response: """ By default, if a protected endpoint is accessed without a JWT, we return the error string indicating why this is unauthorized, with a 401 status code @@ -77,7 +80,7 @@ def default_unauthorized_callback(error_string): return jsonify({config.error_msg_key: error_string}), 401 -def default_needs_fresh_token_callback(jwt_header, jwt_data): +def default_needs_fresh_token_callback(jwt_header: dict, jwt_data: dict) -> Response: """ By default, if a non-fresh jwt is used to access a ```fresh_jwt_required``` endpoint, we return a general error message with a 401 status code @@ -85,7 +88,7 @@ def default_needs_fresh_token_callback(jwt_header, jwt_data): return jsonify({config.error_msg_key: "Fresh token required"}), 401 -def default_revoked_token_callback(jwt_header, jwt_data): +def default_revoked_token_callback(jwt_header: dict, jwt_data: dict) -> Response: """ By default, if a revoked token is used to access a protected endpoint, we return a general error message with a 401 status code @@ -93,7 +96,7 @@ def default_revoked_token_callback(jwt_header, jwt_data): return jsonify({config.error_msg_key: "Token has been revoked"}), 401 -def default_user_lookup_error_callback(_jwt_header, jwt_data): +def default_user_lookup_error_callback(_jwt_header: dict, jwt_data: dict) -> Response: """ By default, if a user_lookup callback is defined and the callback function returns None, we return a general error message with a 401 @@ -104,14 +107,16 @@ def default_user_lookup_error_callback(_jwt_header, jwt_data): return jsonify(result), 401 -def default_token_verification_callback(_jwt_header, _jwt_data): +def default_token_verification_callback(_jwt_header: dict, _jwt_data: dict) -> bool: """ By default, we do not do any verification of the user claims. """ return True -def default_token_verification_failed_callback(_jwt_header, _jwt_data): +def default_token_verification_failed_callback( + _jwt_header: dict, _jwt_data: dict +) -> Response: """ By default, if the user claims verification failed, we return a generic error message with a 400 status code @@ -119,7 +124,7 @@ def default_token_verification_failed_callback(_jwt_header, _jwt_data): return jsonify({config.error_msg_key: "User claims verification failed"}), 400 -def default_decode_key_callback(jwt_header, jwt_data): +def default_decode_key_callback(jwt_header: dict, jwt_data: dict): """ By default, the decode key specified via the JWT_SECRET_KEY or JWT_PUBLIC_KEY settings will be used to decode all tokens diff --git a/flask_jwt_extended/exceptions.py b/flask_jwt_extended/exceptions.py index bc80d889..8fc5405a 100644 --- a/flask_jwt_extended/exceptions.py +++ b/flask_jwt_extended/exceptions.py @@ -60,7 +60,7 @@ class RevokedTokenError(JWTExtendedException): Error raised when a revoked token attempt to access a protected endpoint """ - def __init__(self, jwt_header, jwt_data): + def __init__(self, jwt_header: dict, jwt_data: dict): super().__init__("Token has been revoked") self.jwt_header = jwt_header self.jwt_data = jwt_data @@ -72,7 +72,7 @@ class FreshTokenRequired(JWTExtendedException): protected by fresh_jwt_required """ - def __init__(self, message, jwt_header, jwt_data): + def __init__(self, message, jwt_header: dict, jwt_data: dict): super().__init__(message) self.jwt_header = jwt_header self.jwt_data = jwt_data @@ -84,7 +84,7 @@ class UserLookupError(JWTExtendedException): that it cannot or will not load a user for the given identity. """ - def __init__(self, message, jwt_header, jwt_data): + def __init__(self, message, jwt_header: dict, jwt_data: dict): super().__init__(message) self.jwt_header = jwt_header self.jwt_data = jwt_data @@ -96,7 +96,7 @@ class UserClaimsVerificationError(JWTExtendedException): indicating that the expected user claims are invalid """ - def __init__(self, message, jwt_header, jwt_data): + def __init__(self, message, jwt_header: dict, jwt_data: dict): super().__init__(message) self.jwt_header = jwt_header self.jwt_data = jwt_data diff --git a/flask_jwt_extended/internal_utils.py b/flask_jwt_extended/internal_utils.py index ac18a855..13eb2d53 100644 --- a/flask_jwt_extended/internal_utils.py +++ b/flask_jwt_extended/internal_utils.py @@ -1,11 +1,12 @@ from flask import current_app +from flask_jwt_extended import JWTManager from flask_jwt_extended.exceptions import RevokedTokenError from flask_jwt_extended.exceptions import UserClaimsVerificationError from flask_jwt_extended.exceptions import WrongTokenError -def get_jwt_manager(): +def get_jwt_manager() -> JWTManager: try: return current_app.extensions["flask-jwt-extended"] except KeyError: # pragma: no cover @@ -15,7 +16,7 @@ def get_jwt_manager(): ) from None -def has_user_lookup(): +def has_user_lookup() -> bool: jwt_manager = get_jwt_manager() return jwt_manager._user_lookup_callback is not None @@ -25,20 +26,20 @@ def user_lookup(*args, **kwargs): return jwt_manager._user_lookup_callback(*args, **kwargs) -def verify_token_type(decoded_token, refresh): +def verify_token_type(decoded_token, refresh) -> None: if not refresh and decoded_token["type"] == "refresh": raise WrongTokenError("Only non-refresh tokens are allowed") elif refresh and decoded_token["type"] != "refresh": raise WrongTokenError("Only refresh tokens are allowed") -def verify_token_not_blocklisted(jwt_header, jwt_data): +def verify_token_not_blocklisted(jwt_header: dict, jwt_data: dict) -> None: jwt_manager = get_jwt_manager() if jwt_manager._token_in_blocklist_callback(jwt_header, jwt_data): raise RevokedTokenError(jwt_header, jwt_data) -def custom_verification_for_token(jwt_header, jwt_data): +def custom_verification_for_token(jwt_header: dict, jwt_data: dict) -> None: jwt_manager = get_jwt_manager() if not jwt_manager._token_verification_callback(jwt_header, jwt_data): error_msg = "User claims verification failed" diff --git a/flask_jwt_extended/jwt_manager.py b/flask_jwt_extended/jwt_manager.py index b39d2c93..d895d9d9 100644 --- a/flask_jwt_extended/jwt_manager.py +++ b/flask_jwt_extended/jwt_manager.py @@ -1,6 +1,9 @@ import datetime +from typing import Any +from typing import Callable import jwt +from flask import Flask from jwt import DecodeError from jwt import ExpiredSignatureError from jwt import InvalidAudienceError @@ -49,7 +52,7 @@ class JWTManager(object): to your app in a factory function. """ - def __init__(self, app=None): + def __init__(self, app: Flask = None) -> None: """ Create the JWTManager instance. You can either pass a flask application in directly here to register this extension with the flask app, or @@ -82,7 +85,7 @@ def __init__(self, app=None): if app is not None: self.init_app(app) - def init_app(self, app): + def init_app(self, app: Flask) -> None: """ Register this extension with the flask app. @@ -98,7 +101,7 @@ def init_app(self, app): self._set_default_configuration_options(app) self._set_error_handler_callbacks(app) - def _set_error_handler_callbacks(self, app): + def _set_error_handler_callbacks(self, app: Flask) -> None: @app.errorhandler(CSRFError) def handle_csrf_error(e): return self._unauthorized_callback(str(e)) @@ -164,7 +167,7 @@ def handle_wrong_token_error(e): return self._invalid_token_callback(str(e)) @staticmethod - def _set_default_configuration_options(app): + def _set_default_configuration_options(app: Flask) -> None: app.config.setdefault( "JWT_ACCESS_TOKEN_EXPIRES", datetime.timedelta(minutes=15) ) @@ -210,7 +213,7 @@ def _set_default_configuration_options(app): app.config.setdefault("JWT_TOKEN_LOCATION", ("headers",)) app.config.setdefault("JWT_ENCODE_NBF", True) - def additional_claims_loader(self, callback): + def additional_claims_loader(self, callback: Callable) -> Callable: """ This decorator sets the callback function used to add additional claims when creating a JWT. The claims returned by this function will be merged @@ -227,7 +230,7 @@ def additional_claims_loader(self, callback): self._user_claims_callback = callback return callback - def additional_headers_loader(self, callback): + def additional_headers_loader(self, callback: Callable) -> Callable: """ This decorator sets the callback function used to add additional headers when creating a JWT. The headers returned by this function will be merged @@ -244,7 +247,7 @@ def additional_headers_loader(self, callback): self._jwt_additional_header_callback = callback return callback - def decode_key_loader(self, callback): + def decode_key_loader(self, callback: Callable) -> Callable: """ This decorator sets the callback function for dynamically setting the JWT decode key based on the **UNVERIFIED** contents of the token. Think @@ -265,7 +268,7 @@ def decode_key_loader(self, callback): self._decode_key_callback = callback return callback - def encode_key_loader(self, callback): + def encode_key_loader(self, callback: Callable) -> Callable: """ This decorator sets the callback function for dynamically setting the JWT encode key based on the tokens identity. Think carefully before using this @@ -281,7 +284,7 @@ def encode_key_loader(self, callback): self._encode_key_callback = callback return callback - def expired_token_loader(self, callback): + def expired_token_loader(self, callback: Callable) -> Callable: """ This decorator sets the callback function for returning a custom response when an expired JWT is encountered. @@ -297,7 +300,7 @@ def expired_token_loader(self, callback): self._expired_token_callback = callback return callback - def invalid_token_loader(self, callback): + def invalid_token_loader(self, callback: Callable) -> Callable: """ This decorator sets the callback function for returning a custom response when an invalid JWT is encountered. @@ -314,7 +317,7 @@ def invalid_token_loader(self, callback): self._invalid_token_callback = callback return callback - def needs_fresh_token_loader(self, callback): + def needs_fresh_token_loader(self, callback: Callable) -> Callable: """ This decorator sets the callback function for returning a custom response when a valid and non-fresh token is used on an endpoint @@ -331,7 +334,7 @@ def needs_fresh_token_loader(self, callback): self._needs_fresh_token_callback = callback return callback - def revoked_token_loader(self, callback): + def revoked_token_loader(self, callback: Callable) -> Callable: """ This decorator sets the callback function for returning a custom response when a revoked token is encountered. @@ -347,7 +350,7 @@ def revoked_token_loader(self, callback): self._revoked_token_callback = callback return callback - def token_in_blocklist_loader(self, callback): + def token_in_blocklist_loader(self, callback: Callable) -> Callable: """ This decorator sets the callback function used to check if a JWT has been revoked. @@ -364,7 +367,7 @@ def token_in_blocklist_loader(self, callback): self._token_in_blocklist_callback = callback return callback - def token_verification_failed_loader(self, callback): + def token_verification_failed_loader(self, callback: Callable) -> Callable: """ This decorator sets the callback function used to return a custom response when the claims verification check fails. @@ -380,7 +383,7 @@ def token_verification_failed_loader(self, callback): self._token_verification_failed_callback = callback return callback - def token_verification_loader(self, callback): + def token_verification_loader(self, callback: Callable) -> Callable: """ This decorator sets the callback function used for custom verification of a valid JWT. @@ -397,7 +400,7 @@ def token_verification_loader(self, callback): self._token_verification_callback = callback return callback - def unauthorized_loader(self, callback): + def unauthorized_loader(self, callback: Callable) -> Callable: """ This decorator sets the callback function used to return a custom response when no JWT is present. @@ -411,7 +414,7 @@ def unauthorized_loader(self, callback): self._unauthorized_callback = callback return callback - def user_identity_loader(self, callback): + def user_identity_loader(self, callback: Callable) -> Callable: """ This decorator sets the callback function used to convert an identity to a JSON serializable format when creating JWTs. This is useful for @@ -427,7 +430,7 @@ def user_identity_loader(self, callback): self._user_identity_callback = callback return callback - def user_lookup_loader(self, callback): + def user_lookup_loader(self, callback: Callable) -> Callable: """ This decorator sets the callback function used to convert a JWT into a python object that can be used in a protected endpoint. This is useful @@ -452,7 +455,7 @@ def user_lookup_loader(self, callback): self._user_lookup_callback = callback return callback - def user_lookup_error_loader(self, callback): + def user_lookup_error_loader(self, callback: Callable) -> Callable: """ This decorator sets the callback function used to return a custom response when loading a user via @@ -471,11 +474,11 @@ def user_lookup_error_loader(self, callback): def _encode_jwt_from_config( self, - identity, - token_type, + identity: Any, + token_type: str, claims=None, - fresh=False, - expires_delta=None, + fresh: bool = False, + expires_delta: datetime.timedelta = None, headers=None, ): header_overrides = self._jwt_additional_header_callback(identity) @@ -510,7 +513,7 @@ def _encode_jwt_from_config( ) def _decode_jwt_from_config( - self, encoded_token, csrf_value=None, allow_expired=False + self, encoded_token, csrf_value=None, allow_expired: bool = False ): unverified_claims = jwt.decode( encoded_token, diff --git a/flask_jwt_extended/utils.py b/flask_jwt_extended/utils.py index 4a560a01..a284b0a0 100644 --- a/flask_jwt_extended/utils.py +++ b/flask_jwt_extended/utils.py @@ -1,3 +1,6 @@ +import datetime +from typing import Any + import jwt from flask import _request_ctx_stack from werkzeug.local import LocalProxy @@ -5,12 +8,11 @@ from flask_jwt_extended.config import config from flask_jwt_extended.internal_utils import get_jwt_manager - # Proxy to access the current user current_user = LocalProxy(lambda: get_current_user()) -def get_jwt(): +def get_jwt() -> dict: """ In a protected endpoint, this will return the python dictionary which has the payload of the JWT that is accessing the endpoint. If no JWT is present @@ -28,7 +30,7 @@ def get_jwt(): return decoded_jwt -def get_jwt_header(): +def get_jwt_header() -> dict: """ In a protected endpoint, this will return the python dictionary which has the header of the JWT that is accessing the endpoint. If no JWT is present @@ -66,7 +68,7 @@ def get_jwt_request_location(): None is returned. :return: - The location of the JWT in the current request; e.g., cookies", + The location of the JWT in the current request; e.g., "cookies", "query-string", "headers", or "json" """ location = getattr(_request_ctx_stack.top, "jwt_location", None) @@ -97,7 +99,7 @@ def get_current_user(): return jwt_user_dict["loaded_user"] -def decode_token(encoded_token, csrf_value=None, allow_expired=False): +def decode_token(encoded_token, csrf_value=None, allow_expired: bool = False) -> dict: """ Returns the decoded token (python dict) from an encoded JWT. This does all the checks to insure that the decoded token is valid before returning it. @@ -124,7 +126,7 @@ def decode_token(encoded_token, csrf_value=None, allow_expired=False): def create_access_token( identity, - fresh=False, + fresh: bool = False, expires_delta=None, additional_claims=None, additional_headers=None, @@ -177,7 +179,10 @@ def create_access_token( def create_refresh_token( - identity, expires_delta=None, additional_claims=None, additional_headers=None + identity: Any, + expires_delta: datetime.timedelta = None, + additional_claims=None, + additional_headers=None, ): """ Create a new refresh token. diff --git a/requirements.txt b/requirements.txt index 403a17c4..5e24dec4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,5 +5,5 @@ flake8==4.0.1 Pallets-Sphinx-Themes==2.0.1 pre-commit==2.13.0 PyJWT==2.1.0 -Sphinx==4.0.2 +Sphinx==4.3.2 tox==3.23.1 From fb0b26980dd96b9a65571a24765290999a660703 Mon Sep 17 00:00:00 2001 From: Trevor Gross Date: Fri, 7 Jan 2022 22:51:26 -0500 Subject: [PATCH 03/17] Finished type hinting, tests passed --- .gitignore | 3 + docs/conf.py | 6 +- flask_jwt_extended/config.py | 110 ++++++++++++------------ flask_jwt_extended/default_callbacks.py | 30 +++++-- flask_jwt_extended/exceptions.py | 8 +- flask_jwt_extended/internal_utils.py | 2 +- flask_jwt_extended/jwt_manager.py | 6 +- flask_jwt_extended/tokens.py | 52 +++++------ flask_jwt_extended/utils.py | 34 +++++--- flask_jwt_extended/view_decorators.py | 38 +++++--- 10 files changed, 171 insertions(+), 118 deletions(-) diff --git a/.gitignore b/.gitignore index 66c16882..474e25fc 100644 --- a/.gitignore +++ b/.gitignore @@ -95,3 +95,6 @@ ENV/ # MacOS specific crap .DS_Store + +# Workspace +.vscode/ diff --git a/docs/conf.py b/docs/conf.py index 4da5cf51..1b83443d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -362,4 +362,8 @@ # Fix warnings about refernce targets. See link: # https://stackoverflow.com/questions/11417221/ # sphinx-autodoc-gives-warning-pyclass-reference-target-not-found-type-warning -nitpick_ignore = [("py:class", "flask.app.Flask"), ("py:class", "datetime.timedelta")] +nitpick_ignore = [ + ("py:class", "flask.app.Flask"), + ("py:class", "datetime.timedelta"), + ("py:class", "flask.wrappers.Response"), +] diff --git a/flask_jwt_extended/config.py b/flask_jwt_extended/config.py index 764d6552..cf1517d3 100644 --- a/flask_jwt_extended/config.py +++ b/flask_jwt_extended/config.py @@ -1,10 +1,12 @@ -from collections.abc import Sequence -from collections.abc import Set from datetime import datetime from datetime import timedelta from datetime import timezone +from typing import Iterable +from typing import List +from typing import Union from flask import current_app +from flask.json import JSONEncoder from jwt.algorithms import requires_cryptography @@ -20,23 +22,23 @@ class _Config(object): """ @property - def is_asymmetric(self): + def is_asymmetric(self) -> bool: return self.algorithm in requires_cryptography @property - def encode_key(self): + def encode_key(self) -> str: return self._private_key if self.is_asymmetric else self._secret_key @property - def decode_key(self): + def decode_key(self) -> str: return self._public_key if self.is_asymmetric else self._secret_key @property - def token_location(self): + def token_location(self) -> Iterable[str]: locations = current_app.config["JWT_TOKEN_LOCATION"] if isinstance(locations, str): locations = (locations,) - elif not isinstance(locations, (Sequence, Set)): + elif not isinstance(locations, Iterable): raise RuntimeError("JWT_TOKEN_LOCATION must be a sequence or a set") elif not locations: raise RuntimeError( @@ -52,90 +54,90 @@ def token_location(self): return locations @property - def jwt_in_cookies(self): + def jwt_in_cookies(self) -> bool: return "cookies" in self.token_location @property - def jwt_in_headers(self): + def jwt_in_headers(self) -> bool: return "headers" in self.token_location @property - def jwt_in_query_string(self): + def jwt_in_query_string(self) -> bool: return "query_string" in self.token_location @property - def jwt_in_json(self): + def jwt_in_json(self) -> bool: return "json" in self.token_location @property - def header_name(self): + def header_name(self) -> str: name = current_app.config["JWT_HEADER_NAME"] if not name: raise RuntimeError("JWT_ACCESS_HEADER_NAME cannot be empty") return name @property - def header_type(self): + def header_type(self) -> str: return current_app.config["JWT_HEADER_TYPE"] @property - def query_string_name(self): + def query_string_name(self) -> str: return current_app.config["JWT_QUERY_STRING_NAME"] @property - def query_string_value_prefix(self): + def query_string_value_prefix(self) -> str: return current_app.config["JWT_QUERY_STRING_VALUE_PREFIX"] @property - def access_cookie_name(self): + def access_cookie_name(self) -> str: return current_app.config["JWT_ACCESS_COOKIE_NAME"] @property - def refresh_cookie_name(self): + def refresh_cookie_name(self) -> str: return current_app.config["JWT_REFRESH_COOKIE_NAME"] @property - def access_cookie_path(self): + def access_cookie_path(self) -> str: return current_app.config["JWT_ACCESS_COOKIE_PATH"] @property - def refresh_cookie_path(self): + def refresh_cookie_path(self) -> str: return current_app.config["JWT_REFRESH_COOKIE_PATH"] @property - def cookie_secure(self): + def cookie_secure(self) -> bool: return current_app.config["JWT_COOKIE_SECURE"] @property - def cookie_domain(self): + def cookie_domain(self) -> str: return current_app.config["JWT_COOKIE_DOMAIN"] @property - def session_cookie(self): + def session_cookie(self) -> bool: return current_app.config["JWT_SESSION_COOKIE"] @property - def cookie_samesite(self): + def cookie_samesite(self) -> str: return current_app.config["JWT_COOKIE_SAMESITE"] @property - def json_key(self): + def json_key(self) -> str: return current_app.config["JWT_JSON_KEY"] @property - def refresh_json_key(self): + def refresh_json_key(self) -> str: return current_app.config["JWT_REFRESH_JSON_KEY"] @property - def csrf_protect(self): + def csrf_protect(self) -> bool: return self.jwt_in_cookies and current_app.config["JWT_COOKIE_CSRF_PROTECT"] @property - def csrf_request_methods(self): + def csrf_request_methods(self) -> Iterable[str]: return current_app.config["JWT_CSRF_METHODS"] @property - def csrf_in_cookies(self): + def csrf_in_cookies(self) -> bool: return current_app.config["JWT_CSRF_IN_COOKIES"] @property @@ -143,39 +145,39 @@ def access_csrf_cookie_name(self): return current_app.config["JWT_ACCESS_CSRF_COOKIE_NAME"] @property - def refresh_csrf_cookie_name(self): + def refresh_csrf_cookie_name(self) -> str: return current_app.config["JWT_REFRESH_CSRF_COOKIE_NAME"] @property - def access_csrf_cookie_path(self): + def access_csrf_cookie_path(self) -> str: return current_app.config["JWT_ACCESS_CSRF_COOKIE_PATH"] @property - def refresh_csrf_cookie_path(self): + def refresh_csrf_cookie_path(self) -> str: return current_app.config["JWT_REFRESH_CSRF_COOKIE_PATH"] @property - def access_csrf_header_name(self): + def access_csrf_header_name(self) -> str: return current_app.config["JWT_ACCESS_CSRF_HEADER_NAME"] @property - def refresh_csrf_header_name(self): + def refresh_csrf_header_name(self) -> str: return current_app.config["JWT_REFRESH_CSRF_HEADER_NAME"] @property - def csrf_check_form(self): + def csrf_check_form(self) -> bool: return current_app.config["JWT_CSRF_CHECK_FORM"] @property - def access_csrf_field_name(self): + def access_csrf_field_name(self) -> str: return current_app.config["JWT_ACCESS_CSRF_FIELD_NAME"] @property - def refresh_csrf_field_name(self): + def refresh_csrf_field_name(self) -> str: return current_app.config["JWT_REFRESH_CSRF_FIELD_NAME"] @property - def access_expires(self): + def access_expires(self) -> datetime: delta = current_app.config["JWT_ACCESS_TOKEN_EXPIRES"] if type(delta) is int: delta = timedelta(seconds=delta) @@ -190,7 +192,7 @@ def access_expires(self): return delta @property - def refresh_expires(self): + def refresh_expires(self) -> datetime: delta = current_app.config["JWT_REFRESH_TOKEN_EXPIRES"] if type(delta) is int: delta = timedelta(seconds=delta) @@ -205,11 +207,11 @@ def refresh_expires(self): return delta @property - def algorithm(self): + def algorithm(self) -> str: return current_app.config["JWT_ALGORITHM"] @property - def decode_algorithms(self): + def decode_algorithms(self) -> List[str]: algorithms = current_app.config["JWT_DECODE_ALGORITHMS"] if not algorithms: return [self.algorithm] @@ -218,7 +220,7 @@ def decode_algorithms(self): return algorithms @property - def _secret_key(self): + def _secret_key(self) -> str: key = current_app.config["JWT_SECRET_KEY"] if not key: key = current_app.config.get("SECRET_KEY", None) @@ -231,7 +233,7 @@ def _secret_key(self): return key @property - def _public_key(self): + def _public_key(self) -> str: key = current_app.config["JWT_PUBLIC_KEY"] if not key: raise RuntimeError( @@ -242,7 +244,7 @@ def _public_key(self): return key @property - def _private_key(self): + def _private_key(self) -> str: key = current_app.config["JWT_PRIVATE_KEY"] if not key: raise RuntimeError( @@ -253,50 +255,50 @@ def _private_key(self): return key @property - def cookie_max_age(self): + def cookie_max_age(self) -> int: # Returns the appropiate value for max_age for flask set_cookies. If # session cookie is true, return None, otherwise return a number of # seconds 1 year in the future return None if self.session_cookie else 31540000 # 1 year @property - def identity_claim_key(self): + def identity_claim_key(self) -> str: return current_app.config["JWT_IDENTITY_CLAIM"] @property - def exempt_methods(self): + def exempt_methods(self) -> Iterable[str]: return {"OPTIONS"} @property - def error_msg_key(self): + def error_msg_key(self) -> str: return current_app.config["JWT_ERROR_MESSAGE_KEY"] @property - def json_encoder(self): + def json_encoder(self) -> JSONEncoder: return current_app.json_encoder @property - def decode_audience(self): + def decode_audience(self) -> Union[str, Iterable[str]]: return current_app.config["JWT_DECODE_AUDIENCE"] @property - def encode_audience(self): + def encode_audience(self) -> Union[str, Iterable[str]]: return current_app.config["JWT_ENCODE_AUDIENCE"] @property - def encode_issuer(self): + def encode_issuer(self) -> str: return current_app.config["JWT_ENCODE_ISSUER"] @property - def decode_issuer(self): + def decode_issuer(self) -> str: return current_app.config["JWT_DECODE_ISSUER"] @property - def leeway(self): + def leeway(self) -> int: return current_app.config["JWT_DECODE_LEEWAY"] @property - def encode_nbf(self): + def encode_nbf(self) -> bool: return current_app.config["JWT_ENCODE_NBF"] diff --git a/flask_jwt_extended/default_callbacks.py b/flask_jwt_extended/default_callbacks.py index e2c4e7aa..72accea3 100644 --- a/flask_jwt_extended/default_callbacks.py +++ b/flask_jwt_extended/default_callbacks.py @@ -6,6 +6,8 @@ http://flask-jwt-extended.readthedocs.io/en/latest/changing_default_behavior.html http://flask-jwt-extended.readthedocs.io/en/latest/tokens_from_complex_object.html """ +from http import HTTPStatus + from flask import jsonify from flask import Response @@ -57,7 +59,7 @@ def default_expired_token_callback( By default, if an expired token attempts to access a protected endpoint, we return a generic error message with a 401 status """ - return jsonify({config.error_msg_key: "Token has expired"}), 401 + return jsonify({config.error_msg_key: "Token has expired"}), HTTPStatus.UNAUTHORIZED def default_invalid_token_callback(error_string: str) -> Response: @@ -67,7 +69,10 @@ def default_invalid_token_callback(error_string: str) -> Response: :param error_string: String indicating why the token is invalid """ - return jsonify({config.error_msg_key: error_string}), 422 + return ( + jsonify({config.error_msg_key: error_string}), + HTTPStatus.UNPROCESSABLE_ENTITY, + ) def default_unauthorized_callback(error_string: str) -> Response: @@ -77,7 +82,7 @@ def default_unauthorized_callback(error_string: str) -> Response: :param error_string: String indicating why this request is unauthorized """ - return jsonify({config.error_msg_key: error_string}), 401 + return jsonify({config.error_msg_key: error_string}), HTTPStatus.UNAUTHORIZED def default_needs_fresh_token_callback(jwt_header: dict, jwt_data: dict) -> Response: @@ -85,7 +90,10 @@ def default_needs_fresh_token_callback(jwt_header: dict, jwt_data: dict) -> Resp By default, if a non-fresh jwt is used to access a ```fresh_jwt_required``` endpoint, we return a general error message with a 401 status code """ - return jsonify({config.error_msg_key: "Fresh token required"}), 401 + return ( + jsonify({config.error_msg_key: "Fresh token required"}), + HTTPStatus.UNAUTHORIZED, + ) def default_revoked_token_callback(jwt_header: dict, jwt_data: dict) -> Response: @@ -93,7 +101,10 @@ def default_revoked_token_callback(jwt_header: dict, jwt_data: dict) -> Response By default, if a revoked token is used to access a protected endpoint, we return a general error message with a 401 status code """ - return jsonify({config.error_msg_key: "Token has been revoked"}), 401 + return ( + jsonify({config.error_msg_key: "Token has been revoked"}), + HTTPStatus.UNAUTHORIZED, + ) def default_user_lookup_error_callback(_jwt_header: dict, jwt_data: dict) -> Response: @@ -103,8 +114,8 @@ def default_user_lookup_error_callback(_jwt_header: dict, jwt_data: dict) -> Res status code """ identity = jwt_data[config.identity_claim_key] - result = {config.error_msg_key: "Error loading the user {}".format(identity)} - return jsonify(result), 401 + result = {config.error_msg_key: f"Error loading the user {identity}"} + return jsonify(result), HTTPStatus.UNAUTHORIZED def default_token_verification_callback(_jwt_header: dict, _jwt_data: dict) -> bool: @@ -121,7 +132,10 @@ def default_token_verification_failed_callback( By default, if the user claims verification failed, we return a generic error message with a 400 status code """ - return jsonify({config.error_msg_key: "User claims verification failed"}), 400 + return ( + jsonify({config.error_msg_key: "User claims verification failed"}), + HTTPStatus.BAD_REQUEST, + ) def default_decode_key_callback(jwt_header: dict, jwt_data: dict): diff --git a/flask_jwt_extended/exceptions.py b/flask_jwt_extended/exceptions.py index 8fc5405a..2ad414a8 100644 --- a/flask_jwt_extended/exceptions.py +++ b/flask_jwt_extended/exceptions.py @@ -60,7 +60,7 @@ class RevokedTokenError(JWTExtendedException): Error raised when a revoked token attempt to access a protected endpoint """ - def __init__(self, jwt_header: dict, jwt_data: dict): + def __init__(self, jwt_header: dict, jwt_data: dict) -> None: super().__init__("Token has been revoked") self.jwt_header = jwt_header self.jwt_data = jwt_data @@ -72,7 +72,7 @@ class FreshTokenRequired(JWTExtendedException): protected by fresh_jwt_required """ - def __init__(self, message, jwt_header: dict, jwt_data: dict): + def __init__(self, message, jwt_header: dict, jwt_data: dict) -> None: super().__init__(message) self.jwt_header = jwt_header self.jwt_data = jwt_data @@ -84,7 +84,7 @@ class UserLookupError(JWTExtendedException): that it cannot or will not load a user for the given identity. """ - def __init__(self, message, jwt_header: dict, jwt_data: dict): + def __init__(self, message, jwt_header: dict, jwt_data: dict) -> None: super().__init__(message) self.jwt_header = jwt_header self.jwt_data = jwt_data @@ -96,7 +96,7 @@ class UserClaimsVerificationError(JWTExtendedException): indicating that the expected user claims are invalid """ - def __init__(self, message, jwt_header: dict, jwt_data: dict): + def __init__(self, message, jwt_header: dict, jwt_data: dict) -> None: super().__init__(message) self.jwt_header = jwt_header self.jwt_data = jwt_data diff --git a/flask_jwt_extended/internal_utils.py b/flask_jwt_extended/internal_utils.py index 13eb2d53..a49de656 100644 --- a/flask_jwt_extended/internal_utils.py +++ b/flask_jwt_extended/internal_utils.py @@ -26,7 +26,7 @@ def user_lookup(*args, **kwargs): return jwt_manager._user_lookup_callback(*args, **kwargs) -def verify_token_type(decoded_token, refresh) -> None: +def verify_token_type(decoded_token: dict, refresh: bool) -> None: if not refresh and decoded_token["type"] == "refresh": raise WrongTokenError("Only non-refresh tokens are allowed") elif refresh and decoded_token["type"] != "refresh": diff --git a/flask_jwt_extended/jwt_manager.py b/flask_jwt_extended/jwt_manager.py index d895d9d9..a8e032f0 100644 --- a/flask_jwt_extended/jwt_manager.py +++ b/flask_jwt_extended/jwt_manager.py @@ -480,7 +480,7 @@ def _encode_jwt_from_config( fresh: bool = False, expires_delta: datetime.timedelta = None, headers=None, - ): + ) -> str: header_overrides = self._jwt_additional_header_callback(identity) if headers is not None: header_overrides.update(headers) @@ -513,8 +513,8 @@ def _encode_jwt_from_config( ) def _decode_jwt_from_config( - self, encoded_token, csrf_value=None, allow_expired: bool = False - ): + self, encoded_token: str, csrf_value=None, allow_expired: bool = False + ) -> dict: unverified_claims = jwt.decode( encoded_token, algorithms=config.decode_algorithms, diff --git a/flask_jwt_extended/tokens.py b/flask_jwt_extended/tokens.py index 2b98b8d0..c78651bc 100644 --- a/flask_jwt_extended/tokens.py +++ b/flask_jwt_extended/tokens.py @@ -3,29 +3,33 @@ from datetime import timedelta from datetime import timezone from hmac import compare_digest +from typing import Any +from typing import Iterable +from typing import Union import jwt +from flask.json import JSONEncoder from flask_jwt_extended.exceptions import CSRFError from flask_jwt_extended.exceptions import JWTDecodeError def _encode_jwt( - algorithm, - audience, - claim_overrides, - csrf, - expires_delta, - fresh, + algorithm: str, + audience: Union[str, Iterable[str]], + claim_overrides: dict, + csrf: bool, + expires_delta: timedelta, + fresh: bool, header_overrides, - identity, - identity_claim_key, - issuer, - json_encoder, - secret, - token_type, - nbf, -): + identity: Any, + identity_claim_key: str, + issuer: str, + json_encoder: JSONEncoder, + secret: str, + token_type: str, + nbf: bool, +) -> str: now = datetime.now(timezone.utc) if isinstance(fresh, timedelta): @@ -67,17 +71,17 @@ def _encode_jwt( def _decode_jwt( - algorithms, - allow_expired, - audience, - csrf_value, - encoded_token, + algorithms: Iterable, + allow_expired: bool, + audience: Union[str, Iterable[str]], + csrf_value: str, + encoded_token: str, identity_claim_key, - issuer, - leeway, - secret, - verify_aud, -): + issuer: str, + leeway: int, + secret: str, + verify_aud: bool, +) -> dict: options = {"verify_aud": verify_aud} if allow_expired: options["verify_exp"] = False diff --git a/flask_jwt_extended/utils.py b/flask_jwt_extended/utils.py index a284b0a0..b0708b89 100644 --- a/flask_jwt_extended/utils.py +++ b/flask_jwt_extended/utils.py @@ -3,6 +3,7 @@ import jwt from flask import _request_ctx_stack +from flask import Response from werkzeug.local import LocalProxy from flask_jwt_extended.config import config @@ -60,7 +61,7 @@ def get_jwt_identity(): return get_jwt().get(config.identity_claim_key, None) -def get_jwt_request_location(): +def get_jwt_request_location() -> str: """ In a protected endpoint, this will return the "location" at which the JWT that is accessing the endpoint was found--e.g., "cookies", "query-string", @@ -99,7 +100,9 @@ def get_current_user(): return jwt_user_dict["loaded_user"] -def decode_token(encoded_token, csrf_value=None, allow_expired: bool = False) -> dict: +def decode_token( + encoded_token: str, csrf_value: str = None, allow_expired: bool = False +) -> dict: """ Returns the decoded token (python dict) from an encoded JWT. This does all the checks to insure that the decoded token is valid before returning it. @@ -125,9 +128,9 @@ def decode_token(encoded_token, csrf_value=None, allow_expired: bool = False) -> def create_access_token( - identity, + identity: Any, fresh: bool = False, - expires_delta=None, + expires_delta: datetime.timedelta = None, additional_claims=None, additional_headers=None, ): @@ -224,7 +227,7 @@ def create_refresh_token( ) -def get_unverified_jwt_headers(encoded_token): +def get_unverified_jwt_headers(encoded_token: str) -> dict: """ Returns the Headers of an encoded JWT without verifying the signature of the JWT. @@ -237,7 +240,7 @@ def get_unverified_jwt_headers(encoded_token): return jwt.get_unverified_header(encoded_token) -def get_jti(encoded_token): +def get_jti(encoded_token: str) -> str: """ Returns the JTI (unique identifier) of an encoded JWT @@ -250,7 +253,7 @@ def get_jti(encoded_token): return decode_token(encoded_token).get("jti") -def get_csrf_token(encoded_token): +def get_csrf_token(encoded_token: str) -> str: """ Returns the CSRF double submit token from an encoded JWT. @@ -264,7 +267,9 @@ def get_csrf_token(encoded_token): return token["csrf"] -def set_access_cookies(response, encoded_access_token, max_age=None, domain=None): +def set_access_cookies( + response: Response, encoded_access_token: str, max_age=None, domain=None +) -> None: """ Modifiy a Flask Response to set a cookie containing the access JWT. Also sets the corresponding CSRF cookies if ``JWT_CSRF_IN_COOKIES`` is ``True`` @@ -312,7 +317,12 @@ def set_access_cookies(response, encoded_access_token, max_age=None, domain=None ) -def set_refresh_cookies(response, encoded_refresh_token, max_age=None, domain=None): +def set_refresh_cookies( + response: Response, + encoded_refresh_token: str, + max_age: int = None, + domain: str = None, +) -> None: """ Modifiy a Flask Response to set a cookie containing the refresh JWT. Also sets the corresponding CSRF cookies if ``JWT_CSRF_IN_COOKIES`` is ``True`` @@ -360,7 +370,7 @@ def set_refresh_cookies(response, encoded_refresh_token, max_age=None, domain=No ) -def unset_jwt_cookies(response, domain=None): +def unset_jwt_cookies(response: Response, domain: str = None) -> None: """ Modifiy a Flask Response to delete the cookies containing access or refresh JWTs. Also deletes the corresponding CSRF cookies if applicable. @@ -372,7 +382,7 @@ def unset_jwt_cookies(response, domain=None): unset_refresh_cookies(response, domain) -def unset_access_cookies(response, domain=None): +def unset_access_cookies(response: Response, domain: str = None) -> None: """ Modifiy a Flask Response to delete the cookie containing an access JWT. Also deletes the corresponding CSRF cookie if applicable. @@ -410,7 +420,7 @@ def unset_access_cookies(response, domain=None): ) -def unset_refresh_cookies(response, domain=None): +def unset_refresh_cookies(response: Response, domain: str = None): """ Modifiy a Flask Response to delete the cookie containing a refresh JWT. Also deletes the corresponding CSRF cookie if applicable. diff --git a/flask_jwt_extended/view_decorators.py b/flask_jwt_extended/view_decorators.py index 273cbb41..7a9864c0 100644 --- a/flask_jwt_extended/view_decorators.py +++ b/flask_jwt_extended/view_decorators.py @@ -2,6 +2,9 @@ from datetime import timezone from functools import wraps from re import split +from typing import Iterable +from typing import Tuple +from typing import Union from flask import _request_ctx_stack from flask import current_app @@ -23,8 +26,10 @@ from flask_jwt_extended.utils import decode_token from flask_jwt_extended.utils import get_unverified_jwt_headers +LocationType = Union[str, Iterable] -def _verify_token_is_fresh(jwt_header, jwt_data): + +def _verify_token_is_fresh(jwt_header: dict, jwt_data: dict) -> None: fresh = jwt_data["fresh"] if isinstance(fresh, bool): if not fresh: @@ -36,8 +41,11 @@ def _verify_token_is_fresh(jwt_header, jwt_data): def verify_jwt_in_request( - optional: bool = False, fresh: bool = False, refresh: bool = False, locations=None -): + optional: bool = False, + fresh: bool = False, + refresh: bool = False, + locations: LocationType = None, +) -> Tuple[dict, dict]: """ Verify that a valid JWT is present in the request, unless ``optional=True`` in which case no JWT is also considered valid. @@ -90,7 +98,12 @@ def verify_jwt_in_request( return jwt_header, jwt_data -def jwt_required(optional=False, fresh=False, refresh=False, locations=None): +def jwt_required( + optional: bool = False, + fresh: bool = False, + refresh: bool = False, + locations: LocationType = None, +): """ A decorator to protect a Flask endpoint with JSON Web Tokens. @@ -135,7 +148,7 @@ def decorator(*args, **kwargs): return wrapper -def _load_user(jwt_header, jwt_data): +def _load_user(jwt_header: dict, jwt_data: dict) -> dict: if not has_user_lookup(): return None @@ -147,7 +160,7 @@ def _load_user(jwt_header, jwt_data): return {"loaded_user": user} -def _decode_jwt_from_headers(): +def _decode_jwt_from_headers() -> Tuple[str, str]: header_name = config.header_name header_type = config.header_type @@ -191,7 +204,7 @@ def _decode_jwt_from_headers(): return encoded_token, None -def _decode_jwt_from_cookies(refresh): +def _decode_jwt_from_cookies(refresh: bool) -> Tuple[str, str]: if refresh: cookie_key = config.refresh_cookie_name csrf_header_key = config.refresh_csrf_header_name @@ -217,7 +230,7 @@ def _decode_jwt_from_cookies(refresh): return encoded_token, csrf_value -def _decode_jwt_from_query_string(): +def _decode_jwt_from_query_string() -> Tuple[str, str]: param_name = config.query_string_name prefix = config.query_string_value_prefix @@ -235,7 +248,7 @@ def _decode_jwt_from_query_string(): return encoded_token, None -def _decode_jwt_from_json(refresh): +def _decode_jwt_from_json(refresh: bool) -> Tuple[str, str]: content_type = request.content_type or "" if not content_type.startswith("application/json"): raise NoAuthorizationError("Invalid content-type. Must be application/json.") @@ -258,8 +271,11 @@ def _decode_jwt_from_json(refresh): def _decode_jwt_from_request( - locations, fresh, refresh: bool = False, verify_type: bool = True -): + locations: LocationType, + fresh: bool, + refresh: bool = False, + verify_type: bool = True, +) -> Tuple[dict, dict, str]: # Figure out what locations to look for the JWT in this request if isinstance(locations, str): locations = [locations] From 763ef23e7a42e56744e4b92d5628cfcaf3a839ec Mon Sep 17 00:00:00 2001 From: Trevor Gross Date: Fri, 7 Jan 2022 23:33:51 -0500 Subject: [PATCH 04/17] Working on docs --- docs/blocklist_and_token_revoking.rst | 7 +++++++ docs/refreshing_tokens.rst | 7 +++++++ flask_jwt_extended/__init__.py | 1 + flask_jwt_extended/utils.py | 13 ++++++++++++ flask_jwt_extended/view_decorators.py | 29 ++++++++++++++++++--------- 5 files changed, 47 insertions(+), 10 deletions(-) diff --git a/docs/blocklist_and_token_revoking.rst b/docs/blocklist_and_token_revoking.rst index 10deb047..8b5119d8 100644 --- a/docs/blocklist_and_token_revoking.rst +++ b/docs/blocklist_and_token_revoking.rst @@ -33,3 +33,10 @@ revoked tokens, such as when it was revoked, who revoked it, can it be un-revoke etc. Here is an example using SQLAlchemy: .. literalinclude:: ../examples/blocklist_database.py + +Handling Revoking of Refresh Tokens +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +It is very important to note that a user's refresh token must also be revoked +when logging out; otherwise, this refresh token could just be used to generate +a new access token. Usually this falls to the responsibility of the frontend, +which must diff --git a/docs/refreshing_tokens.rst b/docs/refreshing_tokens.rst index 0f39db4a..8fc9f4c5 100644 --- a/docs/refreshing_tokens.rst +++ b/docs/refreshing_tokens.rst @@ -80,3 +80,10 @@ option when creating JWTs: .. code-block :: python create_access_token(identity, fresh=datetime.timedelta(minutes=15)) + + +Revoking Refresh Tokens +~~~~~~~~~~~~~~~~~~~~~~~ +Note that when an access token is invalidated (e.g. logging a user out), the +corresponding refresh token(s) must be revoked too. +See :ref:`Handling Revoking Refresh Tokens` for details on how to handle this. diff --git a/flask_jwt_extended/__init__.py b/flask_jwt_extended/__init__.py index 21373cc3..c6f4454a 100644 --- a/flask_jwt_extended/__init__.py +++ b/flask_jwt_extended/__init__.py @@ -10,6 +10,7 @@ from .utils import get_jwt_header from .utils import get_jwt_identity from .utils import get_jwt_request_location +from .utils import get_token_type from .utils import get_unverified_jwt_headers from .utils import set_access_cookies from .utils import set_refresh_cookies diff --git a/flask_jwt_extended/utils.py b/flask_jwt_extended/utils.py index b0708b89..c21ea017 100644 --- a/flask_jwt_extended/utils.py +++ b/flask_jwt_extended/utils.py @@ -253,6 +253,19 @@ def get_jti(encoded_token: str) -> str: return decode_token(encoded_token).get("jti") +def get_token_type(encoded_token: str) -> str: + """ + Returns the JTI (unique identifier) of an encoded JWT + + :param encoded_token: + The encoded JWT to get the JTI from. + + :return: + The JTI (unique identifier) of a JWT. + """ + return decode_token(encoded_token).get("jti") + + def get_csrf_token(encoded_token: str) -> str: """ Returns the CSRF double submit token from an encoded JWT. diff --git a/flask_jwt_extended/view_decorators.py b/flask_jwt_extended/view_decorators.py index 7a9864c0..757ad3c7 100644 --- a/flask_jwt_extended/view_decorators.py +++ b/flask_jwt_extended/view_decorators.py @@ -45,6 +45,7 @@ def verify_jwt_in_request( fresh: bool = False, refresh: bool = False, locations: LocationType = None, + verify_type: bool = False, ) -> Tuple[dict, dict]: """ Verify that a valid JWT is present in the request, unless ``optional=True`` in @@ -59,26 +60,28 @@ def verify_jwt_in_request( Defaults to ``False``. :param refresh: - If ``True``, require a refresh JWT to be verified. + If ``True``, requires a refresh JWT to access this endpoint. If ``False``, + requires an access JWT to access this endpoint. Defaults to ``False`` :param locations: A location or list of locations to look for the JWT in this request, for example ``'headers'`` or ``['headers', 'cookies']``. Defaults to ``None`` which indicates that JWTs will be looked for in the locations defined by the ``JWT_TOKEN_LOCATION`` configuration option. + + :param verify_type: + If ``True``, the token type (access or refresh) will be checked according + to the ``refresh`` argument. If ``False``, type will not be checked and both + access and refresh tokens will be accepted. """ if request.method in config.exempt_methods: return try: - if refresh: - jwt_data, jwt_header, jwt_location = _decode_jwt_from_request( - locations, fresh, refresh=True - ) - else: - jwt_data, jwt_header, jwt_location = _decode_jwt_from_request( - locations, fresh - ) + jwt_data, jwt_header, jwt_location = _decode_jwt_from_request( + locations, fresh, refresh=refresh, verify_type=verify_type + ) + except NoAuthorizationError: if not optional: raise @@ -103,6 +106,7 @@ def jwt_required( fresh: bool = False, refresh: bool = False, locations: LocationType = None, + verify_type: bool = True, ): """ A decorator to protect a Flask endpoint with JSON Web Tokens. @@ -128,12 +132,17 @@ def jwt_required( example ``'headers'`` or ``['headers', 'cookies']``. Defaults to ``None`` which indicates that JWTs will be looked for in the locations defined by the ``JWT_TOKEN_LOCATION`` configuration option. + + :param verify_type: + If ``True``, the token type (access or refresh) will be checked according + to the ``refresh`` argument. If ``False``, type will not be checked and both + access and refresh tokens will be accepted. """ def wrapper(fn): @wraps(fn) def decorator(*args, **kwargs): - verify_jwt_in_request(optional, fresh, refresh, locations) + verify_jwt_in_request(optional, fresh, refresh, locations, verify_type) # Compatibility with flask < 2.0 if hasattr(current_app, "ensure_sync") and callable( From 056ffb2e442aa794a2521a8ff7d2e38134c5b028 Mon Sep 17 00:00:00 2001 From: Trevor Gross Date: Sat, 8 Jan 2022 00:10:07 -0500 Subject: [PATCH 05/17] Work on docs --- docs/blocklist_and_token_revoking.rst | 70 ++++++++++++++++++++++++++- examples/blocklist_database.py | 19 +++----- flask_jwt_extended/utils.py | 3 +- 3 files changed, 78 insertions(+), 14 deletions(-) diff --git a/docs/blocklist_and_token_revoking.rst b/docs/blocklist_and_token_revoking.rst index 8b5119d8..4b940b5f 100644 --- a/docs/blocklist_and_token_revoking.rst +++ b/docs/blocklist_and_token_revoking.rst @@ -25,6 +25,14 @@ Live (TTL) functionality when storing a JWT. Here is an example using redis: .. literalinclude:: ../examples/blocklist_redis.py + +.. warning:: + Note that configuring redis to be disk-persistent is an absolutely necessity for + production use. Otherwise, events like power outages or server crashes/reboots + would cause all formerly invalidated tokens to become valid again (assuming the + secret key does not change). This is especially concering for long-lived + refresh tokens, discussed below. + Database ~~~~~~~~ If you need to keep track of information about revoked JWTs our recommendation is @@ -39,4 +47,64 @@ Handling Revoking of Refresh Tokens It is very important to note that a user's refresh token must also be revoked when logging out; otherwise, this refresh token could just be used to generate a new access token. Usually this falls to the responsibility of the frontend, -which must +which must + +It is very important to note that a user's refresh token(s) must also be revoked +when logging out; otherwise, this refresh token could just be used to generate +a new access token. Usually this falls to the responsibility of the frontend, +which should request + + +It is possible to use two different routes with ``@jwt_required()`` and +``@jwt_required(refresh=True)`` to accomplish this. However, it is convenient to +provide a single endpoint where both users + +.. code-block:: python + @app.route("/logout", methods=["DELETE"]) + @jwt_required(verify_type=False) + def logout(): + token = get_jwt() + jti = token["jti"] + ttype = token["type"] + jwt_redis_blocklist.set(jti, "", ex=ACCESS_EXPIRES) + + # Returns "Access token revoked" or "Refresh token revoked" + return jsonify(msg=f"{ttype.capitalize()} token revoked") + +or, for the database format: + +.. code-block:: python + class TokenBlocklist(db.Model): + id = db.Column(db.Integer, primary_key=True) + jti = db.Column(db.String(36), nullable=False, index=True) + type = db.Column(db.Integer, nullable=False) + user_id = db.Column( + db.ForeignKey('person.id') + nullable=False, + default=lambda: get_current_user().id + ) + created_at = db.Column(db.DateTime, nullable=False) + + @app.route("/logout", methods=["DELETE"]) + @jwt_required(verify_type=False) + def modify_token(): + token = get_jwt() + jti = token["jti"] + ttype = token["type"] + now = datetime.now(timezone.utc) + db.session.add(TokenBlocklist(jti=jti, type=ttype, created_at=now)) + db.session.commit() + return jsonify(msg=f"{ttype.capitalize()} token revoked") + +Token type and user are not required and can be omitted. That being said, including +these columns can help to audit that the frontend is performing its revoking job +correctly and revoking both tokens. + + +An alternative, albeit more complex, implementation is to invalidate all issued +tokens at once. + +#. Store all generated access and refresh tokens in a database with a user_id column or similar +#. Change +#. token_in_blocklist_loader +#. diff --git a/examples/blocklist_database.py b/examples/blocklist_database.py index f7ec1f40..40357481 100644 --- a/examples/blocklist_database.py +++ b/examples/blocklist_database.py @@ -1,16 +1,9 @@ -from datetime import datetime -from datetime import timedelta -from datetime import timezone +from datetime import datetime, timedelta, timezone -from flask import Flask -from flask import jsonify +from flask import Flask, jsonify +from flask_jwt_extended import JWTManager, create_access_token, get_jwt, jwt_required from flask_sqlalchemy import SQLAlchemy -from flask_jwt_extended import create_access_token -from flask_jwt_extended import get_jwt -from flask_jwt_extended import jwt_required -from flask_jwt_extended import JWTManager - app = Flask(__name__) ACCESS_EXPIRES = timedelta(hours=1) @@ -28,9 +21,12 @@ # This could be expanded to fit the needs of your application. For example, # it could track who revoked a JWT, when a token expires, notes for why a # JWT was revoked, an endpoint to un-revoked a JWT, etc. +# Making jti an index can significantly speed up the search when there are +# tens of thousands of records. Remember this query will happen for every +# (protected) request, class TokenBlocklist(db.Model): id = db.Column(db.Integer, primary_key=True) - jti = db.Column(db.String(36), nullable=False) + jti = db.Column(db.String(36), nullable=False, index=True) created_at = db.Column(db.DateTime, nullable=False) @@ -39,6 +35,7 @@ class TokenBlocklist(db.Model): def check_if_token_revoked(jwt_header, jwt_payload: dict) -> bool: jti = jwt_payload["jti"] token = db.session.query(TokenBlocklist.id).filter_by(jti=jti).scalar() + return token is not None diff --git a/flask_jwt_extended/utils.py b/flask_jwt_extended/utils.py index c21ea017..adcedc16 100644 --- a/flask_jwt_extended/utils.py +++ b/flask_jwt_extended/utils.py @@ -2,8 +2,7 @@ from typing import Any import jwt -from flask import _request_ctx_stack -from flask import Response +from flask import Response, _request_ctx_stack from werkzeug.local import LocalProxy from flask_jwt_extended.config import config From 7ff6e8e53f2b0750df5ff6c63100b9294db42b1a Mon Sep 17 00:00:00 2001 From: Trevor Gross Date: Sat, 8 Jan 2022 00:16:15 -0500 Subject: [PATCH 06/17] Work on docs --- docs/blocklist_and_token_revoking.rst | 34 +++++++++++++-------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/docs/blocklist_and_token_revoking.rst b/docs/blocklist_and_token_revoking.rst index 4b940b5f..895fbc07 100644 --- a/docs/blocklist_and_token_revoking.rst +++ b/docs/blocklist_and_token_revoking.rst @@ -44,20 +44,16 @@ etc. Here is an example using SQLAlchemy: Handling Revoking of Refresh Tokens ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -It is very important to note that a user's refresh token must also be revoked +It is critical to note that a user's refresh token must also be revoked when logging out; otherwise, this refresh token could just be used to generate -a new access token. Usually this falls to the responsibility of the frontend, -which must - -It is very important to note that a user's refresh token(s) must also be revoked -when logging out; otherwise, this refresh token could just be used to generate -a new access token. Usually this falls to the responsibility of the frontend, -which should request +a new access token. Usually this falls to the responsibility of the frontend +application, which must send two separate requests to the backend in order to +revoke these tokens. - -It is possible to use two different routes with ``@jwt_required()`` and -``@jwt_required(refresh=True)`` to accomplish this. However, it is convenient to -provide a single endpoint where both users +This can be implemented via two separate routes marked with ``@jwt_required()`` +and ``@jwt_required(refresh=True)`` to revoke access and refresh tokens, respectively. +However, it is more convenient to provide a single endpoint where the frontend +can send a DELETE for each token. This can be done with the following: .. code-block:: python @app.route("/logout", methods=["DELETE"]) @@ -77,13 +73,17 @@ or, for the database format: class TokenBlocklist(db.Model): id = db.Column(db.Integer, primary_key=True) jti = db.Column(db.String(36), nullable=False, index=True) - type = db.Column(db.Integer, nullable=False) + type = db.Column(db.String(16), nullable=False) user_id = db.Column( - db.ForeignKey('person.id') + db.ForeignKey('person.id'), + default=lambda: get_current_user().id, + nullable=False, + ) + created_at = db.Column( + db.DateTime, + server_default=func.now(), nullable=False, - default=lambda: get_current_user().id - ) - created_at = db.Column(db.DateTime, nullable=False) + ) @app.route("/logout", methods=["DELETE"]) @jwt_required(verify_type=False) From 9bef2987d7df5d30c19e02c263cc59d4b2183b29 Mon Sep 17 00:00:00 2001 From: Trevor Gross Date: Sat, 8 Jan 2022 00:34:37 -0500 Subject: [PATCH 07/17] More documentation work --- docs/blocklist_and_token_revoking.rst | 15 +++++++++------ examples/blocklist_database.py | 2 ++ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/docs/blocklist_and_token_revoking.rst b/docs/blocklist_and_token_revoking.rst index 895fbc07..e7ad7799 100644 --- a/docs/blocklist_and_token_revoking.rst +++ b/docs/blocklist_and_token_revoking.rst @@ -96,15 +96,18 @@ or, for the database format: db.session.commit() return jsonify(msg=f"{ttype.capitalize()} token revoked") + Token type and user are not required and can be omitted. That being said, including these columns can help to audit that the frontend is performing its revoking job correctly and revoking both tokens. -An alternative, albeit more complex, implementation is to invalidate all issued -tokens at once. +An alternative, albeit much more complex, implementation is to invalidate all issued +tokens for a user at once. To do this, all issued tokens must be tracked (by default, +they verified by being validated against the secret key). A few steps would be +required: -#. Store all generated access and refresh tokens in a database with a user_id column or similar -#. Change -#. token_in_blocklist_loader -#. +#. Store all generated access and refresh tokens in a database, include a user_id column +#. Add a "valid" boolean column. Update the `token_in_blocklist_loader` to respond based on this column +#. Upon revoking a token, find all other tokens with the same user and created at the same time, + and mark them all as invalid diff --git a/examples/blocklist_database.py b/examples/blocklist_database.py index 40357481..0a5db723 100644 --- a/examples/blocklist_database.py +++ b/examples/blocklist_database.py @@ -24,6 +24,8 @@ # Making jti an index can significantly speed up the search when there are # tens of thousands of records. Remember this query will happen for every # (protected) request, +# If your database supports a UUID type, this can be used for the jti column +# as well class TokenBlocklist(db.Model): id = db.Column(db.Integer, primary_key=True) jti = db.Column(db.String(36), nullable=False, index=True) From 95ef41309e1709cbc1a1737970eb1cf94b71cb49 Mon Sep 17 00:00:00 2001 From: Trevor Gross Date: Sat, 8 Jan 2022 01:02:31 -0500 Subject: [PATCH 08/17] Finished documentation --- docs/blocklist_and_token_revoking.rst | 16 +++++++--------- docs/refreshing_tokens.rst | 6 +++--- examples/blocklist_database.py | 13 ++++++++++--- flask_jwt_extended/utils.py | 3 ++- 4 files changed, 22 insertions(+), 16 deletions(-) diff --git a/docs/blocklist_and_token_revoking.rst b/docs/blocklist_and_token_revoking.rst index e7ad7799..39611e3e 100644 --- a/docs/blocklist_and_token_revoking.rst +++ b/docs/blocklist_and_token_revoking.rst @@ -29,7 +29,7 @@ Live (TTL) functionality when storing a JWT. Here is an example using redis: .. warning:: Note that configuring redis to be disk-persistent is an absolutely necessity for production use. Otherwise, events like power outages or server crashes/reboots - would cause all formerly invalidated tokens to become valid again (assuming the + would cause all invalidated tokens to become valid again (assuming the secret key does not change). This is especially concering for long-lived refresh tokens, discussed below. @@ -42,18 +42,18 @@ etc. Here is an example using SQLAlchemy: .. literalinclude:: ../examples/blocklist_database.py -Handling Revoking of Refresh Tokens -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Revoking Refresh Tokens +~~~~~~~~~~~~~~~~~~~~~~~ It is critical to note that a user's refresh token must also be revoked when logging out; otherwise, this refresh token could just be used to generate -a new access token. Usually this falls to the responsibility of the frontend +a new access token. Usually this falls to the responsibility of the frontend application, which must send two separate requests to the backend in order to revoke these tokens. This can be implemented via two separate routes marked with ``@jwt_required()`` and ``@jwt_required(refresh=True)`` to revoke access and refresh tokens, respectively. However, it is more convenient to provide a single endpoint where the frontend -can send a DELETE for each token. This can be done with the following: +can send a DELETE for each token. Thee following is an example: .. code-block:: python @app.route("/logout", methods=["DELETE"]) @@ -101,13 +101,11 @@ Token type and user are not required and can be omitted. That being said, includ these columns can help to audit that the frontend is performing its revoking job correctly and revoking both tokens. - An alternative, albeit much more complex, implementation is to invalidate all issued tokens for a user at once. To do this, all issued tokens must be tracked (by default, -they verified by being validated against the secret key). A few steps would be -required: +they are not stored on the server). A few steps would be required: #. Store all generated access and refresh tokens in a database, include a user_id column #. Add a "valid" boolean column. Update the `token_in_blocklist_loader` to respond based on this column #. Upon revoking a token, find all other tokens with the same user and created at the same time, - and mark them all as invalid + (or all a user's tokens to log out on all devices) and mark each as invalid diff --git a/docs/refreshing_tokens.rst b/docs/refreshing_tokens.rst index 8fc9f4c5..77f05ba5 100644 --- a/docs/refreshing_tokens.rst +++ b/docs/refreshing_tokens.rst @@ -82,8 +82,8 @@ option when creating JWTs: create_access_token(identity, fresh=datetime.timedelta(minutes=15)) -Revoking Refresh Tokens -~~~~~~~~~~~~~~~~~~~~~~~ +Note on Revoking Refresh Tokens +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Note that when an access token is invalidated (e.g. logging a user out), the corresponding refresh token(s) must be revoked too. -See :ref:`Handling Revoking Refresh Tokens` for details on how to handle this. +See :ref:`Revoking Refresh Tokens` for details on how to handle this. diff --git a/examples/blocklist_database.py b/examples/blocklist_database.py index 0a5db723..1b75f45a 100644 --- a/examples/blocklist_database.py +++ b/examples/blocklist_database.py @@ -1,9 +1,16 @@ -from datetime import datetime, timedelta, timezone +from datetime import datetime +from datetime import timedelta +from datetime import timezone -from flask import Flask, jsonify -from flask_jwt_extended import JWTManager, create_access_token, get_jwt, jwt_required +from flask import Flask +from flask import jsonify from flask_sqlalchemy import SQLAlchemy +from flask_jwt_extended import create_access_token +from flask_jwt_extended import get_jwt +from flask_jwt_extended import jwt_required +from flask_jwt_extended import JWTManager + app = Flask(__name__) ACCESS_EXPIRES = timedelta(hours=1) diff --git a/flask_jwt_extended/utils.py b/flask_jwt_extended/utils.py index adcedc16..c21ea017 100644 --- a/flask_jwt_extended/utils.py +++ b/flask_jwt_extended/utils.py @@ -2,7 +2,8 @@ from typing import Any import jwt -from flask import Response, _request_ctx_stack +from flask import _request_ctx_stack +from flask import Response from werkzeug.local import LocalProxy from flask_jwt_extended.config import config From 6114f99211d4b1d6dd56233e48bf665249682256 Mon Sep 17 00:00:00 2001 From: Trevor Gross Date: Sat, 8 Jan 2022 01:03:57 -0500 Subject: [PATCH 09/17] Fixed docs --- docs/blocklist_and_token_revoking.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/blocklist_and_token_revoking.rst b/docs/blocklist_and_token_revoking.rst index 39611e3e..97d31060 100644 --- a/docs/blocklist_and_token_revoking.rst +++ b/docs/blocklist_and_token_revoking.rst @@ -56,6 +56,7 @@ However, it is more convenient to provide a single endpoint where the frontend can send a DELETE for each token. Thee following is an example: .. code-block:: python + @app.route("/logout", methods=["DELETE"]) @jwt_required(verify_type=False) def logout(): @@ -70,6 +71,7 @@ can send a DELETE for each token. Thee following is an example: or, for the database format: .. code-block:: python + class TokenBlocklist(db.Model): id = db.Column(db.Integer, primary_key=True) jti = db.Column(db.String(36), nullable=False, index=True) From 5777199073e4994399180579ef0b825e124bc119 Mon Sep 17 00:00:00 2001 From: Trevor Gross Date: Sat, 8 Jan 2022 01:17:41 -0500 Subject: [PATCH 10/17] Removed unused get_token_type --- flask_jwt_extended/__init__.py | 1 - flask_jwt_extended/utils.py | 13 ------------- 2 files changed, 14 deletions(-) diff --git a/flask_jwt_extended/__init__.py b/flask_jwt_extended/__init__.py index c6f4454a..21373cc3 100644 --- a/flask_jwt_extended/__init__.py +++ b/flask_jwt_extended/__init__.py @@ -10,7 +10,6 @@ from .utils import get_jwt_header from .utils import get_jwt_identity from .utils import get_jwt_request_location -from .utils import get_token_type from .utils import get_unverified_jwt_headers from .utils import set_access_cookies from .utils import set_refresh_cookies diff --git a/flask_jwt_extended/utils.py b/flask_jwt_extended/utils.py index c21ea017..b0708b89 100644 --- a/flask_jwt_extended/utils.py +++ b/flask_jwt_extended/utils.py @@ -253,19 +253,6 @@ def get_jti(encoded_token: str) -> str: return decode_token(encoded_token).get("jti") -def get_token_type(encoded_token: str) -> str: - """ - Returns the JTI (unique identifier) of an encoded JWT - - :param encoded_token: - The encoded JWT to get the JTI from. - - :return: - The JTI (unique identifier) of a JWT. - """ - return decode_token(encoded_token).get("jti") - - def get_csrf_token(encoded_token: str) -> str: """ Returns the CSRF double submit token from an encoded JWT. From 8ff71515ba3cbf2a7923c81627869ec1cbc220ee Mon Sep 17 00:00:00 2001 From: Trevor Gross Date: Sat, 8 Jan 2022 02:36:16 -0500 Subject: [PATCH 11/17] Added test for no typecheck --- docs/blocklist_and_token_revoking.rst | 35 +++++++++++-------- docs/options.rst | 5 +++ docs/refreshing_tokens.rst | 16 +++++---- tests/test_view_decorators.py | 48 ++++++++++++++++++++------- 4 files changed, 72 insertions(+), 32 deletions(-) diff --git a/docs/blocklist_and_token_revoking.rst b/docs/blocklist_and_token_revoking.rst index 97d31060..42146288 100644 --- a/docs/blocklist_and_token_revoking.rst +++ b/docs/blocklist_and_token_revoking.rst @@ -53,7 +53,7 @@ revoke these tokens. This can be implemented via two separate routes marked with ``@jwt_required()`` and ``@jwt_required(refresh=True)`` to revoke access and refresh tokens, respectively. However, it is more convenient to provide a single endpoint where the frontend -can send a DELETE for each token. Thee following is an example: +can send a DELETE for each token. The following is an example: .. code-block:: python @@ -66,7 +66,7 @@ can send a DELETE for each token. Thee following is an example: jwt_redis_blocklist.set(jti, "", ex=ACCESS_EXPIRES) # Returns "Access token revoked" or "Refresh token revoked" - return jsonify(msg=f"{ttype.capitalize()} token revoked") + return jsonify(msg=f"{ttype.capitalize()} token successfully revoked") or, for the database format: @@ -96,18 +96,27 @@ or, for the database format: now = datetime.now(timezone.utc) db.session.add(TokenBlocklist(jti=jti, type=ttype, created_at=now)) db.session.commit() - return jsonify(msg=f"{ttype.capitalize()} token revoked") + return jsonify(msg=f"{ttype.capitalize()} token successfully revoked") -Token type and user are not required and can be omitted. That being said, including -these columns can help to audit that the frontend is performing its revoking job -correctly and revoking both tokens. +Token type and user columns are not required and can be omitted. That being said, including +these can help to audit that the frontend is performing its revoking job correctly and revoking both tokens. -An alternative, albeit much more complex, implementation is to invalidate all issued -tokens for a user at once. To do this, all issued tokens must be tracked (by default, -they are not stored on the server). A few steps would be required: +Alternatively, there are a few ways to revoke both tokens at once: -#. Store all generated access and refresh tokens in a database, include a user_id column -#. Add a "valid" boolean column. Update the `token_in_blocklist_loader` to respond based on this column -#. Upon revoking a token, find all other tokens with the same user and created at the same time, - (or all a user's tokens to log out on all devices) and mark each as invalid +#. Send the access token in the header (per usual), and send the refresh token in + the DELETE request body. This saves a request but still needs frontend changes, so may not + be worth implementing +#. Embed the refresh token's jti in the access token. The revoke route should be authenticated + with the access token. Upon revoking the access token, extract the refresh jti from it + and invalidate both. This has the advantage of requiring no extra work from the frontend. +#. Store all generated tokens jtis in a database whenever they are created. Have a column to represent + whether it is valid or not, which the ``token_in_blocklist_loader`` should respond based upon. + Upon revoking a token, mark that token as invalid, as well as all other tokens from the same + user generated at the same time. This would also allow for a "log out everywhere" option where + all tokens for a user are invalidated at once, which is otherwise not easily possibile + + +The best option of course depends and needs to be chosen based upon the circumstances. If there +if ever a time where an unknown, untracked token needs to be immediately invalidated, this can +be accomplished by changing the secret key. diff --git a/docs/options.rst b/docs/options.rst index 2c920106..356fa9ae 100644 --- a/docs/options.rst +++ b/docs/options.rst @@ -93,6 +93,11 @@ General Options: **Do not reveal the secret key when posting questions or committing code.** + Note: there is ever a need to invalidate all issued tokens (e.g. a security flaw was found, + or the revoked token database was lost), this can be easily done by changing the JWT_SECRET_KEY + (or Flask's SECRET_KEY, if JWT_SECRET_KEY is unset). + + Default: ``None`` diff --git a/docs/refreshing_tokens.rst b/docs/refreshing_tokens.rst index 77f05ba5..02dc3d44 100644 --- a/docs/refreshing_tokens.rst +++ b/docs/refreshing_tokens.rst @@ -50,10 +50,19 @@ website (mobile, api only, etc). Making a request with a refresh token looks just like making a request with an access token. Here is an example using `HTTPie `_. + + + .. code-block :: bash $ http POST :5000/refresh Authorization:"Bearer $REFRESH_TOKEN" +.. warning:: + + Note that when an access token is invalidated (e.g. logging a user out), any + corresponding refresh token(s) must be revoked too. See + :ref:`Revoking Refresh Tokens` for details on how to handle this. + Token Freshness Pattern ~~~~~~~~~~~~~~~~~~~~~~~ @@ -80,10 +89,3 @@ option when creating JWTs: .. code-block :: python create_access_token(identity, fresh=datetime.timedelta(minutes=15)) - - -Note on Revoking Refresh Tokens -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Note that when an access token is invalidated (e.g. logging a user out), the -corresponding refresh token(s) must be revoked too. -See :ref:`Revoking Refresh Tokens` for details on how to handle this. diff --git a/tests/test_view_decorators.py b/tests/test_view_decorators.py index 6a34c0bb..272c1547 100644 --- a/tests/test_view_decorators.py +++ b/tests/test_view_decorators.py @@ -2,19 +2,18 @@ import pytest from dateutil.relativedelta import relativedelta -from flask import Flask -from flask import jsonify +from flask import Flask, jsonify +from flask_jwt_extended import ( + JWTManager, + create_access_token, + create_refresh_token, + decode_token, + get_jwt_identity, + jwt_required, + verify_jwt_in_request, +) -from flask_jwt_extended import create_access_token -from flask_jwt_extended import create_refresh_token -from flask_jwt_extended import decode_token -from flask_jwt_extended import get_jwt_identity -from flask_jwt_extended import jwt_required -from flask_jwt_extended import JWTManager -from flask_jwt_extended import verify_jwt_in_request -from tests.utils import encode_token -from tests.utils import get_jwt_manager -from tests.utils import make_headers +from tests.utils import encode_token, get_jwt_manager, make_headers @pytest.fixture(scope="function") @@ -46,6 +45,11 @@ def optional_protected(): else: return jsonify(foo="bar") + @app.route("/no_typecheck_protected", methods=["GET"]) + @jwt_required(verify_type=False) + def no_typecheck_protected(): + return jsonify(foo="bar") + return app @@ -153,6 +157,26 @@ def test_refresh_jwt_required(app): assert response.get_json() == {"foo": "bar"} +def test_jwt_required_no_typecheck(app): + """Verify this route works with access or refresh tokens.""" + url = "/no_typecheck_protected" + + test_client = app.test_client() + with app.test_request_context(): + access_token = create_access_token("username") + fresh_access_token = create_access_token("username", fresh=True) + refresh_token = create_refresh_token("username") + + for token in (access_token, fresh_access_token, refresh_token): + response = test_client.get(url, headers=make_headers(token)) + assert response.status_code == 200 + assert response.get_json() == {"foo": "bar"} + + response = test_client.get(url, headers=None) + assert response.status_code == 401 + assert response.get_json() == {"msg": "Missing Authorization Header"} + + @pytest.mark.parametrize("delta_func", [timedelta, relativedelta]) def test_jwt_optional(app, delta_func): url = "/optional_protected" From a6ab61d65978746836c6a97f18de9886eca3326f Mon Sep 17 00:00:00 2001 From: Trevor Gross Date: Sat, 8 Jan 2022 02:39:30 -0500 Subject: [PATCH 12/17] Updated docs, tests passed --- docs/blocklist_and_token_revoking.rst | 4 ++-- docs/refreshing_tokens.rst | 4 ++-- tests/test_view_decorators.py | 25 +++++++++++++------------ 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/docs/blocklist_and_token_revoking.rst b/docs/blocklist_and_token_revoking.rst index 42146288..9250cfc0 100644 --- a/docs/blocklist_and_token_revoking.rst +++ b/docs/blocklist_and_token_revoking.rst @@ -110,9 +110,9 @@ Alternatively, there are a few ways to revoke both tokens at once: #. Embed the refresh token's jti in the access token. The revoke route should be authenticated with the access token. Upon revoking the access token, extract the refresh jti from it and invalidate both. This has the advantage of requiring no extra work from the frontend. -#. Store all generated tokens jtis in a database whenever they are created. Have a column to represent +#. Store every generated tokens jti in a database upon creation. Have a boolean column to represent whether it is valid or not, which the ``token_in_blocklist_loader`` should respond based upon. - Upon revoking a token, mark that token as invalid, as well as all other tokens from the same + Upon revoking a token, mark that token row as invalid, as well as all other tokens from the same user generated at the same time. This would also allow for a "log out everywhere" option where all tokens for a user are invalidated at once, which is otherwise not easily possibile diff --git a/docs/refreshing_tokens.rst b/docs/refreshing_tokens.rst index 02dc3d44..97390f16 100644 --- a/docs/refreshing_tokens.rst +++ b/docs/refreshing_tokens.rst @@ -59,8 +59,8 @@ an access token. Here is an example using `HTTPie `_. .. warning:: - Note that when an access token is invalidated (e.g. logging a user out), any - corresponding refresh token(s) must be revoked too. See + Note that when an access token is invalidated (e.g. logging a user out), any + corresponding refresh token(s) must be revoked too. See :ref:`Revoking Refresh Tokens` for details on how to handle this. diff --git a/tests/test_view_decorators.py b/tests/test_view_decorators.py index 272c1547..7b88b522 100644 --- a/tests/test_view_decorators.py +++ b/tests/test_view_decorators.py @@ -2,18 +2,19 @@ import pytest from dateutil.relativedelta import relativedelta -from flask import Flask, jsonify -from flask_jwt_extended import ( - JWTManager, - create_access_token, - create_refresh_token, - decode_token, - get_jwt_identity, - jwt_required, - verify_jwt_in_request, -) - -from tests.utils import encode_token, get_jwt_manager, make_headers +from flask import Flask +from flask import jsonify + +from flask_jwt_extended import create_access_token +from flask_jwt_extended import create_refresh_token +from flask_jwt_extended import decode_token +from flask_jwt_extended import get_jwt_identity +from flask_jwt_extended import jwt_required +from flask_jwt_extended import JWTManager +from flask_jwt_extended import verify_jwt_in_request +from tests.utils import encode_token +from tests.utils import get_jwt_manager +from tests.utils import make_headers @pytest.fixture(scope="function") From faadc9535fb78a5a563892ecf7f82f96483f258f Mon Sep 17 00:00:00 2001 From: Trevor Gross Date: Sat, 8 Jan 2022 02:42:58 -0500 Subject: [PATCH 13/17] Fixed a minor grammatical error, nobody is getting insurance money from this:) --- README.md | 47 ++++++++++++++++++++------------- flask_jwt_extended/utils.py | 7 +++-- tests/test_asymmetric_crypto.py | 12 +++------ tests/test_cookies.py | 31 +++++++++++----------- tests/test_decode_tokens.py | 43 +++++++++++++++--------------- tests/test_headers.py | 21 +++++++-------- tests/test_query_string.py | 8 +++--- 7 files changed, 86 insertions(+), 83 deletions(-) diff --git a/README.md b/README.md index 03c99711..012f0d37 100644 --- a/README.md +++ b/README.md @@ -1,62 +1,73 @@ # Flask-JWT-Extended ### Features + Flask-JWT-Extended not only adds support for using JSON Web Tokens (JWT) to Flask for protecting routes, -but also many helpful (and **optional**) features built in to make working with JSON Web Tokens +but also many helpful (and **optional**) features built in to make working with JSON Web Tokens easier. These include: -* Adding custom claims to JSON Web Tokens -* Automatic user loading (`current_user`). -* Custom claims validation on received tokens -* [Refresh tokens](https://auth0.com/blog/refresh-tokens-what-are-they-and-when-to-use-them/) -* First class support for fresh tokens for making sensitive changes. -* Token revoking/blocklisting -* Storing tokens in cookies and CSRF protection +- Adding custom claims to JSON Web Tokens +- Automatic user loading (`current_user`). +- Custom claims validation on received tokens +- [Refresh tokens](https://auth0.com/blog/refresh-tokens-what-are-they-and-when-to-use-them/) +- First class support for fresh tokens for making sensitive changes. +- Token revoking/blocklisting +- Storing tokens in cookies and CSRF protection ### Usage + [View the documentation online](https://flask-jwt-extended.readthedocs.io/en/stable/) ### Upgrading from 3.x.x to 4.0.0 + [View the changes](https://flask-jwt-extended.readthedocs.io/en/stable/v4_upgrade_guide/) ### Changelog + You can view the changelog [here](https://github.com/vimalloc/flask-jwt-extended/releases). This project follows [semantic versioning](https://semver.org/). ### Chatting + Come chat with the community or ask questions at https://discord.gg/EJBsbFd ### Contributing + Before making any changes, make sure to install the development requirements and setup the git hooks which will automatically lint and format your changes. + ```bash pip install -r requirements.txt pre-commit install ``` We require 100% code coverage in our unit tests. You can run the tests locally -with `tox` which insures that all tests pass, tests provide complete code coverage, +with `tox` which ensures that all tests pass, tests provide complete code coverage, documentation builds, and style guide are adhered to + ```bash tox ``` A subset of checks can also be ran by adding an argument to tox. The available arguments are: - * py36, py37, py38, py39, pypy3 - * Run unit tests on the given python version - * coverage - * Run a code coverage check - * docs - * Insure documentation builds and there are no broken links - * style - * Insure style guide is adhered to + +- py36, py37, py38, py39, pypy3 + - Run unit tests on the given python version +- coverage + - Run a code coverage check +- docs + - Ensure documentation builds and there are no broken links +- style + - Ensure style guide is adhered to + ```bash tox -e py38 ``` -We also require features to be well documented. You can generate a local copy +We also require features to be well documented. You can generate a local copy of the documentation by going to the `docs` directory and running: + ```bash make clean && make html && open _build/html/index.html ``` diff --git a/flask_jwt_extended/utils.py b/flask_jwt_extended/utils.py index b0708b89..4dc7c3b3 100644 --- a/flask_jwt_extended/utils.py +++ b/flask_jwt_extended/utils.py @@ -2,8 +2,7 @@ from typing import Any import jwt -from flask import _request_ctx_stack -from flask import Response +from flask import Response, _request_ctx_stack from werkzeug.local import LocalProxy from flask_jwt_extended.config import config @@ -105,11 +104,11 @@ def decode_token( ) -> dict: """ Returns the decoded token (python dict) from an encoded JWT. This does all - the checks to insure that the decoded token is valid before returning it. + the checks to ensure that the decoded token is valid before returning it. This will not fire the user loader callbacks, save the token for access in protected endpoints, checked if a token is revoked, etc. This is puerly - used to insure that a JWT is valid. + used to ensure that a JWT is valid. :param encoded_token: The encoded JWT to decode. diff --git a/tests/test_asymmetric_crypto.py b/tests/test_asymmetric_crypto.py index 40d0d8ee..ef1ee7c6 100644 --- a/tests/test_asymmetric_crypto.py +++ b/tests/test_asymmetric_crypto.py @@ -1,10 +1,6 @@ import pytest -from flask import Flask -from flask import jsonify - -from flask_jwt_extended import create_access_token -from flask_jwt_extended import jwt_required -from flask_jwt_extended import JWTManager +from flask import Flask, jsonify +from flask_jwt_extended import JWTManager, create_access_token, jwt_required RSA_PRIVATE = """ -----BEGIN RSA PRIVATE KEY----- @@ -57,13 +53,13 @@ def test_asymmetric_cropto(app): app.config["JWT_ALGORITHM"] = "RS256" rs256_token = create_access_token("username") - # Insure the symmetric token does not work now + # Ensure the symmetric token does not work now access_headers = {"Authorization": "Bearer {}".format(hs256_token)} response = test_client.get("/protected", headers=access_headers) assert response.status_code == 422 assert response.get_json() == {"msg": "The specified alg value is not allowed"} - # Insure the asymmetric token does work + # Ensure the asymmetric token does work access_headers = {"Authorization": "Bearer {}".format(rs256_token)} response = test_client.get("/protected", headers=access_headers) assert response.status_code == 200 diff --git a/tests/test_cookies.py b/tests/test_cookies.py index 9cc33247..c028a8c1 100644 --- a/tests/test_cookies.py +++ b/tests/test_cookies.py @@ -1,17 +1,16 @@ import pytest -from flask import Flask -from flask import jsonify -from flask import request - -from flask_jwt_extended import create_access_token -from flask_jwt_extended import create_refresh_token -from flask_jwt_extended import jwt_required -from flask_jwt_extended import JWTManager -from flask_jwt_extended import set_access_cookies -from flask_jwt_extended import set_refresh_cookies -from flask_jwt_extended import unset_access_cookies -from flask_jwt_extended import unset_jwt_cookies -from flask_jwt_extended import unset_refresh_cookies +from flask import Flask, jsonify, request +from flask_jwt_extended import ( + JWTManager, + create_access_token, + create_refresh_token, + jwt_required, + set_access_cookies, + set_refresh_cookies, + unset_access_cookies, + unset_jwt_cookies, + unset_refresh_cookies, +) def _get_cookie_from_response(response, cookie_name): @@ -301,17 +300,17 @@ def test_custom_csrf_methods(app, options): response = test_client.get(auth_url) csrf_token = _get_cookie_from_response(response, csrf_cookie_name)[csrf_cookie_name] - # Insure we can now do posts without csrf + # Ensure we can now do posts without csrf response = test_client.post(post_url) assert response.status_code == 200 assert response.get_json() == {"foo": "bar"} - # Insure GET requests now fail without csrf + # Ensure GET requests now fail without csrf response = test_client.get(get_url) assert response.status_code == 401 assert response.get_json() == {"msg": "Missing CSRF token"} - # Insure GET requests now succeed with csrf + # Ensure GET requests now succeed with csrf csrf_headers = {"X-CSRF-TOKEN": csrf_token} response = test_client.get(get_url, headers=csrf_headers) assert response.status_code == 200 diff --git a/tests/test_decode_tokens.py b/tests/test_decode_tokens.py index 6c98e4f6..5a082eb0 100644 --- a/tests/test_decode_tokens.py +++ b/tests/test_decode_tokens.py @@ -1,28 +1,29 @@ -from datetime import datetime -from datetime import timedelta -from datetime import timezone +from datetime import datetime, timedelta, timezone import pytest from dateutil.relativedelta import relativedelta from flask import Flask -from jwt import DecodeError -from jwt import ExpiredSignatureError -from jwt import ImmatureSignatureError -from jwt import InvalidAudienceError -from jwt import InvalidIssuerError -from jwt import InvalidSignatureError -from jwt import MissingRequiredClaimError - -from flask_jwt_extended import create_access_token -from flask_jwt_extended import create_refresh_token -from flask_jwt_extended import decode_token -from flask_jwt_extended import get_jti -from flask_jwt_extended import get_unverified_jwt_headers -from flask_jwt_extended import JWTManager +from flask_jwt_extended import ( + JWTManager, + create_access_token, + create_refresh_token, + decode_token, + get_jti, + get_unverified_jwt_headers, +) from flask_jwt_extended.config import config from flask_jwt_extended.exceptions import JWTDecodeError -from tests.utils import encode_token -from tests.utils import get_jwt_manager +from jwt import ( + DecodeError, + ExpiredSignatureError, + ImmatureSignatureError, + InvalidAudienceError, + InvalidIssuerError, + InvalidSignatureError, + MissingRequiredClaimError, +) + +from tests.utils import encode_token, get_jwt_manager @pytest.fixture(scope="function") @@ -172,13 +173,13 @@ def test_nbf_token_in_future(app): def test_alternate_identity_claim(app, default_access_token): app.config["JWT_IDENTITY_CLAIM"] = "banana" - # Insure decoding fails if the claim isn't there + # Ensure decoding fails if the claim isn't there token = encode_token(app, default_access_token) with pytest.raises(JWTDecodeError): with app.test_request_context(): decode_token(token) - # Insure the claim exists in the decoded jwt + # Ensure the claim exists in the decoded jwt del default_access_token["sub"] default_access_token["banana"] = "username" token = encode_token(app, default_access_token) diff --git a/tests/test_headers.py b/tests/test_headers.py index a1a3aef5..7d061297 100644 --- a/tests/test_headers.py +++ b/tests/test_headers.py @@ -1,10 +1,7 @@ import pytest -from flask import Flask -from flask import jsonify +from flask import Flask, jsonify +from flask_jwt_extended import JWTManager, create_access_token, jwt_required -from flask_jwt_extended import create_access_token -from flask_jwt_extended import jwt_required -from flask_jwt_extended import JWTManager from tests.utils import get_jwt_manager @@ -76,13 +73,13 @@ def test_custom_header_name(app): with app.test_request_context(): access_token = create_access_token("username") - # Insure 'default' headers no longer work + # Ensure 'default' headers no longer work access_headers = {"Authorization": "Bearer {}".format(access_token)} response = test_client.get("/protected", headers=access_headers) assert response.status_code == 401 assert response.get_json() == {"msg": "Missing Foo Header"} - # Insure new headers do work + # Ensure new headers do work access_headers = {"Foo": "Bearer {}".format(access_token)} response = test_client.get("/protected", headers=access_headers) assert response.status_code == 200 @@ -108,7 +105,7 @@ def test_custom_header_type(app): with app.test_request_context(): access_token = create_access_token("username") - # Insure 'default' headers no longer work + # Ensure 'default' headers no longer work access_headers = {"Authorization": "Bearer {}".format(access_token)} response = test_client.get("/protected", headers=access_headers) error_msg = ( @@ -118,7 +115,7 @@ def test_custom_header_type(app): assert response.status_code == 401 assert response.get_json() == {"msg": error_msg} - # Insure new headers do work + # Ensure new headers do work access_headers = {"Authorization": "JWT {}".format(access_token)} response = test_client.get("/protected", headers=access_headers) assert response.status_code == 200 @@ -136,14 +133,14 @@ def test_custom_header_type(app): assert response.status_code == 200 assert response.get_json() == {"foo": "bar"} - # Insure new headers without a type also work + # Ensure new headers without a type also work app.config["JWT_HEADER_TYPE"] = "" access_headers = {"Authorization": access_token} response = test_client.get("/protected", headers=access_headers) assert response.status_code == 200 assert response.get_json() == {"foo": "bar"} - # Insure header with too many parts fails + # Ensure header with too many parts fails app.config["JWT_HEADER_TYPE"] = "" access_headers = {"Authorization": "Bearer {}".format(access_token)} response = test_client.get("/protected", headers=access_headers) @@ -156,7 +153,7 @@ def test_missing_headers(app): test_client = app.test_client() jwtM = get_jwt_manager(app) - # Insure 'default' no headers response + # Ensure 'default' no headers response response = test_client.get("/protected", headers=None) assert response.status_code == 401 assert response.get_json() == {"msg": "Missing Authorization Header"} diff --git a/tests/test_query_string.py b/tests/test_query_string.py index 2b5c6586..6fc9d8a8 100644 --- a/tests/test_query_string.py +++ b/tests/test_query_string.py @@ -66,13 +66,13 @@ def test_custom_query_paramater(app): with app.test_request_context(): access_token = create_access_token("username") - # Insure 'default' query paramaters no longer work + # Ensure 'default' query paramaters no longer work url = "/protected?jwt={}".format(access_token) response = test_client.get(url) assert response.status_code == 401 assert response.get_json() == {"msg": "Missing 'foo' query paramater"} - # Insure new query_string does work + # Ensure new query_string does work url = "/protected?foo={}".format(access_token) response = test_client.get(url) assert response.status_code == 200 @@ -86,12 +86,12 @@ def test_missing_query_paramater(app): with app.test_request_context(): access_token = create_access_token("username") - # Insure no query paramaters doesn't give a response + # Ensure no query paramaters doesn't give a response response = test_client.get("/protected") assert response.status_code == 401 assert response.get_json() == {"msg": "Missing 'jwt' query paramater"} - # Insure headers don't work + # Ensure headers don't work access_headers = {"Authorization": "Bearer {}".format(access_token)} response = test_client.get("/protected", headers=access_headers) assert response.status_code == 401 From fb99b9f256764dad78fbcf063ce46e3e30a33af7 Mon Sep 17 00:00:00 2001 From: Trevor Gross Date: Sat, 8 Jan 2022 02:51:10 -0500 Subject: [PATCH 14/17] Final formatting, ready for PR. Closes #453. --- flask_jwt_extended/utils.py | 3 ++- tests/test_asymmetric_crypto.py | 8 +++++-- tests/test_cookies.py | 25 +++++++++++---------- tests/test_decode_tokens.py | 39 ++++++++++++++++----------------- tests/test_headers.py | 7 ++++-- 5 files changed, 45 insertions(+), 37 deletions(-) diff --git a/flask_jwt_extended/utils.py b/flask_jwt_extended/utils.py index 4dc7c3b3..83284ef8 100644 --- a/flask_jwt_extended/utils.py +++ b/flask_jwt_extended/utils.py @@ -2,7 +2,8 @@ from typing import Any import jwt -from flask import Response, _request_ctx_stack +from flask import _request_ctx_stack +from flask import Response from werkzeug.local import LocalProxy from flask_jwt_extended.config import config diff --git a/tests/test_asymmetric_crypto.py b/tests/test_asymmetric_crypto.py index ef1ee7c6..5bd067d9 100644 --- a/tests/test_asymmetric_crypto.py +++ b/tests/test_asymmetric_crypto.py @@ -1,6 +1,10 @@ import pytest -from flask import Flask, jsonify -from flask_jwt_extended import JWTManager, create_access_token, jwt_required +from flask import Flask +from flask import jsonify + +from flask_jwt_extended import create_access_token +from flask_jwt_extended import jwt_required +from flask_jwt_extended import JWTManager RSA_PRIVATE = """ -----BEGIN RSA PRIVATE KEY----- diff --git a/tests/test_cookies.py b/tests/test_cookies.py index c028a8c1..a1d2d483 100644 --- a/tests/test_cookies.py +++ b/tests/test_cookies.py @@ -1,16 +1,17 @@ import pytest -from flask import Flask, jsonify, request -from flask_jwt_extended import ( - JWTManager, - create_access_token, - create_refresh_token, - jwt_required, - set_access_cookies, - set_refresh_cookies, - unset_access_cookies, - unset_jwt_cookies, - unset_refresh_cookies, -) +from flask import Flask +from flask import jsonify +from flask import request + +from flask_jwt_extended import create_access_token +from flask_jwt_extended import create_refresh_token +from flask_jwt_extended import jwt_required +from flask_jwt_extended import JWTManager +from flask_jwt_extended import set_access_cookies +from flask_jwt_extended import set_refresh_cookies +from flask_jwt_extended import unset_access_cookies +from flask_jwt_extended import unset_jwt_cookies +from flask_jwt_extended import unset_refresh_cookies def _get_cookie_from_response(response, cookie_name): diff --git a/tests/test_decode_tokens.py b/tests/test_decode_tokens.py index 5a082eb0..4d13f787 100644 --- a/tests/test_decode_tokens.py +++ b/tests/test_decode_tokens.py @@ -1,29 +1,28 @@ -from datetime import datetime, timedelta, timezone +from datetime import datetime +from datetime import timedelta +from datetime import timezone import pytest from dateutil.relativedelta import relativedelta from flask import Flask -from flask_jwt_extended import ( - JWTManager, - create_access_token, - create_refresh_token, - decode_token, - get_jti, - get_unverified_jwt_headers, -) +from jwt import DecodeError +from jwt import ExpiredSignatureError +from jwt import ImmatureSignatureError +from jwt import InvalidAudienceError +from jwt import InvalidIssuerError +from jwt import InvalidSignatureError +from jwt import MissingRequiredClaimError + +from flask_jwt_extended import create_access_token +from flask_jwt_extended import create_refresh_token +from flask_jwt_extended import decode_token +from flask_jwt_extended import get_jti +from flask_jwt_extended import get_unverified_jwt_headers +from flask_jwt_extended import JWTManager from flask_jwt_extended.config import config from flask_jwt_extended.exceptions import JWTDecodeError -from jwt import ( - DecodeError, - ExpiredSignatureError, - ImmatureSignatureError, - InvalidAudienceError, - InvalidIssuerError, - InvalidSignatureError, - MissingRequiredClaimError, -) - -from tests.utils import encode_token, get_jwt_manager +from tests.utils import encode_token +from tests.utils import get_jwt_manager @pytest.fixture(scope="function") diff --git a/tests/test_headers.py b/tests/test_headers.py index 7d061297..fb3b8252 100644 --- a/tests/test_headers.py +++ b/tests/test_headers.py @@ -1,7 +1,10 @@ import pytest -from flask import Flask, jsonify -from flask_jwt_extended import JWTManager, create_access_token, jwt_required +from flask import Flask +from flask import jsonify +from flask_jwt_extended import create_access_token +from flask_jwt_extended import jwt_required +from flask_jwt_extended import JWTManager from tests.utils import get_jwt_manager From c072a34d55e31395cb4c8b2a99e0c4d41362e143 Mon Sep 17 00:00:00 2001 From: Trevor Gross Date: Sat, 8 Jan 2022 02:57:40 -0500 Subject: [PATCH 15/17] Removed typehint from example --- examples/loaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/loaders.py b/examples/loaders.py index 02bd48fe..0dc5b537 100644 --- a/examples/loaders.py +++ b/examples/loaders.py @@ -17,7 +17,7 @@ # response. Check the API documentation to see the required argument and return # values for other callback functions. @jwt.expired_token_loader -def my_expired_token_callback(jwt_header, jwt_payload: dict): +def my_expired_token_callback(jwt_header, jwt_payload): return jsonify(code="dave", err="I can't let you do that"), 401 From 861917872f9b99941339b31d33d1a3fc037ea41f Mon Sep 17 00:00:00 2001 From: Trevor Gross Date: Sat, 8 Jan 2022 03:04:52 -0500 Subject: [PATCH 16/17] Final type hint updates --- flask_jwt_extended/config.py | 2 +- flask_jwt_extended/default_callbacks.py | 9 +++++---- flask_jwt_extended/internal_utils.py | 4 +++- flask_jwt_extended/tokens.py | 4 ++-- flask_jwt_extended/utils.py | 6 +++--- flask_jwt_extended/view_decorators.py | 3 ++- 6 files changed, 16 insertions(+), 12 deletions(-) diff --git a/flask_jwt_extended/config.py b/flask_jwt_extended/config.py index cf1517d3..7d8d93d3 100644 --- a/flask_jwt_extended/config.py +++ b/flask_jwt_extended/config.py @@ -141,7 +141,7 @@ def csrf_in_cookies(self) -> bool: return current_app.config["JWT_CSRF_IN_COOKIES"] @property - def access_csrf_cookie_name(self): + def access_csrf_cookie_name(self) -> str: return current_app.config["JWT_ACCESS_CSRF_COOKIE_NAME"] @property diff --git a/flask_jwt_extended/default_callbacks.py b/flask_jwt_extended/default_callbacks.py index 72accea3..9efbcc4e 100644 --- a/flask_jwt_extended/default_callbacks.py +++ b/flask_jwt_extended/default_callbacks.py @@ -7,6 +7,7 @@ http://flask-jwt-extended.readthedocs.io/en/latest/tokens_from_complex_object.html """ from http import HTTPStatus +from typing import Any from flask import jsonify from flask import Response @@ -14,7 +15,7 @@ from flask_jwt_extended.config import config -def default_additional_claims_callback(userdata) -> dict: +def default_additional_claims_callback(userdata: Any) -> dict: """ By default, we add no additional claims to the access tokens. @@ -40,7 +41,7 @@ def default_jwt_headers_callback(default_headers) -> dict: return {} -def default_user_identity_callback(userdata): +def default_user_identity_callback(userdata: Any) -> Any: """ By default, we use the passed in object directly as the jwt identity. See this for additional info: @@ -138,7 +139,7 @@ def default_token_verification_failed_callback( ) -def default_decode_key_callback(jwt_header: dict, jwt_data: dict): +def default_decode_key_callback(jwt_header: dict, jwt_data: dict) -> str: """ By default, the decode key specified via the JWT_SECRET_KEY or JWT_PUBLIC_KEY settings will be used to decode all tokens @@ -146,7 +147,7 @@ def default_decode_key_callback(jwt_header: dict, jwt_data: dict): return config.decode_key -def default_encode_key_callback(identity): +def default_encode_key_callback(identity: Any) -> str: """ By default, the encode key specified via the JWT_SECRET_KEY or JWT_PRIVATE_KEY settings will be used to encode all tokens diff --git a/flask_jwt_extended/internal_utils.py b/flask_jwt_extended/internal_utils.py index a49de656..3b6dc8ec 100644 --- a/flask_jwt_extended/internal_utils.py +++ b/flask_jwt_extended/internal_utils.py @@ -1,3 +1,5 @@ +from typing import Any + from flask import current_app from flask_jwt_extended import JWTManager @@ -21,7 +23,7 @@ def has_user_lookup() -> bool: return jwt_manager._user_lookup_callback is not None -def user_lookup(*args, **kwargs): +def user_lookup(*args, **kwargs) -> Any: jwt_manager = get_jwt_manager() return jwt_manager._user_lookup_callback(*args, **kwargs) diff --git a/flask_jwt_extended/tokens.py b/flask_jwt_extended/tokens.py index c78651bc..b72efe14 100644 --- a/flask_jwt_extended/tokens.py +++ b/flask_jwt_extended/tokens.py @@ -21,7 +21,7 @@ def _encode_jwt( csrf: bool, expires_delta: timedelta, fresh: bool, - header_overrides, + header_overrides: dict, identity: Any, identity_claim_key: str, issuer: str, @@ -76,7 +76,7 @@ def _decode_jwt( audience: Union[str, Iterable[str]], csrf_value: str, encoded_token: str, - identity_claim_key, + identity_claim_key: str, issuer: str, leeway: int, secret: str, diff --git a/flask_jwt_extended/utils.py b/flask_jwt_extended/utils.py index 83284ef8..e7df351c 100644 --- a/flask_jwt_extended/utils.py +++ b/flask_jwt_extended/utils.py @@ -49,7 +49,7 @@ def get_jwt_header() -> dict: return decoded_header -def get_jwt_identity(): +def get_jwt_identity() -> Any: """ In a protected endpoint, this will return the identity of the JWT that is accessing the endpoint. If no JWT is present due to @@ -76,7 +76,7 @@ def get_jwt_request_location() -> str: return location -def get_current_user(): +def get_current_user() -> Any: """ In a protected endpoint, this will return the user object for the JWT that is accessing the endpoint. @@ -420,7 +420,7 @@ def unset_access_cookies(response: Response, domain: str = None) -> None: ) -def unset_refresh_cookies(response: Response, domain: str = None): +def unset_refresh_cookies(response: Response, domain: str = None) -> None: """ Modifiy a Flask Response to delete the cookie containing a refresh JWT. Also deletes the corresponding CSRF cookie if applicable. diff --git a/flask_jwt_extended/view_decorators.py b/flask_jwt_extended/view_decorators.py index 757ad3c7..3b96f250 100644 --- a/flask_jwt_extended/view_decorators.py +++ b/flask_jwt_extended/view_decorators.py @@ -2,6 +2,7 @@ from datetime import timezone from functools import wraps from re import split +from typing import Any from typing import Iterable from typing import Tuple from typing import Union @@ -107,7 +108,7 @@ def jwt_required( refresh: bool = False, locations: LocationType = None, verify_type: bool = True, -): +) -> Any: """ A decorator to protect a Flask endpoint with JSON Web Tokens. From 2b10bd0607ba7a6d9d107444fab34d57ec6b1b8f Mon Sep 17 00:00:00 2001 From: Trevor Gross Date: Sun, 23 Jan 2022 19:27:52 -0500 Subject: [PATCH 17/17] Remove flake8 from requirements --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 5e24dec4..1a3cfd7a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ black==21.6b0 cryptography==35.0.0 Flask==2.0.1 -flake8==4.0.1 Pallets-Sphinx-Themes==2.0.1 pre-commit==2.13.0 PyJWT==2.1.0