Skip to content

Commit 7ae16d6

Browse files
committed
fix: prevent ddos attacks in oauth rate limit
1 parent db73c56 commit 7ae16d6

2 files changed

Lines changed: 45 additions & 3 deletions

File tree

internal/controller/oauth_controller.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
129129
}
130130

131131
c.SetCookie(controller.config.OAuthSessionCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
132+
defer controller.auth.EndOAuthSession(sessionIdCookie)
132133

133134
state := c.Query("state")
134135
csrfCookie, err := c.Cookie(controller.config.CSRFCookieName)
@@ -227,9 +228,6 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
227228

228229
tlog.AuditLoginSuccess(c, sessionCookie.Username, sessionCookie.Provider)
229230

230-
// Clear OAuth session
231-
controller.auth.EndOAuthSession(sessionIdCookie)
232-
233231
redirectURI, err := c.Cookie(controller.config.RedirectCookieName)
234232

235233
if err != nil || !utils.IsRedirectSafe(redirectURI, controller.config.CookieDomain) {

internal/service/auth_service.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,13 @@ import (
1717
"github.com/gin-gonic/gin"
1818
"github.com/google/uuid"
1919
"golang.org/x/crypto/bcrypt"
20+
"golang.org/x/exp/slices"
2021
"golang.org/x/oauth2"
2122
)
2223

24+
const MaxOAuthPendingSessions = 256
25+
const OAuthCleanupCount = 16
26+
2327
type OAuthPendingSession struct {
2428
State string
2529
Verifier string
@@ -570,6 +574,8 @@ func (auth *AuthService) IsBypassedIP(acls config.AppIP, ip string) bool {
570574
}
571575

572576
func (auth *AuthService) NewOAuthSession(serviceName string) (string, OAuthPendingSession, error) {
577+
auth.ensureOAuthSessionLimit()
578+
573579
service, ok := auth.oauthBroker.GetService(serviceName)
574580

575581
if !ok {
@@ -685,6 +691,8 @@ func (auth *AuthService) CleanupOAuthSessionsRoutine() {
685691
}
686692

687693
func (auth *AuthService) getOAuthPendingSession(sessionId string) (*OAuthPendingSession, error) {
694+
auth.ensureOAuthSessionLimit()
695+
688696
auth.oauthMutex.RLock()
689697
session, exists := auth.oauthPendingSessions[sessionId]
690698
auth.oauthMutex.RUnlock()
@@ -702,3 +710,39 @@ func (auth *AuthService) getOAuthPendingSession(sessionId string) (*OAuthPending
702710

703711
return session, nil
704712
}
713+
714+
func (auth *AuthService) ensureOAuthSessionLimit() {
715+
auth.oauthMutex.Lock()
716+
defer auth.oauthMutex.Unlock()
717+
718+
if len(auth.oauthPendingSessions) >= MaxOAuthPendingSessions {
719+
720+
cleanupIds := make([]string, 0, OAuthCleanupCount)
721+
722+
for range OAuthCleanupCount {
723+
oldestId := ""
724+
oldestTime := int64(0)
725+
726+
for id, session := range auth.oauthPendingSessions {
727+
if oldestTime == 0 {
728+
oldestId = id
729+
oldestTime = session.ExpiresAt.Unix()
730+
continue
731+
}
732+
if slices.Contains(cleanupIds, id) {
733+
continue
734+
}
735+
if session.ExpiresAt.Unix() < oldestTime {
736+
oldestId = id
737+
oldestTime = session.ExpiresAt.Unix()
738+
}
739+
}
740+
741+
cleanupIds = append(cleanupIds, oldestId)
742+
}
743+
744+
for _, id := range cleanupIds {
745+
delete(auth.oauthPendingSessions, id)
746+
}
747+
}
748+
}

0 commit comments

Comments
 (0)