Skip to content

Commit 2491d45

Browse files
committed
feat: add oauth session impl in auth service
1 parent 1a1712e commit 2491d45

1 file changed

Lines changed: 155 additions & 9 deletions

File tree

internal/service/auth_service.go

Lines changed: 155 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,17 @@ import (
1717
"github.com/gin-gonic/gin"
1818
"github.com/google/uuid"
1919
"golang.org/x/crypto/bcrypt"
20+
"golang.org/x/oauth2"
2021
)
2122

23+
type OAuthPendingSession struct {
24+
State string
25+
Verifier string
26+
Token *oauth2.Token
27+
Service *OAuthServiceImpl
28+
ExpiresAt time.Time
29+
}
30+
2231
type LdapGroupsCache struct {
2332
Groups []string
2433
Expires time.Time
@@ -45,28 +54,33 @@ type AuthServiceConfig struct {
4554
}
4655

4756
type AuthService struct {
48-
config AuthServiceConfig
49-
docker *DockerService
50-
loginAttempts map[string]*LoginAttempt
51-
ldapGroupsCache map[string]*LdapGroupsCache
52-
loginMutex sync.RWMutex
53-
ldapGroupsMutex sync.RWMutex
54-
ldap *LdapService
55-
queries *repository.Queries
57+
config AuthServiceConfig
58+
docker *DockerService
59+
loginAttempts map[string]*LoginAttempt
60+
ldapGroupsCache map[string]*LdapGroupsCache
61+
oauthPendingSessions map[string]*OAuthPendingSession
62+
oauthMutex sync.RWMutex
63+
loginMutex sync.RWMutex
64+
ldapGroupsMutex sync.RWMutex
65+
ldap *LdapService
66+
queries *repository.Queries
67+
oauthBroker *OAuthBrokerService
5668
}
5769

58-
func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService, queries *repository.Queries) *AuthService {
70+
func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService {
5971
return &AuthService{
6072
config: config,
6173
docker: docker,
6274
loginAttempts: make(map[string]*LoginAttempt),
6375
ldapGroupsCache: make(map[string]*LdapGroupsCache),
6476
ldap: ldap,
6577
queries: queries,
78+
oauthBroker: oauthBroker,
6679
}
6780
}
6881

6982
func (auth *AuthService) Init() error {
83+
go auth.CleanupOAuthSessionsRoutine()
7084
return nil
7185
}
7286

@@ -553,3 +567,135 @@ func (auth *AuthService) IsBypassedIP(acls config.AppIP, ip string) bool {
553567
tlog.App.Debug().Str("ip", ip).Msg("IP not in bypass list, continuing with authentication")
554568
return false
555569
}
570+
571+
func (auth *AuthService) NewOAuthSession(serviceName string) (string, error) {
572+
service, ok := auth.oauthBroker.GetService(serviceName)
573+
574+
if !ok {
575+
return "", fmt.Errorf("oauth service not found: %s", serviceName)
576+
}
577+
578+
sessionId, err := uuid.NewRandom()
579+
580+
if err != nil {
581+
return "", fmt.Errorf("failed to generate session ID: %w", err)
582+
}
583+
584+
state := uuid.New().String()
585+
verifier := uuid.New().String()
586+
587+
auth.oauthMutex.Lock()
588+
auth.oauthPendingSessions[sessionId.String()] = &OAuthPendingSession{
589+
State: state,
590+
Verifier: verifier,
591+
Service: &service,
592+
ExpiresAt: time.Now().Add(1 * time.Hour),
593+
}
594+
auth.oauthMutex.Unlock()
595+
596+
return sessionId.String(), nil
597+
}
598+
599+
func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) {
600+
auth.oauthMutex.RLock()
601+
defer auth.oauthMutex.RUnlock()
602+
603+
session, exists := auth.oauthPendingSessions[sessionId]
604+
605+
if !exists {
606+
return "", fmt.Errorf("oauth session not found: %s", sessionId)
607+
}
608+
609+
if time.Now().After(session.ExpiresAt) {
610+
delete(auth.oauthPendingSessions, sessionId)
611+
return "", fmt.Errorf("oauth session expired: %s", sessionId)
612+
}
613+
614+
return (*session.Service).GetAuthURL(session.State, session.Verifier), nil
615+
}
616+
617+
func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) {
618+
auth.oauthMutex.RLock()
619+
session, exists := auth.oauthPendingSessions[sessionId]
620+
auth.oauthMutex.RUnlock()
621+
622+
if !exists {
623+
return nil, fmt.Errorf("oauth session not found: %s", sessionId)
624+
}
625+
626+
if time.Now().After(session.ExpiresAt) {
627+
auth.oauthMutex.Lock()
628+
delete(auth.oauthPendingSessions, sessionId)
629+
auth.oauthMutex.Unlock()
630+
return nil, fmt.Errorf("oauth session expired: %s", sessionId)
631+
}
632+
633+
token, err := (*session.Service).GetToken(code, session.Verifier)
634+
635+
if err != nil {
636+
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
637+
}
638+
639+
auth.oauthMutex.Lock()
640+
session.Token = token
641+
auth.oauthMutex.Unlock()
642+
643+
return token, nil
644+
}
645+
646+
func (auth *AuthService) GetOAuthUserinfo(sessionId string) (config.Claims, error) {
647+
auth.oauthMutex.RLock()
648+
session, exists := auth.oauthPendingSessions[sessionId]
649+
auth.oauthMutex.RUnlock()
650+
651+
if !exists {
652+
return config.Claims{}, fmt.Errorf("oauth session not found: %s", sessionId)
653+
}
654+
655+
if time.Now().After(session.ExpiresAt) {
656+
auth.oauthMutex.Lock()
657+
delete(auth.oauthPendingSessions, sessionId)
658+
auth.oauthMutex.Unlock()
659+
return config.Claims{}, fmt.Errorf("oauth session expired: %s", sessionId)
660+
}
661+
662+
if session.Token == nil {
663+
return config.Claims{}, fmt.Errorf("oauth token not found for session: %s", sessionId)
664+
}
665+
666+
userinfo, err := (*session.Service).GetUserinfo(session.Token)
667+
668+
if err != nil {
669+
return config.Claims{}, fmt.Errorf("failed to get userinfo: %w", err)
670+
}
671+
672+
auth.oauthMutex.Lock()
673+
delete(auth.oauthPendingSessions, sessionId)
674+
auth.oauthMutex.Unlock()
675+
676+
return userinfo, nil
677+
}
678+
679+
func (auth *AuthService) EndOAuthSession(sessionId string) {
680+
auth.oauthMutex.Lock()
681+
delete(auth.oauthPendingSessions, sessionId)
682+
auth.oauthMutex.Unlock()
683+
}
684+
685+
func (auth *AuthService) CleanupOAuthSessionsRoutine() {
686+
ticker := time.NewTicker(30 * time.Minute)
687+
defer ticker.Stop()
688+
689+
for range ticker.C {
690+
auth.oauthMutex.Lock()
691+
defer auth.oauthMutex.Unlock()
692+
693+
now := time.Now()
694+
695+
for sessionId, session := range auth.oauthPendingSessions {
696+
if now.After(session.ExpiresAt) {
697+
delete(auth.oauthPendingSessions, sessionId)
698+
}
699+
}
700+
}
701+
}

0 commit comments

Comments
 (0)