diff --git a/.gitignore b/.gitignore index cf70e9fe..474e25fc 100644 --- a/.gitignore +++ b/.gitignore @@ -81,6 +81,7 @@ celerybeat-schedule # virtualenv venv/ +.venv/ ENV/ # Spyder project settings @@ -94,3 +95,6 @@ ENV/ # MacOS specific crap .DS_Store + +# Workspace +.vscode/ diff --git a/README.md b/README.md index 03c99711..012f0d37 100644 --- a/README.md +++ b/README.md @@ -1,62 +1,73 @@ # Flask-JWT-Extended ### Features + Flask-JWT-Extended not only adds support for using JSON Web Tokens (JWT) to Flask for protecting routes, -but also many helpful (and **optional**) features built in to make working with JSON Web Tokens +but also many helpful (and **optional**) features built in to make working with JSON Web Tokens easier. These include: -* Adding custom claims to JSON Web Tokens -* Automatic user loading (`current_user`). -* Custom claims validation on received tokens -* [Refresh tokens](https://auth0.com/blog/refresh-tokens-what-are-they-and-when-to-use-them/) -* First class support for fresh tokens for making sensitive changes. -* Token revoking/blocklisting -* Storing tokens in cookies and CSRF protection +- Adding custom claims to JSON Web Tokens +- Automatic user loading (`current_user`). +- Custom claims validation on received tokens +- [Refresh tokens](https://auth0.com/blog/refresh-tokens-what-are-they-and-when-to-use-them/) +- First class support for fresh tokens for making sensitive changes. +- Token revoking/blocklisting +- Storing tokens in cookies and CSRF protection ### Usage + [View the documentation online](https://flask-jwt-extended.readthedocs.io/en/stable/) ### Upgrading from 3.x.x to 4.0.0 + [View the changes](https://flask-jwt-extended.readthedocs.io/en/stable/v4_upgrade_guide/) ### Changelog + You can view the changelog [here](https://github.com/vimalloc/flask-jwt-extended/releases). This project follows [semantic versioning](https://semver.org/). ### Chatting + Come chat with the community or ask questions at https://discord.gg/EJBsbFd ### Contributing + Before making any changes, make sure to install the development requirements and setup the git hooks which will automatically lint and format your changes. + ```bash pip install -r requirements.txt pre-commit install ``` We require 100% code coverage in our unit tests. You can run the tests locally -with `tox` which insures that all tests pass, tests provide complete code coverage, +with `tox` which ensures that all tests pass, tests provide complete code coverage, documentation builds, and style guide are adhered to + ```bash tox ``` A subset of checks can also be ran by adding an argument to tox. The available arguments are: - * py36, py37, py38, py39, pypy3 - * Run unit tests on the given python version - * coverage - * Run a code coverage check - * docs - * Insure documentation builds and there are no broken links - * style - * Insure style guide is adhered to + +- py36, py37, py38, py39, pypy3 + - Run unit tests on the given python version +- coverage + - Run a code coverage check +- docs + - Ensure documentation builds and there are no broken links +- style + - Ensure style guide is adhered to + ```bash tox -e py38 ``` -We also require features to be well documented. You can generate a local copy +We also require features to be well documented. You can generate a local copy of the documentation by going to the `docs` directory and running: + ```bash make clean && make html && open _build/html/index.html ``` diff --git a/docs/blocklist_and_token_revoking.rst b/docs/blocklist_and_token_revoking.rst index 10deb047..9250cfc0 100644 --- a/docs/blocklist_and_token_revoking.rst +++ b/docs/blocklist_and_token_revoking.rst @@ -25,6 +25,14 @@ Live (TTL) functionality when storing a JWT. Here is an example using redis: .. literalinclude:: ../examples/blocklist_redis.py + +.. warning:: + Note that configuring redis to be disk-persistent is an absolutely necessity for + production use. Otherwise, events like power outages or server crashes/reboots + would cause all invalidated tokens to become valid again (assuming the + secret key does not change). This is especially concering for long-lived + refresh tokens, discussed below. + Database ~~~~~~~~ If you need to keep track of information about revoked JWTs our recommendation is @@ -33,3 +41,82 @@ revoked tokens, such as when it was revoked, who revoked it, can it be un-revoke etc. Here is an example using SQLAlchemy: .. literalinclude:: ../examples/blocklist_database.py + +Revoking Refresh Tokens +~~~~~~~~~~~~~~~~~~~~~~~ +It is critical to note that a user's refresh token must also be revoked +when logging out; otherwise, this refresh token could just be used to generate +a new access token. Usually this falls to the responsibility of the frontend +application, which must send two separate requests to the backend in order to +revoke these tokens. + +This can be implemented via two separate routes marked with ``@jwt_required()`` +and ``@jwt_required(refresh=True)`` to revoke access and refresh tokens, respectively. +However, it is more convenient to provide a single endpoint where the frontend +can send a DELETE for each token. The following is an example: + +.. code-block:: python + + @app.route("/logout", methods=["DELETE"]) + @jwt_required(verify_type=False) + def logout(): + token = get_jwt() + jti = token["jti"] + ttype = token["type"] + jwt_redis_blocklist.set(jti, "", ex=ACCESS_EXPIRES) + + # Returns "Access token revoked" or "Refresh token revoked" + return jsonify(msg=f"{ttype.capitalize()} token successfully revoked") + +or, for the database format: + +.. code-block:: python + + class TokenBlocklist(db.Model): + id = db.Column(db.Integer, primary_key=True) + jti = db.Column(db.String(36), nullable=False, index=True) + type = db.Column(db.String(16), nullable=False) + user_id = db.Column( + db.ForeignKey('person.id'), + default=lambda: get_current_user().id, + nullable=False, + ) + created_at = db.Column( + db.DateTime, + server_default=func.now(), + nullable=False, + ) + + @app.route("/logout", methods=["DELETE"]) + @jwt_required(verify_type=False) + def modify_token(): + token = get_jwt() + jti = token["jti"] + ttype = token["type"] + now = datetime.now(timezone.utc) + db.session.add(TokenBlocklist(jti=jti, type=ttype, created_at=now)) + db.session.commit() + return jsonify(msg=f"{ttype.capitalize()} token successfully revoked") + + +Token type and user columns are not required and can be omitted. That being said, including +these can help to audit that the frontend is performing its revoking job correctly and revoking both tokens. + +Alternatively, there are a few ways to revoke both tokens at once: + +#. Send the access token in the header (per usual), and send the refresh token in + the DELETE request body. This saves a request but still needs frontend changes, so may not + be worth implementing +#. Embed the refresh token's jti in the access token. The revoke route should be authenticated + with the access token. Upon revoking the access token, extract the refresh jti from it + and invalidate both. This has the advantage of requiring no extra work from the frontend. +#. Store every generated tokens jti in a database upon creation. Have a boolean column to represent + whether it is valid or not, which the ``token_in_blocklist_loader`` should respond based upon. + Upon revoking a token, mark that token row as invalid, as well as all other tokens from the same + user generated at the same time. This would also allow for a "log out everywhere" option where + all tokens for a user are invalidated at once, which is otherwise not easily possibile + + +The best option of course depends and needs to be chosen based upon the circumstances. If there +if ever a time where an unknown, untracked token needs to be immediately invalidated, this can +be accomplished by changing the secret key. diff --git a/docs/conf.py b/docs/conf.py index e1cb34b3..1b83443d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -358,3 +358,12 @@ # If true, do not generate a @detailmenu in the "Top" node's menu. # # texinfo_no_detailmenu = False + +# Fix warnings about refernce targets. See link: +# https://stackoverflow.com/questions/11417221/ +# sphinx-autodoc-gives-warning-pyclass-reference-target-not-found-type-warning +nitpick_ignore = [ + ("py:class", "flask.app.Flask"), + ("py:class", "datetime.timedelta"), + ("py:class", "flask.wrappers.Response"), +] diff --git a/docs/options.rst b/docs/options.rst index 2c920106..356fa9ae 100644 --- a/docs/options.rst +++ b/docs/options.rst @@ -93,6 +93,11 @@ General Options: **Do not reveal the secret key when posting questions or committing code.** + Note: there is ever a need to invalidate all issued tokens (e.g. a security flaw was found, + or the revoked token database was lost), this can be easily done by changing the JWT_SECRET_KEY + (or Flask's SECRET_KEY, if JWT_SECRET_KEY is unset). + + Default: ``None`` diff --git a/docs/refreshing_tokens.rst b/docs/refreshing_tokens.rst index 0f39db4a..97390f16 100644 --- a/docs/refreshing_tokens.rst +++ b/docs/refreshing_tokens.rst @@ -50,10 +50,19 @@ website (mobile, api only, etc). Making a request with a refresh token looks just like making a request with an access token. Here is an example using `HTTPie `_. + + + .. code-block :: bash $ http POST :5000/refresh Authorization:"Bearer $REFRESH_TOKEN" +.. warning:: + + Note that when an access token is invalidated (e.g. logging a user out), any + corresponding refresh token(s) must be revoked too. See + :ref:`Revoking Refresh Tokens` for details on how to handle this. + Token Freshness Pattern ~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/examples/blocklist_database.py b/examples/blocklist_database.py index f4278ee1..1b75f45a 100644 --- a/examples/blocklist_database.py +++ b/examples/blocklist_database.py @@ -28,17 +28,23 @@ # This could be expanded to fit the needs of your application. For example, # it could track who revoked a JWT, when a token expires, notes for why a # JWT was revoked, an endpoint to un-revoked a JWT, etc. +# Making jti an index can significantly speed up the search when there are +# tens of thousands of records. Remember this query will happen for every +# (protected) request, +# If your database supports a UUID type, this can be used for the jti column +# as well class TokenBlocklist(db.Model): id = db.Column(db.Integer, primary_key=True) - jti = db.Column(db.String(36), nullable=False) + jti = db.Column(db.String(36), nullable=False, index=True) created_at = db.Column(db.DateTime, nullable=False) # Callback function to check if a JWT exists in the database blocklist @jwt.token_in_blocklist_loader -def check_if_token_revoked(jwt_header, jwt_payload): +def check_if_token_revoked(jwt_header, jwt_payload: dict) -> bool: jti = jwt_payload["jti"] token = db.session.query(TokenBlocklist.id).filter_by(jti=jti).scalar() + return token is not None diff --git a/examples/blocklist_redis.py b/examples/blocklist_redis.py index bcc055f8..e1aa1f0a 100644 --- a/examples/blocklist_redis.py +++ b/examples/blocklist_redis.py @@ -26,7 +26,7 @@ # Callback function to check if a JWT exists in the redis blocklist @jwt.token_in_blocklist_loader -def check_if_token_is_revoked(jwt_header, jwt_payload): +def check_if_token_is_revoked(jwt_header, jwt_payload: dict): jti = jwt_payload["jti"] token_in_redis = jwt_redis_blocklist.get(jti) return token_in_redis is not None diff --git a/flask_jwt_extended/config.py b/flask_jwt_extended/config.py index 764d6552..7d8d93d3 100644 --- a/flask_jwt_extended/config.py +++ b/flask_jwt_extended/config.py @@ -1,10 +1,12 @@ -from collections.abc import Sequence -from collections.abc import Set from datetime import datetime from datetime import timedelta from datetime import timezone +from typing import Iterable +from typing import List +from typing import Union from flask import current_app +from flask.json import JSONEncoder from jwt.algorithms import requires_cryptography @@ -20,23 +22,23 @@ class _Config(object): """ @property - def is_asymmetric(self): + def is_asymmetric(self) -> bool: return self.algorithm in requires_cryptography @property - def encode_key(self): + def encode_key(self) -> str: return self._private_key if self.is_asymmetric else self._secret_key @property - def decode_key(self): + def decode_key(self) -> str: return self._public_key if self.is_asymmetric else self._secret_key @property - def token_location(self): + def token_location(self) -> Iterable[str]: locations = current_app.config["JWT_TOKEN_LOCATION"] if isinstance(locations, str): locations = (locations,) - elif not isinstance(locations, (Sequence, Set)): + elif not isinstance(locations, Iterable): raise RuntimeError("JWT_TOKEN_LOCATION must be a sequence or a set") elif not locations: raise RuntimeError( @@ -52,130 +54,130 @@ def token_location(self): return locations @property - def jwt_in_cookies(self): + def jwt_in_cookies(self) -> bool: return "cookies" in self.token_location @property - def jwt_in_headers(self): + def jwt_in_headers(self) -> bool: return "headers" in self.token_location @property - def jwt_in_query_string(self): + def jwt_in_query_string(self) -> bool: return "query_string" in self.token_location @property - def jwt_in_json(self): + def jwt_in_json(self) -> bool: return "json" in self.token_location @property - def header_name(self): + def header_name(self) -> str: name = current_app.config["JWT_HEADER_NAME"] if not name: raise RuntimeError("JWT_ACCESS_HEADER_NAME cannot be empty") return name @property - def header_type(self): + def header_type(self) -> str: return current_app.config["JWT_HEADER_TYPE"] @property - def query_string_name(self): + def query_string_name(self) -> str: return current_app.config["JWT_QUERY_STRING_NAME"] @property - def query_string_value_prefix(self): + def query_string_value_prefix(self) -> str: return current_app.config["JWT_QUERY_STRING_VALUE_PREFIX"] @property - def access_cookie_name(self): + def access_cookie_name(self) -> str: return current_app.config["JWT_ACCESS_COOKIE_NAME"] @property - def refresh_cookie_name(self): + def refresh_cookie_name(self) -> str: return current_app.config["JWT_REFRESH_COOKIE_NAME"] @property - def access_cookie_path(self): + def access_cookie_path(self) -> str: return current_app.config["JWT_ACCESS_COOKIE_PATH"] @property - def refresh_cookie_path(self): + def refresh_cookie_path(self) -> str: return current_app.config["JWT_REFRESH_COOKIE_PATH"] @property - def cookie_secure(self): + def cookie_secure(self) -> bool: return current_app.config["JWT_COOKIE_SECURE"] @property - def cookie_domain(self): + def cookie_domain(self) -> str: return current_app.config["JWT_COOKIE_DOMAIN"] @property - def session_cookie(self): + def session_cookie(self) -> bool: return current_app.config["JWT_SESSION_COOKIE"] @property - def cookie_samesite(self): + def cookie_samesite(self) -> str: return current_app.config["JWT_COOKIE_SAMESITE"] @property - def json_key(self): + def json_key(self) -> str: return current_app.config["JWT_JSON_KEY"] @property - def refresh_json_key(self): + def refresh_json_key(self) -> str: return current_app.config["JWT_REFRESH_JSON_KEY"] @property - def csrf_protect(self): + def csrf_protect(self) -> bool: return self.jwt_in_cookies and current_app.config["JWT_COOKIE_CSRF_PROTECT"] @property - def csrf_request_methods(self): + def csrf_request_methods(self) -> Iterable[str]: return current_app.config["JWT_CSRF_METHODS"] @property - def csrf_in_cookies(self): + def csrf_in_cookies(self) -> bool: return current_app.config["JWT_CSRF_IN_COOKIES"] @property - def access_csrf_cookie_name(self): + def access_csrf_cookie_name(self) -> str: return current_app.config["JWT_ACCESS_CSRF_COOKIE_NAME"] @property - def refresh_csrf_cookie_name(self): + def refresh_csrf_cookie_name(self) -> str: return current_app.config["JWT_REFRESH_CSRF_COOKIE_NAME"] @property - def access_csrf_cookie_path(self): + def access_csrf_cookie_path(self) -> str: return current_app.config["JWT_ACCESS_CSRF_COOKIE_PATH"] @property - def refresh_csrf_cookie_path(self): + def refresh_csrf_cookie_path(self) -> str: return current_app.config["JWT_REFRESH_CSRF_COOKIE_PATH"] @property - def access_csrf_header_name(self): + def access_csrf_header_name(self) -> str: return current_app.config["JWT_ACCESS_CSRF_HEADER_NAME"] @property - def refresh_csrf_header_name(self): + def refresh_csrf_header_name(self) -> str: return current_app.config["JWT_REFRESH_CSRF_HEADER_NAME"] @property - def csrf_check_form(self): + def csrf_check_form(self) -> bool: return current_app.config["JWT_CSRF_CHECK_FORM"] @property - def access_csrf_field_name(self): + def access_csrf_field_name(self) -> str: return current_app.config["JWT_ACCESS_CSRF_FIELD_NAME"] @property - def refresh_csrf_field_name(self): + def refresh_csrf_field_name(self) -> str: return current_app.config["JWT_REFRESH_CSRF_FIELD_NAME"] @property - def access_expires(self): + def access_expires(self) -> datetime: delta = current_app.config["JWT_ACCESS_TOKEN_EXPIRES"] if type(delta) is int: delta = timedelta(seconds=delta) @@ -190,7 +192,7 @@ def access_expires(self): return delta @property - def refresh_expires(self): + def refresh_expires(self) -> datetime: delta = current_app.config["JWT_REFRESH_TOKEN_EXPIRES"] if type(delta) is int: delta = timedelta(seconds=delta) @@ -205,11 +207,11 @@ def refresh_expires(self): return delta @property - def algorithm(self): + def algorithm(self) -> str: return current_app.config["JWT_ALGORITHM"] @property - def decode_algorithms(self): + def decode_algorithms(self) -> List[str]: algorithms = current_app.config["JWT_DECODE_ALGORITHMS"] if not algorithms: return [self.algorithm] @@ -218,7 +220,7 @@ def decode_algorithms(self): return algorithms @property - def _secret_key(self): + def _secret_key(self) -> str: key = current_app.config["JWT_SECRET_KEY"] if not key: key = current_app.config.get("SECRET_KEY", None) @@ -231,7 +233,7 @@ def _secret_key(self): return key @property - def _public_key(self): + def _public_key(self) -> str: key = current_app.config["JWT_PUBLIC_KEY"] if not key: raise RuntimeError( @@ -242,7 +244,7 @@ def _public_key(self): return key @property - def _private_key(self): + def _private_key(self) -> str: key = current_app.config["JWT_PRIVATE_KEY"] if not key: raise RuntimeError( @@ -253,50 +255,50 @@ def _private_key(self): return key @property - def cookie_max_age(self): + def cookie_max_age(self) -> int: # Returns the appropiate value for max_age for flask set_cookies. If # session cookie is true, return None, otherwise return a number of # seconds 1 year in the future return None if self.session_cookie else 31540000 # 1 year @property - def identity_claim_key(self): + def identity_claim_key(self) -> str: return current_app.config["JWT_IDENTITY_CLAIM"] @property - def exempt_methods(self): + def exempt_methods(self) -> Iterable[str]: return {"OPTIONS"} @property - def error_msg_key(self): + def error_msg_key(self) -> str: return current_app.config["JWT_ERROR_MESSAGE_KEY"] @property - def json_encoder(self): + def json_encoder(self) -> JSONEncoder: return current_app.json_encoder @property - def decode_audience(self): + def decode_audience(self) -> Union[str, Iterable[str]]: return current_app.config["JWT_DECODE_AUDIENCE"] @property - def encode_audience(self): + def encode_audience(self) -> Union[str, Iterable[str]]: return current_app.config["JWT_ENCODE_AUDIENCE"] @property - def encode_issuer(self): + def encode_issuer(self) -> str: return current_app.config["JWT_ENCODE_ISSUER"] @property - def decode_issuer(self): + def decode_issuer(self) -> str: return current_app.config["JWT_DECODE_ISSUER"] @property - def leeway(self): + def leeway(self) -> int: return current_app.config["JWT_DECODE_LEEWAY"] @property - def encode_nbf(self): + def encode_nbf(self) -> bool: return current_app.config["JWT_ENCODE_NBF"] diff --git a/flask_jwt_extended/default_callbacks.py b/flask_jwt_extended/default_callbacks.py index 32e88719..9efbcc4e 100644 --- a/flask_jwt_extended/default_callbacks.py +++ b/flask_jwt_extended/default_callbacks.py @@ -6,12 +6,16 @@ http://flask-jwt-extended.readthedocs.io/en/latest/changing_default_behavior.html http://flask-jwt-extended.readthedocs.io/en/latest/tokens_from_complex_object.html """ +from http import HTTPStatus +from typing import Any + from flask import jsonify +from flask import Response from flask_jwt_extended.config import config -def default_additional_claims_callback(userdata): +def default_additional_claims_callback(userdata: Any) -> dict: """ By default, we add no additional claims to the access tokens. @@ -22,11 +26,11 @@ def default_additional_claims_callback(userdata): return {} -def default_blocklist_callback(jwt_headers, jwt_data): +def default_blocklist_callback(jwt_headers: dict, jwt_data: dict) -> bool: return False -def default_jwt_headers_callback(default_headers): +def default_jwt_headers_callback(default_headers) -> dict: """ By default header typically consists of two parts: the type of the token, which is JWT, and the signing algorithm being used, such as HMAC SHA256 @@ -37,7 +41,7 @@ def default_jwt_headers_callback(default_headers): return {} -def default_user_identity_callback(userdata): +def default_user_identity_callback(userdata: Any) -> Any: """ By default, we use the passed in object directly as the jwt identity. See this for additional info: @@ -49,77 +53,93 @@ def default_user_identity_callback(userdata): return userdata -def default_expired_token_callback(_expired_jwt_header, _expired_jwt_data): +def default_expired_token_callback( + _expired_jwt_header: dict, _expired_jwt_data: dict +) -> Response: """ By default, if an expired token attempts to access a protected endpoint, we return a generic error message with a 401 status """ - return jsonify({config.error_msg_key: "Token has expired"}), 401 + return jsonify({config.error_msg_key: "Token has expired"}), HTTPStatus.UNAUTHORIZED -def default_invalid_token_callback(error_string): +def default_invalid_token_callback(error_string: str) -> Response: """ By default, if an invalid token attempts to access a protected endpoint, we return the error string for why it is not valid with a 422 status code :param error_string: String indicating why the token is invalid """ - return jsonify({config.error_msg_key: error_string}), 422 + return ( + jsonify({config.error_msg_key: error_string}), + HTTPStatus.UNPROCESSABLE_ENTITY, + ) -def default_unauthorized_callback(error_string): +def default_unauthorized_callback(error_string: str) -> Response: """ By default, if a protected endpoint is accessed without a JWT, we return the error string indicating why this is unauthorized, with a 401 status code :param error_string: String indicating why this request is unauthorized """ - return jsonify({config.error_msg_key: error_string}), 401 + return jsonify({config.error_msg_key: error_string}), HTTPStatus.UNAUTHORIZED -def default_needs_fresh_token_callback(jwt_header, jwt_data): +def default_needs_fresh_token_callback(jwt_header: dict, jwt_data: dict) -> Response: """ By default, if a non-fresh jwt is used to access a ```fresh_jwt_required``` endpoint, we return a general error message with a 401 status code """ - return jsonify({config.error_msg_key: "Fresh token required"}), 401 + return ( + jsonify({config.error_msg_key: "Fresh token required"}), + HTTPStatus.UNAUTHORIZED, + ) -def default_revoked_token_callback(jwt_header, jwt_data): +def default_revoked_token_callback(jwt_header: dict, jwt_data: dict) -> Response: """ By default, if a revoked token is used to access a protected endpoint, we return a general error message with a 401 status code """ - return jsonify({config.error_msg_key: "Token has been revoked"}), 401 + return ( + jsonify({config.error_msg_key: "Token has been revoked"}), + HTTPStatus.UNAUTHORIZED, + ) -def default_user_lookup_error_callback(_jwt_header, jwt_data): +def default_user_lookup_error_callback(_jwt_header: dict, jwt_data: dict) -> Response: """ By default, if a user_lookup callback is defined and the callback function returns None, we return a general error message with a 401 status code """ identity = jwt_data[config.identity_claim_key] - result = {config.error_msg_key: "Error loading the user {}".format(identity)} - return jsonify(result), 401 + result = {config.error_msg_key: f"Error loading the user {identity}"} + return jsonify(result), HTTPStatus.UNAUTHORIZED -def default_token_verification_callback(_jwt_header, _jwt_data): +def default_token_verification_callback(_jwt_header: dict, _jwt_data: dict) -> bool: """ By default, we do not do any verification of the user claims. """ return True -def default_token_verification_failed_callback(_jwt_header, _jwt_data): +def default_token_verification_failed_callback( + _jwt_header: dict, _jwt_data: dict +) -> Response: """ By default, if the user claims verification failed, we return a generic error message with a 400 status code """ - return jsonify({config.error_msg_key: "User claims verification failed"}), 400 + return ( + jsonify({config.error_msg_key: "User claims verification failed"}), + HTTPStatus.BAD_REQUEST, + ) -def default_decode_key_callback(jwt_header, jwt_data): +def default_decode_key_callback(jwt_header: dict, jwt_data: dict) -> str: """ By default, the decode key specified via the JWT_SECRET_KEY or JWT_PUBLIC_KEY settings will be used to decode all tokens @@ -127,7 +147,7 @@ def default_decode_key_callback(jwt_header, jwt_data): return config.decode_key -def default_encode_key_callback(identity): +def default_encode_key_callback(identity: Any) -> str: """ By default, the encode key specified via the JWT_SECRET_KEY or JWT_PRIVATE_KEY settings will be used to encode all tokens diff --git a/flask_jwt_extended/exceptions.py b/flask_jwt_extended/exceptions.py index bc80d889..2ad414a8 100644 --- a/flask_jwt_extended/exceptions.py +++ b/flask_jwt_extended/exceptions.py @@ -60,7 +60,7 @@ class RevokedTokenError(JWTExtendedException): Error raised when a revoked token attempt to access a protected endpoint """ - def __init__(self, jwt_header, jwt_data): + def __init__(self, jwt_header: dict, jwt_data: dict) -> None: super().__init__("Token has been revoked") self.jwt_header = jwt_header self.jwt_data = jwt_data @@ -72,7 +72,7 @@ class FreshTokenRequired(JWTExtendedException): protected by fresh_jwt_required """ - def __init__(self, message, jwt_header, jwt_data): + def __init__(self, message, jwt_header: dict, jwt_data: dict) -> None: super().__init__(message) self.jwt_header = jwt_header self.jwt_data = jwt_data @@ -84,7 +84,7 @@ class UserLookupError(JWTExtendedException): that it cannot or will not load a user for the given identity. """ - def __init__(self, message, jwt_header, jwt_data): + def __init__(self, message, jwt_header: dict, jwt_data: dict) -> None: super().__init__(message) self.jwt_header = jwt_header self.jwt_data = jwt_data @@ -96,7 +96,7 @@ class UserClaimsVerificationError(JWTExtendedException): indicating that the expected user claims are invalid """ - def __init__(self, message, jwt_header, jwt_data): + def __init__(self, message, jwt_header: dict, jwt_data: dict) -> None: super().__init__(message) self.jwt_header = jwt_header self.jwt_data = jwt_data diff --git a/flask_jwt_extended/internal_utils.py b/flask_jwt_extended/internal_utils.py index ac18a855..3b6dc8ec 100644 --- a/flask_jwt_extended/internal_utils.py +++ b/flask_jwt_extended/internal_utils.py @@ -1,11 +1,14 @@ +from typing import Any + from flask import current_app +from flask_jwt_extended import JWTManager from flask_jwt_extended.exceptions import RevokedTokenError from flask_jwt_extended.exceptions import UserClaimsVerificationError from flask_jwt_extended.exceptions import WrongTokenError -def get_jwt_manager(): +def get_jwt_manager() -> JWTManager: try: return current_app.extensions["flask-jwt-extended"] except KeyError: # pragma: no cover @@ -15,30 +18,30 @@ def get_jwt_manager(): ) from None -def has_user_lookup(): +def has_user_lookup() -> bool: jwt_manager = get_jwt_manager() return jwt_manager._user_lookup_callback is not None -def user_lookup(*args, **kwargs): +def user_lookup(*args, **kwargs) -> Any: jwt_manager = get_jwt_manager() return jwt_manager._user_lookup_callback(*args, **kwargs) -def verify_token_type(decoded_token, refresh): +def verify_token_type(decoded_token: dict, refresh: bool) -> None: if not refresh and decoded_token["type"] == "refresh": raise WrongTokenError("Only non-refresh tokens are allowed") elif refresh and decoded_token["type"] != "refresh": raise WrongTokenError("Only refresh tokens are allowed") -def verify_token_not_blocklisted(jwt_header, jwt_data): +def verify_token_not_blocklisted(jwt_header: dict, jwt_data: dict) -> None: jwt_manager = get_jwt_manager() if jwt_manager._token_in_blocklist_callback(jwt_header, jwt_data): raise RevokedTokenError(jwt_header, jwt_data) -def custom_verification_for_token(jwt_header, jwt_data): +def custom_verification_for_token(jwt_header: dict, jwt_data: dict) -> None: jwt_manager = get_jwt_manager() if not jwt_manager._token_verification_callback(jwt_header, jwt_data): error_msg = "User claims verification failed" diff --git a/flask_jwt_extended/jwt_manager.py b/flask_jwt_extended/jwt_manager.py index b39d2c93..a8e032f0 100644 --- a/flask_jwt_extended/jwt_manager.py +++ b/flask_jwt_extended/jwt_manager.py @@ -1,6 +1,9 @@ import datetime +from typing import Any +from typing import Callable import jwt +from flask import Flask from jwt import DecodeError from jwt import ExpiredSignatureError from jwt import InvalidAudienceError @@ -49,7 +52,7 @@ class JWTManager(object): to your app in a factory function. """ - def __init__(self, app=None): + def __init__(self, app: Flask = None) -> None: """ Create the JWTManager instance. You can either pass a flask application in directly here to register this extension with the flask app, or @@ -82,7 +85,7 @@ def __init__(self, app=None): if app is not None: self.init_app(app) - def init_app(self, app): + def init_app(self, app: Flask) -> None: """ Register this extension with the flask app. @@ -98,7 +101,7 @@ def init_app(self, app): self._set_default_configuration_options(app) self._set_error_handler_callbacks(app) - def _set_error_handler_callbacks(self, app): + def _set_error_handler_callbacks(self, app: Flask) -> None: @app.errorhandler(CSRFError) def handle_csrf_error(e): return self._unauthorized_callback(str(e)) @@ -164,7 +167,7 @@ def handle_wrong_token_error(e): return self._invalid_token_callback(str(e)) @staticmethod - def _set_default_configuration_options(app): + def _set_default_configuration_options(app: Flask) -> None: app.config.setdefault( "JWT_ACCESS_TOKEN_EXPIRES", datetime.timedelta(minutes=15) ) @@ -210,7 +213,7 @@ def _set_default_configuration_options(app): app.config.setdefault("JWT_TOKEN_LOCATION", ("headers",)) app.config.setdefault("JWT_ENCODE_NBF", True) - def additional_claims_loader(self, callback): + def additional_claims_loader(self, callback: Callable) -> Callable: """ This decorator sets the callback function used to add additional claims when creating a JWT. The claims returned by this function will be merged @@ -227,7 +230,7 @@ def additional_claims_loader(self, callback): self._user_claims_callback = callback return callback - def additional_headers_loader(self, callback): + def additional_headers_loader(self, callback: Callable) -> Callable: """ This decorator sets the callback function used to add additional headers when creating a JWT. The headers returned by this function will be merged @@ -244,7 +247,7 @@ def additional_headers_loader(self, callback): self._jwt_additional_header_callback = callback return callback - def decode_key_loader(self, callback): + def decode_key_loader(self, callback: Callable) -> Callable: """ This decorator sets the callback function for dynamically setting the JWT decode key based on the **UNVERIFIED** contents of the token. Think @@ -265,7 +268,7 @@ def decode_key_loader(self, callback): self._decode_key_callback = callback return callback - def encode_key_loader(self, callback): + def encode_key_loader(self, callback: Callable) -> Callable: """ This decorator sets the callback function for dynamically setting the JWT encode key based on the tokens identity. Think carefully before using this @@ -281,7 +284,7 @@ def encode_key_loader(self, callback): self._encode_key_callback = callback return callback - def expired_token_loader(self, callback): + def expired_token_loader(self, callback: Callable) -> Callable: """ This decorator sets the callback function for returning a custom response when an expired JWT is encountered. @@ -297,7 +300,7 @@ def expired_token_loader(self, callback): self._expired_token_callback = callback return callback - def invalid_token_loader(self, callback): + def invalid_token_loader(self, callback: Callable) -> Callable: """ This decorator sets the callback function for returning a custom response when an invalid JWT is encountered. @@ -314,7 +317,7 @@ def invalid_token_loader(self, callback): self._invalid_token_callback = callback return callback - def needs_fresh_token_loader(self, callback): + def needs_fresh_token_loader(self, callback: Callable) -> Callable: """ This decorator sets the callback function for returning a custom response when a valid and non-fresh token is used on an endpoint @@ -331,7 +334,7 @@ def needs_fresh_token_loader(self, callback): self._needs_fresh_token_callback = callback return callback - def revoked_token_loader(self, callback): + def revoked_token_loader(self, callback: Callable) -> Callable: """ This decorator sets the callback function for returning a custom response when a revoked token is encountered. @@ -347,7 +350,7 @@ def revoked_token_loader(self, callback): self._revoked_token_callback = callback return callback - def token_in_blocklist_loader(self, callback): + def token_in_blocklist_loader(self, callback: Callable) -> Callable: """ This decorator sets the callback function used to check if a JWT has been revoked. @@ -364,7 +367,7 @@ def token_in_blocklist_loader(self, callback): self._token_in_blocklist_callback = callback return callback - def token_verification_failed_loader(self, callback): + def token_verification_failed_loader(self, callback: Callable) -> Callable: """ This decorator sets the callback function used to return a custom response when the claims verification check fails. @@ -380,7 +383,7 @@ def token_verification_failed_loader(self, callback): self._token_verification_failed_callback = callback return callback - def token_verification_loader(self, callback): + def token_verification_loader(self, callback: Callable) -> Callable: """ This decorator sets the callback function used for custom verification of a valid JWT. @@ -397,7 +400,7 @@ def token_verification_loader(self, callback): self._token_verification_callback = callback return callback - def unauthorized_loader(self, callback): + def unauthorized_loader(self, callback: Callable) -> Callable: """ This decorator sets the callback function used to return a custom response when no JWT is present. @@ -411,7 +414,7 @@ def unauthorized_loader(self, callback): self._unauthorized_callback = callback return callback - def user_identity_loader(self, callback): + def user_identity_loader(self, callback: Callable) -> Callable: """ This decorator sets the callback function used to convert an identity to a JSON serializable format when creating JWTs. This is useful for @@ -427,7 +430,7 @@ def user_identity_loader(self, callback): self._user_identity_callback = callback return callback - def user_lookup_loader(self, callback): + def user_lookup_loader(self, callback: Callable) -> Callable: """ This decorator sets the callback function used to convert a JWT into a python object that can be used in a protected endpoint. This is useful @@ -452,7 +455,7 @@ def user_lookup_loader(self, callback): self._user_lookup_callback = callback return callback - def user_lookup_error_loader(self, callback): + def user_lookup_error_loader(self, callback: Callable) -> Callable: """ This decorator sets the callback function used to return a custom response when loading a user via @@ -471,13 +474,13 @@ def user_lookup_error_loader(self, callback): def _encode_jwt_from_config( self, - identity, - token_type, + identity: Any, + token_type: str, claims=None, - fresh=False, - expires_delta=None, + fresh: bool = False, + expires_delta: datetime.timedelta = None, headers=None, - ): + ) -> str: header_overrides = self._jwt_additional_header_callback(identity) if headers is not None: header_overrides.update(headers) @@ -510,8 +513,8 @@ def _encode_jwt_from_config( ) def _decode_jwt_from_config( - self, encoded_token, csrf_value=None, allow_expired=False - ): + self, encoded_token: str, csrf_value=None, allow_expired: bool = False + ) -> dict: unverified_claims = jwt.decode( encoded_token, algorithms=config.decode_algorithms, diff --git a/flask_jwt_extended/tokens.py b/flask_jwt_extended/tokens.py index 2b98b8d0..b72efe14 100644 --- a/flask_jwt_extended/tokens.py +++ b/flask_jwt_extended/tokens.py @@ -3,29 +3,33 @@ from datetime import timedelta from datetime import timezone from hmac import compare_digest +from typing import Any +from typing import Iterable +from typing import Union import jwt +from flask.json import JSONEncoder from flask_jwt_extended.exceptions import CSRFError from flask_jwt_extended.exceptions import JWTDecodeError def _encode_jwt( - algorithm, - audience, - claim_overrides, - csrf, - expires_delta, - fresh, - header_overrides, - identity, - identity_claim_key, - issuer, - json_encoder, - secret, - token_type, - nbf, -): + algorithm: str, + audience: Union[str, Iterable[str]], + claim_overrides: dict, + csrf: bool, + expires_delta: timedelta, + fresh: bool, + header_overrides: dict, + identity: Any, + identity_claim_key: str, + issuer: str, + json_encoder: JSONEncoder, + secret: str, + token_type: str, + nbf: bool, +) -> str: now = datetime.now(timezone.utc) if isinstance(fresh, timedelta): @@ -67,17 +71,17 @@ def _encode_jwt( def _decode_jwt( - algorithms, - allow_expired, - audience, - csrf_value, - encoded_token, - identity_claim_key, - issuer, - leeway, - secret, - verify_aud, -): + algorithms: Iterable, + allow_expired: bool, + audience: Union[str, Iterable[str]], + csrf_value: str, + encoded_token: str, + identity_claim_key: str, + issuer: str, + leeway: int, + secret: str, + verify_aud: bool, +) -> dict: options = {"verify_aud": verify_aud} if allow_expired: options["verify_exp"] = False diff --git a/flask_jwt_extended/utils.py b/flask_jwt_extended/utils.py index 4a560a01..e7df351c 100644 --- a/flask_jwt_extended/utils.py +++ b/flask_jwt_extended/utils.py @@ -1,16 +1,19 @@ +import datetime +from typing import Any + import jwt from flask import _request_ctx_stack +from flask import Response from werkzeug.local import LocalProxy from flask_jwt_extended.config import config from flask_jwt_extended.internal_utils import get_jwt_manager - # Proxy to access the current user current_user = LocalProxy(lambda: get_current_user()) -def get_jwt(): +def get_jwt() -> dict: """ In a protected endpoint, this will return the python dictionary which has the payload of the JWT that is accessing the endpoint. If no JWT is present @@ -28,7 +31,7 @@ def get_jwt(): return decoded_jwt -def get_jwt_header(): +def get_jwt_header() -> dict: """ In a protected endpoint, this will return the python dictionary which has the header of the JWT that is accessing the endpoint. If no JWT is present @@ -46,7 +49,7 @@ def get_jwt_header(): return decoded_header -def get_jwt_identity(): +def get_jwt_identity() -> Any: """ In a protected endpoint, this will return the identity of the JWT that is accessing the endpoint. If no JWT is present due to @@ -58,7 +61,7 @@ def get_jwt_identity(): return get_jwt().get(config.identity_claim_key, None) -def get_jwt_request_location(): +def get_jwt_request_location() -> str: """ In a protected endpoint, this will return the "location" at which the JWT that is accessing the endpoint was found--e.g., "cookies", "query-string", @@ -66,14 +69,14 @@ def get_jwt_request_location(): None is returned. :return: - The location of the JWT in the current request; e.g., cookies", + The location of the JWT in the current request; e.g., "cookies", "query-string", "headers", or "json" """ location = getattr(_request_ctx_stack.top, "jwt_location", None) return location -def get_current_user(): +def get_current_user() -> Any: """ In a protected endpoint, this will return the user object for the JWT that is accessing the endpoint. @@ -97,14 +100,16 @@ def get_current_user(): return jwt_user_dict["loaded_user"] -def decode_token(encoded_token, csrf_value=None, allow_expired=False): +def decode_token( + encoded_token: str, csrf_value: str = None, allow_expired: bool = False +) -> dict: """ Returns the decoded token (python dict) from an encoded JWT. This does all - the checks to insure that the decoded token is valid before returning it. + the checks to ensure that the decoded token is valid before returning it. This will not fire the user loader callbacks, save the token for access in protected endpoints, checked if a token is revoked, etc. This is puerly - used to insure that a JWT is valid. + used to ensure that a JWT is valid. :param encoded_token: The encoded JWT to decode. @@ -123,9 +128,9 @@ def decode_token(encoded_token, csrf_value=None, allow_expired=False): def create_access_token( - identity, - fresh=False, - expires_delta=None, + identity: Any, + fresh: bool = False, + expires_delta: datetime.timedelta = None, additional_claims=None, additional_headers=None, ): @@ -177,7 +182,10 @@ def create_access_token( def create_refresh_token( - identity, expires_delta=None, additional_claims=None, additional_headers=None + identity: Any, + expires_delta: datetime.timedelta = None, + additional_claims=None, + additional_headers=None, ): """ Create a new refresh token. @@ -219,7 +227,7 @@ def create_refresh_token( ) -def get_unverified_jwt_headers(encoded_token): +def get_unverified_jwt_headers(encoded_token: str) -> dict: """ Returns the Headers of an encoded JWT without verifying the signature of the JWT. @@ -232,7 +240,7 @@ def get_unverified_jwt_headers(encoded_token): return jwt.get_unverified_header(encoded_token) -def get_jti(encoded_token): +def get_jti(encoded_token: str) -> str: """ Returns the JTI (unique identifier) of an encoded JWT @@ -245,7 +253,7 @@ def get_jti(encoded_token): return decode_token(encoded_token).get("jti") -def get_csrf_token(encoded_token): +def get_csrf_token(encoded_token: str) -> str: """ Returns the CSRF double submit token from an encoded JWT. @@ -259,7 +267,9 @@ def get_csrf_token(encoded_token): return token["csrf"] -def set_access_cookies(response, encoded_access_token, max_age=None, domain=None): +def set_access_cookies( + response: Response, encoded_access_token: str, max_age=None, domain=None +) -> None: """ Modifiy a Flask Response to set a cookie containing the access JWT. Also sets the corresponding CSRF cookies if ``JWT_CSRF_IN_COOKIES`` is ``True`` @@ -307,7 +317,12 @@ def set_access_cookies(response, encoded_access_token, max_age=None, domain=None ) -def set_refresh_cookies(response, encoded_refresh_token, max_age=None, domain=None): +def set_refresh_cookies( + response: Response, + encoded_refresh_token: str, + max_age: int = None, + domain: str = None, +) -> None: """ Modifiy a Flask Response to set a cookie containing the refresh JWT. Also sets the corresponding CSRF cookies if ``JWT_CSRF_IN_COOKIES`` is ``True`` @@ -355,7 +370,7 @@ def set_refresh_cookies(response, encoded_refresh_token, max_age=None, domain=No ) -def unset_jwt_cookies(response, domain=None): +def unset_jwt_cookies(response: Response, domain: str = None) -> None: """ Modifiy a Flask Response to delete the cookies containing access or refresh JWTs. Also deletes the corresponding CSRF cookies if applicable. @@ -367,7 +382,7 @@ def unset_jwt_cookies(response, domain=None): unset_refresh_cookies(response, domain) -def unset_access_cookies(response, domain=None): +def unset_access_cookies(response: Response, domain: str = None) -> None: """ Modifiy a Flask Response to delete the cookie containing an access JWT. Also deletes the corresponding CSRF cookie if applicable. @@ -405,7 +420,7 @@ def unset_access_cookies(response, domain=None): ) -def unset_refresh_cookies(response, domain=None): +def unset_refresh_cookies(response: Response, domain: str = None) -> None: """ Modifiy a Flask Response to delete the cookie containing a refresh JWT. Also deletes the corresponding CSRF cookie if applicable. diff --git a/flask_jwt_extended/view_decorators.py b/flask_jwt_extended/view_decorators.py index 55cfec17..3b96f250 100644 --- a/flask_jwt_extended/view_decorators.py +++ b/flask_jwt_extended/view_decorators.py @@ -2,6 +2,10 @@ from datetime import timezone from functools import wraps from re import split +from typing import Any +from typing import Iterable +from typing import Tuple +from typing import Union from flask import _request_ctx_stack from flask import current_app @@ -23,8 +27,10 @@ from flask_jwt_extended.utils import decode_token from flask_jwt_extended.utils import get_unverified_jwt_headers +LocationType = Union[str, Iterable] -def _verify_token_is_fresh(jwt_header, jwt_data): + +def _verify_token_is_fresh(jwt_header: dict, jwt_data: dict) -> None: fresh = jwt_data["fresh"] if isinstance(fresh, bool): if not fresh: @@ -35,7 +41,13 @@ def _verify_token_is_fresh(jwt_header, jwt_data): raise FreshTokenRequired("Fresh token required", jwt_header, jwt_data) -def verify_jwt_in_request(optional=False, fresh=False, refresh=False, locations=None): +def verify_jwt_in_request( + optional: bool = False, + fresh: bool = False, + refresh: bool = False, + locations: LocationType = None, + verify_type: bool = False, +) -> Tuple[dict, dict]: """ Verify that a valid JWT is present in the request, unless ``optional=True`` in which case no JWT is also considered valid. @@ -49,26 +61,28 @@ def verify_jwt_in_request(optional=False, fresh=False, refresh=False, locations= Defaults to ``False``. :param refresh: - If ``True``, require a refresh JWT to be verified. + If ``True``, requires a refresh JWT to access this endpoint. If ``False``, + requires an access JWT to access this endpoint. Defaults to ``False`` :param locations: A location or list of locations to look for the JWT in this request, for example ``'headers'`` or ``['headers', 'cookies']``. Defaults to ``None`` which indicates that JWTs will be looked for in the locations defined by the ``JWT_TOKEN_LOCATION`` configuration option. + + :param verify_type: + If ``True``, the token type (access or refresh) will be checked according + to the ``refresh`` argument. If ``False``, type will not be checked and both + access and refresh tokens will be accepted. """ if request.method in config.exempt_methods: return try: - if refresh: - jwt_data, jwt_header, jwt_location = _decode_jwt_from_request( - locations, fresh, refresh=True - ) - else: - jwt_data, jwt_header, jwt_location = _decode_jwt_from_request( - locations, fresh - ) + jwt_data, jwt_header, jwt_location = _decode_jwt_from_request( + locations, fresh, refresh=refresh, verify_type=verify_type + ) + except NoAuthorizationError: if not optional: raise @@ -88,7 +102,13 @@ def verify_jwt_in_request(optional=False, fresh=False, refresh=False, locations= return jwt_header, jwt_data -def jwt_required(optional=False, fresh=False, refresh=False, locations=None): +def jwt_required( + optional: bool = False, + fresh: bool = False, + refresh: bool = False, + locations: LocationType = None, + verify_type: bool = True, +) -> Any: """ A decorator to protect a Flask endpoint with JSON Web Tokens. @@ -113,12 +133,17 @@ def jwt_required(optional=False, fresh=False, refresh=False, locations=None): example ``'headers'`` or ``['headers', 'cookies']``. Defaults to ``None`` which indicates that JWTs will be looked for in the locations defined by the ``JWT_TOKEN_LOCATION`` configuration option. + + :param verify_type: + If ``True``, the token type (access or refresh) will be checked according + to the ``refresh`` argument. If ``False``, type will not be checked and both + access and refresh tokens will be accepted. """ def wrapper(fn): @wraps(fn) def decorator(*args, **kwargs): - verify_jwt_in_request(optional, fresh, refresh, locations) + verify_jwt_in_request(optional, fresh, refresh, locations, verify_type) # Compatibility with flask < 2.0 if hasattr(current_app, "ensure_sync") and callable( @@ -133,7 +158,7 @@ def decorator(*args, **kwargs): return wrapper -def _load_user(jwt_header, jwt_data): +def _load_user(jwt_header: dict, jwt_data: dict) -> dict: if not has_user_lookup(): return None @@ -145,7 +170,7 @@ def _load_user(jwt_header, jwt_data): return {"loaded_user": user} -def _decode_jwt_from_headers(): +def _decode_jwt_from_headers() -> Tuple[str, str]: header_name = config.header_name header_type = config.header_type @@ -189,7 +214,7 @@ def _decode_jwt_from_headers(): return encoded_token, None -def _decode_jwt_from_cookies(refresh): +def _decode_jwt_from_cookies(refresh: bool) -> Tuple[str, str]: if refresh: cookie_key = config.refresh_cookie_name csrf_header_key = config.refresh_csrf_header_name @@ -215,7 +240,7 @@ def _decode_jwt_from_cookies(refresh): return encoded_token, csrf_value -def _decode_jwt_from_query_string(): +def _decode_jwt_from_query_string() -> Tuple[str, str]: param_name = config.query_string_name prefix = config.query_string_value_prefix @@ -233,7 +258,7 @@ def _decode_jwt_from_query_string(): return encoded_token, None -def _decode_jwt_from_json(refresh): +def _decode_jwt_from_json(refresh: bool) -> Tuple[str, str]: content_type = request.content_type or "" if not content_type.startswith("application/json"): raise NoAuthorizationError("Invalid content-type. Must be application/json.") @@ -255,7 +280,12 @@ def _decode_jwt_from_json(refresh): return encoded_token, None -def _decode_jwt_from_request(locations, fresh, refresh=False): +def _decode_jwt_from_request( + locations: LocationType, + fresh: bool, + refresh: bool = False, + verify_type: bool = True, +) -> Tuple[dict, dict, str]: # Figure out what locations to look for the JWT in this request if isinstance(locations, str): locations = [locations] @@ -314,7 +344,9 @@ def _decode_jwt_from_request(locations, fresh, refresh=False): raise NoAuthorizationError(errors[0]) # Additional verifications provided by this extension - verify_token_type(decoded_token, refresh) + if verify_type: + verify_token_type(decoded_token, refresh) + if fresh: _verify_token_is_fresh(jwt_header, decoded_token) verify_token_not_blocklisted(jwt_header, decoded_token) diff --git a/requirements.txt b/requirements.txt index 082b8775..1a3cfd7a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,5 +4,5 @@ Flask==2.0.1 Pallets-Sphinx-Themes==2.0.1 pre-commit==2.13.0 PyJWT==2.1.0 -Sphinx==4.0.2 +Sphinx==4.3.2 tox==3.23.1 diff --git a/tests/test_asymmetric_crypto.py b/tests/test_asymmetric_crypto.py index 40d0d8ee..5bd067d9 100644 --- a/tests/test_asymmetric_crypto.py +++ b/tests/test_asymmetric_crypto.py @@ -57,13 +57,13 @@ def test_asymmetric_cropto(app): app.config["JWT_ALGORITHM"] = "RS256" rs256_token = create_access_token("username") - # Insure the symmetric token does not work now + # Ensure the symmetric token does not work now access_headers = {"Authorization": "Bearer {}".format(hs256_token)} response = test_client.get("/protected", headers=access_headers) assert response.status_code == 422 assert response.get_json() == {"msg": "The specified alg value is not allowed"} - # Insure the asymmetric token does work + # Ensure the asymmetric token does work access_headers = {"Authorization": "Bearer {}".format(rs256_token)} response = test_client.get("/protected", headers=access_headers) assert response.status_code == 200 diff --git a/tests/test_cookies.py b/tests/test_cookies.py index 9cc33247..a1d2d483 100644 --- a/tests/test_cookies.py +++ b/tests/test_cookies.py @@ -301,17 +301,17 @@ def test_custom_csrf_methods(app, options): response = test_client.get(auth_url) csrf_token = _get_cookie_from_response(response, csrf_cookie_name)[csrf_cookie_name] - # Insure we can now do posts without csrf + # Ensure we can now do posts without csrf response = test_client.post(post_url) assert response.status_code == 200 assert response.get_json() == {"foo": "bar"} - # Insure GET requests now fail without csrf + # Ensure GET requests now fail without csrf response = test_client.get(get_url) assert response.status_code == 401 assert response.get_json() == {"msg": "Missing CSRF token"} - # Insure GET requests now succeed with csrf + # Ensure GET requests now succeed with csrf csrf_headers = {"X-CSRF-TOKEN": csrf_token} response = test_client.get(get_url, headers=csrf_headers) assert response.status_code == 200 diff --git a/tests/test_decode_tokens.py b/tests/test_decode_tokens.py index 6c98e4f6..4d13f787 100644 --- a/tests/test_decode_tokens.py +++ b/tests/test_decode_tokens.py @@ -172,13 +172,13 @@ def test_nbf_token_in_future(app): def test_alternate_identity_claim(app, default_access_token): app.config["JWT_IDENTITY_CLAIM"] = "banana" - # Insure decoding fails if the claim isn't there + # Ensure decoding fails if the claim isn't there token = encode_token(app, default_access_token) with pytest.raises(JWTDecodeError): with app.test_request_context(): decode_token(token) - # Insure the claim exists in the decoded jwt + # Ensure the claim exists in the decoded jwt del default_access_token["sub"] default_access_token["banana"] = "username" token = encode_token(app, default_access_token) diff --git a/tests/test_headers.py b/tests/test_headers.py index a1a3aef5..fb3b8252 100644 --- a/tests/test_headers.py +++ b/tests/test_headers.py @@ -76,13 +76,13 @@ def test_custom_header_name(app): with app.test_request_context(): access_token = create_access_token("username") - # Insure 'default' headers no longer work + # Ensure 'default' headers no longer work access_headers = {"Authorization": "Bearer {}".format(access_token)} response = test_client.get("/protected", headers=access_headers) assert response.status_code == 401 assert response.get_json() == {"msg": "Missing Foo Header"} - # Insure new headers do work + # Ensure new headers do work access_headers = {"Foo": "Bearer {}".format(access_token)} response = test_client.get("/protected", headers=access_headers) assert response.status_code == 200 @@ -108,7 +108,7 @@ def test_custom_header_type(app): with app.test_request_context(): access_token = create_access_token("username") - # Insure 'default' headers no longer work + # Ensure 'default' headers no longer work access_headers = {"Authorization": "Bearer {}".format(access_token)} response = test_client.get("/protected", headers=access_headers) error_msg = ( @@ -118,7 +118,7 @@ def test_custom_header_type(app): assert response.status_code == 401 assert response.get_json() == {"msg": error_msg} - # Insure new headers do work + # Ensure new headers do work access_headers = {"Authorization": "JWT {}".format(access_token)} response = test_client.get("/protected", headers=access_headers) assert response.status_code == 200 @@ -136,14 +136,14 @@ def test_custom_header_type(app): assert response.status_code == 200 assert response.get_json() == {"foo": "bar"} - # Insure new headers without a type also work + # Ensure new headers without a type also work app.config["JWT_HEADER_TYPE"] = "" access_headers = {"Authorization": access_token} response = test_client.get("/protected", headers=access_headers) assert response.status_code == 200 assert response.get_json() == {"foo": "bar"} - # Insure header with too many parts fails + # Ensure header with too many parts fails app.config["JWT_HEADER_TYPE"] = "" access_headers = {"Authorization": "Bearer {}".format(access_token)} response = test_client.get("/protected", headers=access_headers) @@ -156,7 +156,7 @@ def test_missing_headers(app): test_client = app.test_client() jwtM = get_jwt_manager(app) - # Insure 'default' no headers response + # Ensure 'default' no headers response response = test_client.get("/protected", headers=None) assert response.status_code == 401 assert response.get_json() == {"msg": "Missing Authorization Header"} diff --git a/tests/test_query_string.py b/tests/test_query_string.py index 2b5c6586..6fc9d8a8 100644 --- a/tests/test_query_string.py +++ b/tests/test_query_string.py @@ -66,13 +66,13 @@ def test_custom_query_paramater(app): with app.test_request_context(): access_token = create_access_token("username") - # Insure 'default' query paramaters no longer work + # Ensure 'default' query paramaters no longer work url = "/protected?jwt={}".format(access_token) response = test_client.get(url) assert response.status_code == 401 assert response.get_json() == {"msg": "Missing 'foo' query paramater"} - # Insure new query_string does work + # Ensure new query_string does work url = "/protected?foo={}".format(access_token) response = test_client.get(url) assert response.status_code == 200 @@ -86,12 +86,12 @@ def test_missing_query_paramater(app): with app.test_request_context(): access_token = create_access_token("username") - # Insure no query paramaters doesn't give a response + # Ensure no query paramaters doesn't give a response response = test_client.get("/protected") assert response.status_code == 401 assert response.get_json() == {"msg": "Missing 'jwt' query paramater"} - # Insure headers don't work + # Ensure headers don't work access_headers = {"Authorization": "Bearer {}".format(access_token)} response = test_client.get("/protected", headers=access_headers) assert response.status_code == 401 diff --git a/tests/test_view_decorators.py b/tests/test_view_decorators.py index 6a34c0bb..7b88b522 100644 --- a/tests/test_view_decorators.py +++ b/tests/test_view_decorators.py @@ -46,6 +46,11 @@ def optional_protected(): else: return jsonify(foo="bar") + @app.route("/no_typecheck_protected", methods=["GET"]) + @jwt_required(verify_type=False) + def no_typecheck_protected(): + return jsonify(foo="bar") + return app @@ -153,6 +158,26 @@ def test_refresh_jwt_required(app): assert response.get_json() == {"foo": "bar"} +def test_jwt_required_no_typecheck(app): + """Verify this route works with access or refresh tokens.""" + url = "/no_typecheck_protected" + + test_client = app.test_client() + with app.test_request_context(): + access_token = create_access_token("username") + fresh_access_token = create_access_token("username", fresh=True) + refresh_token = create_refresh_token("username") + + for token in (access_token, fresh_access_token, refresh_token): + response = test_client.get(url, headers=make_headers(token)) + assert response.status_code == 200 + assert response.get_json() == {"foo": "bar"} + + response = test_client.get(url, headers=None) + assert response.status_code == 401 + assert response.get_json() == {"msg": "Missing Authorization Header"} + + @pytest.mark.parametrize("delta_func", [timedelta, relativedelta]) def test_jwt_optional(app, delta_func): url = "/optional_protected"