Skip to content

Commit 6f99e7a

Browse files
authored
fix: revoke access token on duplicate auth code user (#786)
* fix: revoke access token on duplicate auth code user * fix: review comments * tests: fix tests
1 parent 578172d commit 6f99e7a

9 files changed

Lines changed: 113 additions & 9 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
ALTER TABLE "oidc_tokens" DROP COLUMN "code_hash";
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
ALTER TABLE "oidc_tokens" ADD COLUMN "code_hash" TEXT NOT NULL DEFAULT "";

internal/controller/oidc_controller.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,9 @@ func (controller *OIDCController) Token(c *gin.Context) {
275275
case "authorization_code":
276276
entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID)
277277
if err != nil {
278+
if err := controller.oidc.DeleteTokenByCodeHash(c, controller.oidc.Hash(req.Code)); err != nil {
279+
tlog.App.Error().Err(err).Msg("Failed to delete access token by code hash")
280+
}
278281
if errors.Is(err, service.ErrCodeNotFound) {
279282
tlog.App.Warn().Msg("Code not found")
280283
c.JSON(400, gin.H{

internal/controller/oidc_controller_test.go

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ func TestOIDCController(t *testing.T) {
387387
err = json.Unmarshal(secondRecorder.Body.Bytes(), &secondRes)
388388
assert.NoError(t, err)
389389

390-
assert.Equal(t, secondRes["error"], "invalid_grant")
390+
assert.Equal(t, "invalid_grant", secondRes["error"])
391391
},
392392
},
393393
{
@@ -778,6 +778,74 @@ func TestOIDCController(t *testing.T) {
778778
assert.NotEmpty(t, error)
779779
},
780780
},
781+
{
782+
description: "Ensure access token gets invalidated on double code use",
783+
middlewares: []gin.HandlerFunc{
784+
simpleCtx,
785+
},
786+
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
787+
authorizeCodeTest, found := getTestByDescription("Ensure authorize succeeds with valid params")
788+
assert.True(t, found, "Authorize test not found")
789+
authorizeCodeTest(t, router, recorder)
790+
791+
var res map[string]any
792+
err := json.Unmarshal(recorder.Body.Bytes(), &res)
793+
assert.NoError(t, err)
794+
795+
redirectURI := res["redirect_uri"].(string)
796+
url, err := url.Parse(redirectURI)
797+
assert.NoError(t, err)
798+
799+
queryParams := url.Query()
800+
code := queryParams.Get("code")
801+
assert.NotEmpty(t, code)
802+
803+
reqBody := controller.TokenRequest{
804+
GrantType: "authorization_code",
805+
Code: code,
806+
RedirectURI: "https://test.example.com/callback",
807+
}
808+
reqBodyEncoded, err := query.Values(reqBody)
809+
assert.NoError(t, err)
810+
811+
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
812+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
813+
req.SetBasicAuth("some-client-id", "some-client-secret")
814+
recorder = httptest.NewRecorder()
815+
router.ServeHTTP(recorder, req)
816+
817+
assert.Equal(t, 200, recorder.Code)
818+
819+
err = json.Unmarshal(recorder.Body.Bytes(), &res)
820+
assert.NoError(t, err)
821+
822+
accessToken := res["access_token"].(string)
823+
assert.NotEmpty(t, accessToken)
824+
825+
req = httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
826+
req.Header.Set("Authorization", "Bearer "+accessToken)
827+
recorder = httptest.NewRecorder()
828+
router.ServeHTTP(recorder, req)
829+
assert.Equal(t, 200, recorder.Code)
830+
831+
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
832+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
833+
req.SetBasicAuth("some-client-id", "some-client-secret")
834+
recorder = httptest.NewRecorder()
835+
router.ServeHTTP(recorder, req)
836+
assert.Equal(t, 400, recorder.Code)
837+
838+
req = httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
839+
req.Header.Set("Authorization", "Bearer "+accessToken)
840+
recorder = httptest.NewRecorder()
841+
router.ServeHTTP(recorder, req)
842+
assert.Equal(t, 401, recorder.Code)
843+
844+
err = json.Unmarshal(recorder.Body.Bytes(), &res)
845+
assert.NoError(t, err)
846+
assert.Equal(t, "invalid_grant", res["error"])
847+
},
848+
},
781849
}
782850

783851
app := bootstrap.NewBootstrapApp(config.Config{})

internal/repository/models.go

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/repository/oidc_queries.sql.go

Lines changed: 26 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/service/oidc_service.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,7 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OI
506506
TokenExpiresAt: tokenExpiresAt,
507507
RefreshTokenExpiresAt: refrshTokenExpiresAt,
508508
Nonce: codeEntry.Nonce,
509+
CodeHash: codeEntry.CodeHash,
509510
})
510511

511512
if err != nil {
@@ -590,6 +591,10 @@ func (service *OIDCService) DeleteToken(c *gin.Context, tokenHash string) error
590591
return service.queries.DeleteOidcToken(c, tokenHash)
591592
}
592593

594+
func (service *OIDCService) DeleteTokenByCodeHash(c *gin.Context, codeHash string) error {
595+
return service.queries.DeleteOidcTokenByCodeHash(c, codeHash)
596+
}
597+
593598
func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (repository.OidcToken, error) {
594599
entry, err := service.queries.GetOidcToken(c, tokenHash)
595600

sql/oidc_queries.sql

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,10 @@ INSERT INTO "oidc_tokens" (
4848
"client_id",
4949
"token_expires_at",
5050
"refresh_token_expires_at",
51+
"code_hash",
5152
"nonce"
5253
) VALUES (
53-
?, ?, ?, ?, ?, ?, ?, ?
54+
?, ?, ?, ?, ?, ?, ?, ?, ?
5455
)
5556
RETURNING *;
5657

@@ -75,6 +76,10 @@ WHERE "refresh_token_hash" = ?;
7576
SELECT * FROM "oidc_tokens"
7677
WHERE "sub" = ?;
7778

79+
-- name: DeleteOidcTokenByCodeHash :exec
80+
DELETE FROM "oidc_tokens"
81+
WHERE "code_hash" = ?;
82+
7883
-- name: DeleteOidcToken :exec
7984
DELETE FROM "oidc_tokens"
8085
WHERE "access_token_hash" = ?;

sql/oidc_schemas.sql

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ CREATE TABLE IF NOT EXISTS "oidc_tokens" (
1313
"sub" TEXT NOT NULL UNIQUE,
1414
"access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
1515
"refresh_token_hash" TEXT NOT NULL,
16+
"code_hash" TEXT NOT NULL,
1617
"scope" TEXT NOT NULL,
1718
"client_id" TEXT NOT NULL,
1819
"token_expires_at" INTEGER NOT NULL,

0 commit comments

Comments
 (0)