Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions internal/assets/migrations/000008_oidc_coder_user.down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE "oidc_tokens" DROP COLUMN "code_hash";
1 change: 1 addition & 0 deletions internal/assets/migrations/000008_oidc_coder_user.up.sql
Comment thread
steveiliop56 marked this conversation as resolved.
Outdated
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE "oidc_tokens" ADD COLUMN "code_hash" TEXT DEFAULT "";
Comment thread
steveiliop56 marked this conversation as resolved.
Outdated
3 changes: 3 additions & 0 deletions internal/controller/oidc_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,9 @@ func (controller *OIDCController) Token(c *gin.Context) {
case "authorization_code":
entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID)
if err != nil {
// Delete the access token just in case
controller.oidc.DeleteTokenByCodeHash(c, controller.oidc.Hash(req.Code))

Comment thread
steveiliop56 marked this conversation as resolved.
Outdated
if errors.Is(err, service.ErrCodeNotFound) {
tlog.App.Warn().Msg("Code not found")
c.JSON(400, gin.H{
Expand Down
68 changes: 68 additions & 0 deletions internal/controller/oidc_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,74 @@ func TestOIDCController(t *testing.T) {
assert.NotEmpty(t, error)
},
},
{
description: "Ensure access token gets invalidated on double code use",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
authorizeCodeTest, found := getTestByDescription("Ensure authorize succeeds with valid params")
assert.True(t, found, "Authorize test not found")
authorizeCodeTest(t, router, recorder)

var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)

redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI)
assert.NoError(t, err)

queryParams := url.Query()
code := queryParams.Get("code")
assert.NotEmpty(t, code)

reqBody := controller.TokenRequest{
GrantType: "authorization_code",
Code: code,
RedirectURI: "https://test.example.com/callback",
}
reqBodyEncoded, err := query.Values(reqBody)
assert.NoError(t, err)

req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth("some-client-id", "some-client-secret")
recorder = httptest.NewRecorder()
router.ServeHTTP(recorder, req)

assert.Equal(t, 200, recorder.Code)

err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)

accessToken := res["access_token"].(string)
assert.NotEmpty(t, accessToken)

req = httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
req.Header.Set("Authorization", "Bearer "+accessToken)
recorder = httptest.NewRecorder()
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)

req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth("some-client-id", "some-client-secret")
recorder = httptest.NewRecorder()
router.ServeHTTP(recorder, req)
assert.Equal(t, 400, recorder.Code)

req = httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
req.Header.Set("Authorization", "Bearer "+accessToken)
recorder = httptest.NewRecorder()
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)

err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
assert.Equal(t, "invalid_grant", res["error"])
},
},
}

app := bootstrap.NewBootstrapApp(config.Config{})
Expand Down
1 change: 1 addition & 0 deletions internal/repository/models.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 26 additions & 7 deletions internal/repository/oidc_queries.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions internal/service/oidc_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,7 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OI
TokenExpiresAt: tokenExpiresAt,
RefreshTokenExpiresAt: refrshTokenExpiresAt,
Nonce: codeEntry.Nonce,
CodeHash: codeEntry.CodeHash,
})

if err != nil {
Expand Down Expand Up @@ -590,6 +591,10 @@ func (service *OIDCService) DeleteToken(c *gin.Context, tokenHash string) error
return service.queries.DeleteOidcToken(c, tokenHash)
}

func (service *OIDCService) DeleteTokenByCodeHash(c *gin.Context, codeHash string) error {
return service.queries.DeleteOidcTokenByCodeHash(c, codeHash)
}

func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (repository.OidcToken, error) {
entry, err := service.queries.GetOidcToken(c, tokenHash)

Expand Down
7 changes: 6 additions & 1 deletion sql/oidc_queries.sql
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ INSERT INTO "oidc_tokens" (
"client_id",
"token_expires_at",
"refresh_token_expires_at",
"code_hash",
"nonce"
) VALUES (
?, ?, ?, ?, ?, ?, ?, ?
?, ?, ?, ?, ?, ?, ?, ?, ?
)
RETURNING *;

Expand All @@ -75,6 +76,10 @@ WHERE "refresh_token_hash" = ?;
SELECT * FROM "oidc_tokens"
WHERE "sub" = ?;

-- name: DeleteOidcTokenByCodeHash :exec
DELETE FROM "oidc_tokens"
WHERE "code_hash" = ?;

-- name: DeleteOidcToken :exec
DELETE FROM "oidc_tokens"
WHERE "access_token_hash" = ?;
Expand Down
1 change: 1 addition & 0 deletions sql/oidc_schemas.sql
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ CREATE TABLE IF NOT EXISTS "oidc_tokens" (
"sub" TEXT NOT NULL UNIQUE,
"access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE,
"refresh_token_hash" TEXT NOT NULL,
"code_hash" TEXT NOT NULL,
"scope" TEXT NOT NULL,
"client_id" TEXT NOT NULL,
"token_expires_at" INTEGER NOT NULL,
Expand Down
Loading