Skip to content

Commit

Permalink
fix: make token refresh robust to network errors
Browse files Browse the repository at this point in the history
  • Loading branch information
irishrain committed Dec 31, 2024
1 parent ffe905d commit 2d488eb
Show file tree
Hide file tree
Showing 4 changed files with 369 additions and 12 deletions.
44 changes: 41 additions & 3 deletions backend/app/controller/auth/auth_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,52 @@
# Callback function to check if a JWT exists in the database blocklist
@jwt.token_in_blocklist_loader
def check_if_token_revoked(jwt_header, jwt_payload: dict) -> bool:
from app import db
jti = jwt_payload["jti"]
token = Token.find_by_jti(jti)
if token is not None:
# Check for invalidated refresh tokens first
if token.type == "invalidated_refresh":
# Delete any remaining tokens in the family
token.delete_token_familiy()
return True

# Check if this token's chain has been superseded
if token.type == "access":
if token.refresh_token: # This token has a parent
# Check if parent is not a valid refresh token
if token.refresh_token.type != "refresh":
token.refresh_token.delete_token_familiy()
return True
# Check if there are any newer tokens that have been used
newer_used = db.session.query(Token).filter(
Token.refresh_token_id == token.refresh_token.id,
Token.last_used_at != None,
db.or_(
Token.type == "refresh",
db.and_(Token.type == "access", Token.id != token.id)
)
).first()
if newer_used:
return True

if token.last_used_at is None:
# First use of this token
if token.type == "access":
# When an access token is first used, invalidate all other access tokens from the same refresh token chain
if token.refresh_token: # This token has a parent
# Delete any access tokens associated with the parent, except this one
token.refresh_token.delete_created_access_tokens(exclude_token_id=token.id)
# Also delete any access tokens from the parents parent chain
if token.refresh_token.refresh_token:
token.refresh_token.refresh_token.delete_created_access_tokens()

token.last_used_at = datetime.now(timezone.utc)
token.user.last_seen = token.last_used_at
token.save()

return token is None
db.session.commit()
return False
else:
return True


# Register a callback function that takes whatever object is passed in as the
Expand Down
57 changes: 48 additions & 9 deletions backend/app/models/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def find_by_jti(cls, jti: str) -> Self:
@classmethod
def delete_expired_refresh(cls):
filter_before = datetime.now(timezone.utc) - JWT_REFRESH_TOKEN_EXPIRES

# Delete expired regular refresh tokens with no children
for token in (
db.session.query(cls)
.filter(
Expand All @@ -58,6 +60,13 @@ def delete_expired_refresh(cls):
.all()
):
token.delete_token_familiy(commit=False)

# Delete expired invalidated refresh tokens
db.session.query(cls).filter(
cls.created_at <= filter_before,
cls.type == "invalidated_refresh"
).delete()

db.session.commit()

@classmethod
Expand All @@ -71,8 +80,9 @@ def delete_expired_access(cls):
# Delete oldest refresh token -> log out device
# Used e.g. when a refresh token is used twice
def delete_token_familiy(self, commit=True):
if self.type != "refresh":
if self.type not in ["refresh", "invalidated_refresh"]:
return

token = self
while token:
if token.refresh_token:
Expand All @@ -91,12 +101,15 @@ def has_created_refresh_token(self) -> bool:
> 0
)

def delete_created_access_tokens(self):
def delete_created_access_tokens(self, exclude_token_id=None):
if self.type != "refresh":
return
db.session.query(Token).filter(
query = db.session.query(Token).filter(
Token.refresh_token_id == self.id, Token.type == "access"
).delete()
)
if exclude_token_id is not None:
query = query.filter(Token.id != exclude_token_id)
query.delete()
db.session.commit()

@classmethod
Expand All @@ -118,25 +131,51 @@ def create_refresh_token(
cls, user: User, device: str | None = None, oldRefreshToken: Self | None = None
) -> Tuple[str, Self]:
assert device or oldRefreshToken
if oldRefreshToken and (
oldRefreshToken.type != "refresh"
or oldRefreshToken.has_created_refresh_token()
):
if oldRefreshToken and oldRefreshToken.type != "refresh":
oldRefreshToken.delete_token_familiy()
raise UnauthorizedRequest(
message="Unauthorized: IP {} reused the same refresh token, logging out user".format(
request.remote_addr
)
)

# Check if this refresh token has already been used to create another refresh token
if oldRefreshToken and oldRefreshToken.has_created_refresh_token():
newer_token = db.session.query(Token).filter(
Token.refresh_token_id == oldRefreshToken.id,
Token.type == "refresh"
).first()

if newer_token:
newer_access_used = db.session.query(Token).filter(
Token.refresh_token_id == newer_token.id,
Token.type == "access",
Token.last_used_at != None
).count() > 0

if newer_token.last_used_at is not None or newer_access_used:
# The newer tokens have been used, this is a reuse attack
oldRefreshToken.delete_token_familiy()
raise UnauthorizedRequest(
message="Unauthorized: IP {} reused the same refresh token, logging out user".format(
request.remote_addr
)
)
else:
# Only invalidate the unused parallel refresh token chain
for token in db.session.query(Token).filter(
Token.refresh_token_id == newer_token.id
).all():
db.session.delete(token)
newer_token.type = "invalidated_refresh"

refreshToken = create_refresh_token(identity=user)
model = cls()
model.jti = get_jti(refreshToken)
model.type = "refresh"
model.name = device or oldRefreshToken.name
model.user = user
if oldRefreshToken:
oldRefreshToken.delete_created_access_tokens()
model.refresh_token = oldRefreshToken
model.save()
return refreshToken, model
Expand Down
4 changes: 4 additions & 0 deletions backend/tests/api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ def admin_client(client, admin_username, admin_name, admin_password):
'password': admin_password
}
response = client.post('/api/onboarding', json=onboard_data)
assert response.status_code == 200, f"Failed to onboard admin: {response.get_json()}"
data = response.get_json()
assert 'access_token' in data, f"No access token in response: {data}"
client.environ_base['HTTP_AUTHORIZATION'] = f'Bearer {data["access_token"]}'
return client

Expand All @@ -103,11 +105,13 @@ def user_client(admin_client, username, name, password):
'password': password
}
response = admin_client.post('/api/user/new', json=data)
assert response.status_code == 200, f"Failed to create user: {response.get_json()}"
data = {
'username': username,
'password': password
}
response = admin_client.post('/api/auth', json=data)
assert response.status_code == 200, f"Failed to login: {response.get_json()}"
data = response.get_json()
admin_client.environ_base['HTTP_AUTHORIZATION'] = f'Bearer {data["access_token"]}'
return admin_client
Expand Down
Loading

0 comments on commit 2d488eb

Please sign in to comment.