diff --git a/01-Login/web/app/callback/callback.go b/01-Login/web/app/callback/callback.go index ba66aea..d920eed 100644 --- a/01-Login/web/app/callback/callback.go +++ b/01-Login/web/app/callback/callback.go @@ -7,6 +7,7 @@ import ( "github.com/gin-gonic/gin" "01-Login/platform/authenticator" + "01-Login/web/app/pkce" ) // Handler for our callback. @@ -18,8 +19,11 @@ func Handler(auth *authenticator.Authenticator) gin.HandlerFunc { return } + verifier := session.Get("verifier").(string) + exchangeOptions := pkce.ExchangeVerifier(verifier) + // Exchange an authorization code for a token. - token, err := auth.Exchange(ctx.Request.Context(), ctx.Query("code")) + token, err := auth.Exchange(ctx.Request.Context(), ctx.Query("code"), exchangeOptions...) if err != nil { ctx.String(http.StatusUnauthorized, "Failed to convert an authorization code into a token.") return diff --git a/01-Login/web/app/login/login.go b/01-Login/web/app/login/login.go index e702662..5719362 100644 --- a/01-Login/web/app/login/login.go +++ b/01-Login/web/app/login/login.go @@ -9,6 +9,7 @@ import ( "github.com/gin-gonic/gin" "01-Login/platform/authenticator" + "01-Login/web/app/pkce" ) // Handler for our login. @@ -20,15 +21,23 @@ func Handler(auth *authenticator.Authenticator) gin.HandlerFunc { return } + verifier, err := pkce.RandomVerifier(32) + if err != nil { + ctx.String(http.StatusInternalServerError, err.Error()) + return + } + // Save the state inside the session. session := sessions.Default(ctx) session.Set("state", state) + session.Set("verifier", verifier) if err := session.Save(); err != nil { ctx.String(http.StatusInternalServerError, err.Error()) return } - ctx.Redirect(http.StatusTemporaryRedirect, auth.AuthCodeURL(state)) + authCodeOptions := pkce.Sha256Challenge(verifier) + ctx.Redirect(http.StatusTemporaryRedirect, auth.AuthCodeURL(state, authCodeOptions...)) } } diff --git a/01-Login/web/app/pkce/pkce.go b/01-Login/web/app/pkce/pkce.go new file mode 100644 index 0000000..7997882 --- /dev/null +++ b/01-Login/web/app/pkce/pkce.go @@ -0,0 +1,56 @@ +package pkce + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "fmt" + + "golang.org/x/oauth2" +) + +const ( + codeChallengeKey = "code_challenge" + codeChallengeMethodKey = "code_challenge_method" + codeVerifierKey = "code_verifier" + + challengeMethodSha256 = "S256" + challengeMethodPlain = "plain" +) + +func RandomVerifier(length int) (string, error) { + // RFC7636 4.1: Result must be between 43 and 128 characters. + if length < 32 || length > 96 { + return "", fmt.Errorf("expected length between 32 and 96, got %d", length) + } + + buf := make([]byte, length) + if _, err := rand.Read(buf); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(buf), nil +} + +func Sha256Challenge(verifier string) []oauth2.AuthCodeOption { + sha := sha256.Sum256([]byte(verifier)) + challenge := base64.RawURLEncoding.EncodeToString(sha[:]) + + return challengeOptions(challengeMethodSha256, challenge) +} + +func PlainChallenge(verifier string) []oauth2.AuthCodeOption { + return challengeOptions(challengeMethodPlain, verifier) +} + +func challengeOptions(method string, challenge string) []oauth2.AuthCodeOption { + return []oauth2.AuthCodeOption{ + oauth2.SetAuthURLParam(codeChallengeMethodKey, method), + oauth2.SetAuthURLParam(codeChallengeKey, challenge), + } +} + +func ExchangeVerifier(verifier string) []oauth2.AuthCodeOption { + return []oauth2.AuthCodeOption{ + oauth2.SetAuthURLParam(codeVerifierKey, verifier), + } +}