Skip to content

Commit cee8d2f

Browse files
committed
Fix for some tests.
Signed-off-by: Pavel Kirilin <[email protected]>
1 parent 99da6b8 commit cee8d2f

File tree

4 files changed

+84
-43
lines changed

4 files changed

+84
-43
lines changed

Diff for: tests/conftest.py

+16
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,22 @@
11
import pytest
22
from fastapi_jwt_auth import AuthJWT
3+
from fastapi_jwt_auth.config import LoadConfig
34

45
@pytest.fixture(scope="module")
56
def Authorize():
67
return AuthJWT()
8+
9+
10+
@pytest.fixture(autouse=True)
11+
def reset_config():
12+
"""
13+
Resets config to default to
14+
guarantee that config is unchanged after test.
15+
"""
16+
yield
17+
@AuthJWT.load_config
18+
def default_conf():
19+
return LoadConfig(
20+
authjwt_secret_key="secret",
21+
authjwt_cookie_samesite='strict',
22+
)

Diff for: tests/test_create_token.py

+14-9
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
from pydantic import BaseSettings
44
from datetime import timedelta, datetime, timezone
55

6-
def test_create_access_token(Authorize):
6+
7+
8+
@pytest.fixture()
9+
def test_settings() -> None:
710
class Settings(BaseSettings):
811
AUTHJWT_SECRET_KEY: str = "testing"
912
AUTHJWT_ACCESS_TOKEN_EXPIRES: int = 2
@@ -13,6 +16,8 @@ class Settings(BaseSettings):
1316
def get_settings():
1417
return Settings()
1518

19+
def test_create_access_token(Authorize, test_settings):
20+
1621
with pytest.raises(TypeError,match=r"missing 1 required positional argument"):
1722
Authorize.create_access_token()
1823

@@ -25,7 +30,7 @@ def get_settings():
2530
with pytest.raises(ValueError,match=r"dictionary update sequence element"):
2631
Authorize.create_access_token(subject=1,headers="test")
2732

28-
def test_create_refresh_token(Authorize):
33+
def test_create_refresh_token(Authorize, test_settings):
2934
with pytest.raises(TypeError,match=r"missing 1 required positional argument"):
3035
Authorize.create_refresh_token()
3136

@@ -35,7 +40,7 @@ def test_create_refresh_token(Authorize):
3540
with pytest.raises(ValueError,match=r"dictionary update sequence element"):
3641
Authorize.create_refresh_token(subject=1,headers="test")
3742

38-
def test_create_dynamic_access_token_expires(Authorize):
43+
def test_create_dynamic_access_token_expires(Authorize, test_settings):
3944
expires_time = int(datetime.now(timezone.utc).timestamp()) + 90
4045
token = Authorize.create_access_token(subject=1,expires_time=90)
4146
assert jwt.decode(token,"testing",algorithms="HS256")['exp'] == expires_time
@@ -54,7 +59,7 @@ def test_create_dynamic_access_token_expires(Authorize):
5459
with pytest.raises(TypeError,match=r"expires_time"):
5560
Authorize.create_access_token(subject=1,expires_time="test")
5661

57-
def test_create_dynamic_refresh_token_expires(Authorize):
62+
def test_create_dynamic_refresh_token_expires(Authorize, test_settings):
5863
expires_time = int(datetime.now(timezone.utc).timestamp()) + 90
5964
token = Authorize.create_refresh_token(subject=1,expires_time=90)
6065
assert jwt.decode(token,"testing",algorithms="HS256")['exp'] == expires_time
@@ -73,34 +78,34 @@ def test_create_dynamic_refresh_token_expires(Authorize):
7378
with pytest.raises(TypeError,match=r"expires_time"):
7479
Authorize.create_refresh_token(subject=1,expires_time="test")
7580

76-
def test_create_token_invalid_type_data_audience(Authorize):
81+
def test_create_token_invalid_type_data_audience(Authorize, test_settings):
7782
with pytest.raises(TypeError,match=r"audience"):
7883
Authorize.create_access_token(subject=1,audience=1)
7984

8085
with pytest.raises(TypeError,match=r"audience"):
8186
Authorize.create_refresh_token(subject=1,audience=1)
8287

83-
def test_create_token_invalid_algorithm(Authorize):
88+
def test_create_token_invalid_algorithm(Authorize, test_settings):
8489
with pytest.raises(ValueError,match=r"Algorithm"):
8590
Authorize.create_access_token(subject=1,algorithm="test")
8691

8792
with pytest.raises(ValueError,match=r"Algorithm"):
8893
Authorize.create_refresh_token(subject=1,algorithm="test")
8994

90-
def test_create_token_invalid_type_data_algorithm(Authorize):
95+
def test_create_token_invalid_type_data_algorithm(Authorize, test_settings):
9196
with pytest.raises(TypeError,match=r"algorithm"):
9297
Authorize.create_access_token(subject=1,algorithm=1)
9398

9499
with pytest.raises(TypeError,match=r"algorithm"):
95100
Authorize.create_refresh_token(subject=1,algorithm=1)
96101

97-
def test_create_token_invalid_user_claims(Authorize):
102+
def test_create_token_invalid_user_claims(Authorize, test_settings):
98103
with pytest.raises(TypeError,match=r"user_claims"):
99104
Authorize.create_access_token(subject=1,user_claims="asd")
100105
with pytest.raises(TypeError,match=r"user_claims"):
101106
Authorize.create_refresh_token(subject=1,user_claims="asd")
102107

103-
def test_create_valid_user_claims(Authorize):
108+
def test_create_valid_user_claims(Authorize, test_settings):
104109
access_token = Authorize.create_access_token(subject=1,user_claims={"my_access":"yeah"})
105110
refresh_token = Authorize.create_refresh_token(subject=1,user_claims={"my_refresh":"hello"})
106111

Diff for: tests/test_decode_token.py

+23-10
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,19 @@ def default_access_token():
4949
'fresh': True,
5050
}
5151

52+
@pytest.fixture()
53+
def test_settings() -> None:
54+
class TestSettings(BaseSettings):
55+
AUTHJWT_SECRET_KEY: str = "secret-key"
56+
AUTHJWT_ACCESS_TOKEN_EXPIRES: int = 1
57+
AUTHJWT_REFRESH_TOKEN_EXPIRES: int = 1
58+
AUTHJWT_DECODE_LEEWAY: int = 2
59+
60+
@AuthJWT.load_config
61+
def load():
62+
return TestSettings()
63+
64+
5265
@pytest.fixture(scope='function')
5366
def encoded_token(default_access_token):
5467
return jwt.encode(default_access_token,'secret-key',algorithm='HS256').decode('utf-8')
@@ -111,23 +124,23 @@ def get_settings_two():
111124
assert response.status_code == 200
112125
assert response.json() == {'hello':'world'}
113126

114-
def test_get_raw_token(client,default_access_token,encoded_token):
127+
def test_get_raw_token(client,default_access_token,encoded_token,test_settings):
115128
response = client.get('/raw_token',headers={"Authorization":f"Bearer {encoded_token}"})
116129
assert response.status_code == 200
117130
assert response.json() == default_access_token
118131

119-
def test_get_raw_jwt(default_access_token,encoded_token,Authorize):
132+
def test_get_raw_jwt(default_access_token,encoded_token,Authorize,test_settings):
120133
assert Authorize.get_raw_jwt(encoded_token) == default_access_token
121134

122-
def test_get_jwt_jti(client,default_access_token,encoded_token,Authorize):
135+
def test_get_jwt_jti(client,default_access_token,encoded_token,Authorize,test_settings):
123136
assert Authorize.get_jti(encoded_token=encoded_token) == default_access_token['jti']
124137

125-
def test_get_jwt_subject(client,default_access_token,encoded_token):
138+
def test_get_jwt_subject(client,default_access_token,encoded_token,test_settings):
126139
response = client.get('/get_subject',headers={"Authorization":f"Bearer {encoded_token}"})
127140
assert response.status_code == 200
128141
assert response.json() == default_access_token['sub']
129142

130-
def test_invalid_jwt_issuer(client,Authorize):
143+
def test_invalid_jwt_issuer(client,Authorize,test_settings):
131144
# No issuer claim expected or provided - OK
132145
token = Authorize.create_access_token(subject='test')
133146
response = client.get('/protected',headers={'Authorization':f"Bearer {token}"})
@@ -154,7 +167,7 @@ def test_invalid_jwt_issuer(client,Authorize):
154167
AuthJWT._encode_issuer = None
155168

156169
@pytest.mark.parametrize("token_aud",['foo', ['bar'], ['foo', 'bar', 'baz']])
157-
def test_valid_aud(client,Authorize,token_aud):
170+
def test_valid_aud(client,Authorize,token_aud,test_settings):
158171
AuthJWT._decode_audience = ['foo','bar']
159172

160173
access_token = Authorize.create_access_token(subject=1,audience=token_aud)
@@ -171,7 +184,7 @@ def test_valid_aud(client,Authorize,token_aud):
171184
AuthJWT._decode_audience = None
172185

173186
@pytest.mark.parametrize("token_aud",['bar', ['bar'], ['bar', 'baz']])
174-
def test_invalid_aud_and_missing_aud(client,Authorize,token_aud):
187+
def test_invalid_aud_and_missing_aud(client,Authorize,token_aud,test_settings):
175188
AuthJWT._decode_audience = 'foo'
176189

177190
access_token = Authorize.create_access_token(subject=1,audience=token_aud)
@@ -187,7 +200,7 @@ def test_invalid_aud_and_missing_aud(client,Authorize,token_aud):
187200
if token_aud == ['bar','baz']:
188201
AuthJWT._decode_audience = None
189202

190-
def test_invalid_decode_algorithms(client,Authorize):
203+
def test_invalid_decode_algorithms(client,Authorize,test_settings):
191204
class SettingsAlgorithms(BaseSettings):
192205
authjwt_secret_key: str = "secret"
193206
authjwt_decode_algorithms: list = ['HS384','RS256']
@@ -203,7 +216,7 @@ def get_settings_algorithms():
203216

204217
AuthJWT._decode_algorithms = None
205218

206-
def test_valid_asymmetric_algorithms(client,Authorize):
219+
def test_valid_asymmetric_algorithms(client,Authorize,test_settings):
207220
hs256_token = Authorize.create_access_token(subject=1)
208221

209222
DIR = os.path.abspath(os.path.dirname(__file__))
@@ -236,7 +249,7 @@ def get_settings_asymmetric():
236249
assert response.status_code == 200
237250
assert response.json() == {'hello':'world'}
238251

239-
def test_invalid_asymmetric_algorithms(client,Authorize):
252+
def test_invalid_asymmetric_algorithms(client,Authorize,test_settings):
240253
class SettingsAsymmetricOne(BaseSettings):
241254
authjwt_algorithm: str = "RS256"
242255

Diff for: tests/test_token_types.py

+31-24
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,16 @@
11
import jwt
22
import pytest
3-
from fastapi import Depends, FastAPI, Request
4-
from fastapi.responses import JSONResponse
3+
from fastapi import Depends, FastAPI
54
from fastapi.testclient import TestClient
65
from pydantic import BaseSettings
76

87
from fastapi_jwt_auth import AuthJWT
9-
from fastapi_jwt_auth.exceptions import AuthJWTException
108

119

1210
@pytest.fixture(scope="function")
1311
def client() -> TestClient:
1412
app = FastAPI()
1513

16-
@app.exception_handler(AuthJWTException)
17-
def authjwt_exception_handler(request: Request, exc: AuthJWTException):
18-
return JSONResponse(
19-
status_code=exc.status_code, content={"detail": exc.message}
20-
)
21-
2214
@app.get("/protected")
2315
def protected(Authorize: AuthJWT = Depends()):
2416
Authorize.jwt_required()
@@ -51,7 +43,10 @@ def test_config():
5143

5244
# Checking that created token has custom type claim
5345
access = Authorize.create_access_token(subject="test")
54-
assert jwt.decode(access, key="secret", algorithms=['HS256'])["custom_type"] == "access"
46+
assert (
47+
jwt.decode(access, key="secret", algorithms=["HS256"])["custom_type"]
48+
== "access"
49+
)
5550

5651
# Checking that protected endpoint validates token correctly
5752
response = client.get("/protected", headers={"Authorization": f"Bearer {access}"})
@@ -60,22 +55,26 @@ def test_config():
6055

6156
# Checking that endpoint with optional protection validates token with
6257
# custom type claim correctly.
63-
response = client.get("/semi_protected", headers={"Authorization": f"Bearer {access}"})
58+
response = client.get(
59+
"/semi_protected", headers={"Authorization": f"Bearer {access}"}
60+
)
6461
assert response.status_code == 200
6562
assert response.json() == {"hello": "world"}
6663

67-
# Creating refresh token and checking if it has correct
64+
# Creating refresh token and checking if it has correct
6865
# type claim.
6966
refresh = Authorize.create_refresh_token(subject="test")
70-
assert jwt.decode(refresh, key="secret", algorithms=['HS256'])["custom_type"] == "refresh"
67+
assert (
68+
jwt.decode(refresh, key="secret", algorithms=["HS256"])["custom_type"]
69+
== "refresh"
70+
)
7171

7272
# Checking that refreshing with custom claim works.
7373
response = client.get("/refresh", headers={"Authorization": f"Bearer {refresh}"})
7474
assert response.status_code == 200
7575
assert response.json() == {"hello": "world"}
7676

7777

78-
7978
def test_custom_token_type_names_validation(
8079
client: TestClient, Authorize: AuthJWT
8180
) -> None:
@@ -88,33 +87,39 @@ class TestConfig(BaseSettings):
8887
def test_config():
8988
return TestConfig()
9089

91-
# Creating access token and checking that
90+
# Creating access token and checking that
9291
# it has custom type
9392
access = Authorize.create_access_token(subject="test")
94-
assert jwt.decode(access, key="secret", algorithms=['HS256'])["type"] == "access_custom"
93+
assert (
94+
jwt.decode(access, key="secret", algorithms=["HS256"])["type"]
95+
== "access_custom"
96+
)
9597

9698
# Checking that validation for custom type works as expected.
9799
response = client.get("/protected", headers={"Authorization": f"Bearer {access}"})
98100
assert response.status_code == 200
99101
assert response.json() == {"hello": "world"}
100102

101-
response = client.get("/semi_protected", headers={"Authorization": f"Bearer {access}"})
103+
response = client.get(
104+
"/semi_protected", headers={"Authorization": f"Bearer {access}"}
105+
)
102106
assert response.status_code == 200
103107
assert response.json() == {"hello": "world"}
104108

105109
# Creating refresh token and checking if it has correct type claim.
106110
refresh = Authorize.create_refresh_token(subject="test")
107-
assert jwt.decode(refresh, key="secret", algorithms=['HS256'])["type"] == "refresh_custom"
111+
assert (
112+
jwt.decode(refresh, key="secret", algorithms=["HS256"])["type"]
113+
== "refresh_custom"
114+
)
108115

109116
# Checking that refreshing with custom type works.
110117
response = client.get("/refresh", headers={"Authorization": f"Bearer {refresh}"})
111118
assert response.status_code == 200
112119
assert response.json() == {"hello": "world"}
113120

114121

115-
def test_without_type_claims(
116-
client: TestClient, Authorize: AuthJWT
117-
) -> None:
122+
def test_without_type_claims(client: TestClient, Authorize: AuthJWT) -> None:
118123
class TestConfig(BaseSettings):
119124
authjwt_secret_key: str = "secret"
120125
authjwt_token_type_claim: bool = False
@@ -125,19 +130,21 @@ def test_config():
125130

126131
# Creating access token and checking if it doesn't have type claim.
127132
access = Authorize.create_access_token(subject="test")
128-
assert "type" not in jwt.decode(access, key="secret", algorithms=['HS256'])
133+
assert "type" not in jwt.decode(access, key="secret", algorithms=["HS256"])
129134

130135
response = client.get("/protected", headers={"Authorization": f"Bearer {access}"})
131136
assert response.status_code == 200
132137
assert response.json() == {"hello": "world"}
133138

134-
response = client.get("/semi_protected", headers={"Authorization": f"Bearer {access}"})
139+
response = client.get(
140+
"/semi_protected", headers={"Authorization": f"Bearer {access}"}
141+
)
135142
assert response.status_code == 200
136143
assert response.json() == {"hello": "world"}
137144

138145
# Creating refresh token and checking if it doesn't have type claim.
139146
refresh = Authorize.create_refresh_token(subject="test")
140-
assert "type" not in jwt.decode(refresh, key="secret", algorithms=['HS256'])
147+
assert "type" not in jwt.decode(refresh, key="secret", algorithms=["HS256"])
141148

142149
# Checking that refreshing without type works.
143150
response = client.get("/refresh", headers={"Authorization": f"Bearer {refresh}"})

0 commit comments

Comments
 (0)