@@ -75,12 +75,14 @@ type TokenResponse struct {
7575}
7676
7777type AuthorizeRequest struct {
78- Scope string `json:"scope" binding:"required"`
79- ResponseType string `json:"response_type" binding:"required"`
80- ClientID string `json:"client_id" binding:"required"`
81- RedirectURI string `json:"redirect_uri" binding:"required"`
82- State string `json:"state"`
83- Nonce string `json:"nonce"`
78+ Scope string `json:"scope" binding:"required"`
79+ ResponseType string `json:"response_type" binding:"required"`
80+ ClientID string `json:"client_id" binding:"required"`
81+ RedirectURI string `json:"redirect_uri" binding:"required"`
82+ State string `json:"state"`
83+ Nonce string `json:"nonce"`
84+ CodeChallenge string `json:"code_challenge"`
85+ CodeChallengeMethod string `json:"code_challenge_method"`
8486}
8587
8688type OIDCServiceConfig struct {
@@ -293,6 +295,13 @@ func (service *OIDCService) ValidateAuthorizeParams(req AuthorizeRequest) error
293295 return errors .New ("invalid_request_uri" )
294296 }
295297
298+ // PKCE code challenge method if set
299+ if req .CodeChallenge != "" && req .CodeChallengeMethod != "" {
300+ if req .CodeChallengeMethod != "S256" || req .CodeChallenge == "plain" {
301+ return errors .New ("invalid_request" )
302+ }
303+ }
304+
296305 return nil
297306}
298307
@@ -306,8 +315,7 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r
306315 // Fixed 10 minutes
307316 expiresAt := time .Now ().Add (time .Minute * time .Duration (10 )).Unix ()
308317
309- // Insert the code into the database
310- _ , err := service .queries .CreateOidcCode (c , repository.CreateOidcCodeParams {
318+ entry := repository.CreateOidcCodeParams {
311319 Sub : sub ,
312320 CodeHash : service .Hash (code ),
313321 // Here it's safe to split and trust the output since, we validated the scopes before
@@ -316,7 +324,21 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r
316324 ClientID : req .ClientID ,
317325 ExpiresAt : expiresAt ,
318326 Nonce : req .Nonce ,
319- })
327+ }
328+
329+ if req .CodeChallenge != "" {
330+ if req .CodeChallengeMethod == "S256" {
331+ entry .CodeChallenge = req .CodeChallenge
332+ entry .CodeChallengeMethod = "S256"
333+ } else {
334+ entry .CodeChallenge = service .hashAndEncodePKCE (req .CodeChallenge )
335+ entry .CodeChallengeMethod = "plain"
336+ tlog .App .Warn ().Msg ("Received plain PKCE code challenge, it's recommended to use S256 for better security" )
337+ }
338+ }
339+
340+ // Insert the code into the database
341+ _ , err := service .queries .CreateOidcCode (c , entry )
320342
321343 return err
322344}
@@ -728,3 +750,20 @@ func (service *OIDCService) GetJWK() ([]byte, error) {
728750
729751 return jwk .Public ().MarshalJSON ()
730752}
753+
754+ func (service * OIDCService ) ValidatePKCE (codeChallenge string , codeChallengeMethod string , codeVerifier string ) bool {
755+ if codeChallenge == "" {
756+ return true
757+ }
758+ if codeChallengeMethod == "plain" {
759+ // Code challenge is hashed and encoded in the database for security reasons
760+ return codeChallenge == service .hashAndEncodePKCE (codeVerifier )
761+ }
762+ return codeChallenge == codeVerifier
763+ }
764+
765+ func (service * OIDCService ) hashAndEncodePKCE (codeVerifier string ) string {
766+ hasher := sha256 .New ()
767+ hasher .Write ([]byte (codeVerifier ))
768+ return base64 .URLEncoding .EncodeToString (hasher .Sum (nil ))
769+ }
0 commit comments