Skip to content

Commit 02faabf

Browse files
committed
feat: add CSRF cookie protection
1 parent eb36b22 commit 02faabf

3 files changed

Lines changed: 46 additions & 10 deletions

File tree

internal/auth/auth.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package auth
22

33
import (
44
"fmt"
5-
"net/http"
65
"regexp"
76
"slices"
87
"strings"
@@ -42,7 +41,6 @@ func (auth *Auth) GetSession(c *gin.Context) (*sessions.Session, error) {
4241
MaxAge: auth.Config.SessionExpiry,
4342
Secure: auth.Config.CookieSecure,
4443
HttpOnly: true,
45-
SameSite: http.SameSiteDefaultMode,
4644
Domain: fmt.Sprintf(".%s", auth.Config.Domain),
4745
}
4846

internal/handlers/handlers.go

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -515,11 +515,17 @@ func (h *Handlers) OauthUrlHandler(c *gin.Context) {
515515

516516
log.Debug().Str("provider", request.Provider).Msg("Got provider")
517517

518+
// Create state
519+
state := provider.GenerateState()
520+
518521
// Get auth URL
519-
authURL := provider.GetAuthURL()
522+
authURL := provider.GetAuthURL(state)
520523

521524
log.Debug().Msg("Got auth URL")
522525

526+
// Set CSRF cookie
527+
c.SetCookie("tinyauth-csrf", state, int(time.Hour.Seconds()), "/", "", h.Config.CookieSecure, true)
528+
523529
// Get redirect URI
524530
redirectURI := c.Query("redirect_uri")
525531

@@ -553,16 +559,33 @@ func (h *Handlers) OauthCallbackHandler(c *gin.Context) {
553559

554560
log.Debug().Interface("provider", providerName.Provider).Msg("Got provider name")
555561

556-
// Get code
557-
code := c.Query("code")
562+
// Get state
563+
state := c.Query("state")
558564

559-
// Code empty so redirect to error
560-
if code == "" {
561-
log.Error().Msg("No code provided")
565+
// Get CSRF cookie
566+
csrfCookie, err := c.Cookie("tinyauth-csrf")
567+
568+
if err != nil {
569+
log.Debug().Msg("No CSRF cookie")
562570
c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL))
563571
return
564572
}
565573

574+
log.Debug().Str("csrfCookie", csrfCookie).Msg("Got CSRF cookie")
575+
576+
// Check if CSRF cookie is valid
577+
if csrfCookie != state {
578+
log.Warn().Msg("Invalid CSRF cookie or CSRF cookie does not match with the state")
579+
c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL))
580+
return
581+
}
582+
583+
// Clean up CSRF cookie
584+
c.SetCookie("tinyauth-csrf", "", -1, "/", "", h.Config.CookieSecure, true)
585+
586+
// Get code
587+
code := c.Query("code")
588+
566589
log.Debug().Msg("Got code")
567590

568591
// Get provider

internal/oauth/oauth.go

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package oauth
22

33
import (
44
"context"
5+
"crypto/rand"
6+
"encoding/base64"
57
"net/http"
68

79
"golang.org/x/oauth2"
@@ -26,9 +28,9 @@ func (oauth *OAuth) Init() {
2628
oauth.Verifier = oauth2.GenerateVerifier()
2729
}
2830

29-
func (oauth *OAuth) GetAuthURL() string {
31+
func (oauth *OAuth) GetAuthURL(state string) string {
3032
// Return the auth url
31-
return oauth.Config.AuthCodeURL("state", oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(oauth.Verifier))
33+
return oauth.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(oauth.Verifier))
3234
}
3335

3436
func (oauth *OAuth) ExchangeToken(code string) (string, error) {
@@ -51,3 +53,16 @@ func (oauth *OAuth) GetClient() *http.Client {
5153
// Return the http client with the token set
5254
return oauth.Config.Client(oauth.Context, oauth.Token)
5355
}
56+
57+
func (oauth *OAuth) GenerateState() string {
58+
// Generate a random state string
59+
b := make([]byte, 128)
60+
61+
// Fill the byte slice with random data
62+
rand.Read(b)
63+
64+
// Encode the byte slice to a base64 string
65+
state := base64.URLEncoding.EncodeToString(b)
66+
67+
return state
68+
}

0 commit comments

Comments
 (0)