Skip to content

Commit 7bead41

Browse files
committed
feat: move oauth logic into auth service and handle multiple sessions
1 parent 2491d45 commit 7bead41

8 files changed

Lines changed: 168 additions & 159 deletions

File tree

internal/bootstrap/app_bootstrap.go

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,17 @@ import (
2222
type BootstrapApp struct {
2323
config config.Config
2424
context struct {
25-
appUrl string
26-
uuid string
27-
cookieDomain string
28-
sessionCookieName string
29-
csrfCookieName string
30-
redirectCookieName string
31-
users []config.User
32-
oauthProviders map[string]config.OAuthServiceConfig
33-
configuredProviders []controller.Provider
34-
oidcClients []config.OIDCClientConfig
25+
appUrl string
26+
uuid string
27+
cookieDomain string
28+
sessionCookieName string
29+
csrfCookieName string
30+
redirectCookieName string
31+
oauthSessionCookieName string
32+
users []config.User
33+
oauthProviders map[string]config.OAuthServiceConfig
34+
configuredProviders []controller.Provider
35+
oidcClients []config.OIDCClientConfig
3536
}
3637
services Services
3738
}
@@ -113,6 +114,7 @@ func (app *BootstrapApp) Setup() error {
113114
app.context.sessionCookieName = fmt.Sprintf("%s-%s", config.SessionCookieName, cookieId)
114115
app.context.csrfCookieName = fmt.Sprintf("%s-%s", config.CSRFCookieName, cookieId)
115116
app.context.redirectCookieName = fmt.Sprintf("%s-%s", config.RedirectCookieName, cookieId)
117+
app.context.oauthSessionCookieName = fmt.Sprintf("%s-%s", config.OAuthSessionCookieName, cookieId)
116118

117119
// Dumps
118120
tlog.App.Trace().Interface("config", app.config).Msg("Config dump")
@@ -190,12 +192,12 @@ func (app *BootstrapApp) Setup() error {
190192

191193
// Start db cleanup routine
192194
tlog.App.Debug().Msg("Starting database cleanup routine")
193-
go app.dbCleanup(queries)
195+
go app.dbCleanupRoutine(queries)
194196

195197
// If analytics are not disabled, start heartbeat
196198
if app.config.Analytics.Enabled {
197199
tlog.App.Debug().Msg("Starting heartbeat routine")
198-
go app.heartbeat()
200+
go app.heartbeatRoutine()
199201
}
200202

201203
// If we have an socket path, bind to it
@@ -226,7 +228,7 @@ func (app *BootstrapApp) Setup() error {
226228
return nil
227229
}
228230

229-
func (app *BootstrapApp) heartbeat() {
231+
func (app *BootstrapApp) heartbeatRoutine() {
230232
ticker := time.NewTicker(time.Duration(12) * time.Hour)
231233
defer ticker.Stop()
232234

@@ -280,7 +282,7 @@ func (app *BootstrapApp) heartbeat() {
280282
}
281283
}
282284

283-
func (app *BootstrapApp) dbCleanup(queries *repository.Queries) {
285+
func (app *BootstrapApp) dbCleanupRoutine(queries *repository.Queries) {
284286
ticker := time.NewTicker(time.Duration(30) * time.Minute)
285287
defer ticker.Stop()
286288
ctx := context.Background()

internal/bootstrap/router_bootstrap.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,13 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
7777
contextController.SetupRoutes()
7878

7979
oauthController := controller.NewOAuthController(controller.OAuthControllerConfig{
80-
AppURL: app.config.AppURL,
81-
SecureCookie: app.config.Auth.SecureCookie,
82-
CSRFCookieName: app.context.csrfCookieName,
83-
RedirectCookieName: app.context.redirectCookieName,
84-
CookieDomain: app.context.cookieDomain,
85-
}, apiRouter, app.services.authService, app.services.oauthBrokerService)
80+
AppURL: app.config.AppURL,
81+
SecureCookie: app.config.Auth.SecureCookie,
82+
CSRFCookieName: app.context.csrfCookieName,
83+
RedirectCookieName: app.context.redirectCookieName,
84+
CookieDomain: app.context.cookieDomain,
85+
OAuthSessionCookieName: app.context.oauthSessionCookieName,
86+
}, apiRouter, app.services.authService)
8687

8788
oauthController.SetupRoutes()
8889

internal/bootstrap/service_bootstrap.go

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,16 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er
5858

5959
services.accessControlService = accessControlsService
6060

61+
oauthBrokerService := service.NewOAuthBrokerService(app.context.oauthProviders)
62+
63+
err = oauthBrokerService.Init()
64+
65+
if err != nil {
66+
return Services{}, err
67+
}
68+
69+
services.oauthBrokerService = oauthBrokerService
70+
6171
authService := service.NewAuthService(service.AuthServiceConfig{
6272
Users: app.context.users,
6373
OauthWhitelist: app.config.OAuth.Whitelist,
@@ -70,7 +80,7 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er
7080
SessionCookieName: app.context.sessionCookieName,
7181
IP: app.config.Auth.IP,
7282
LDAPGroupsCacheTTL: app.config.Ldap.GroupCacheTTL,
73-
}, dockerService, services.ldapService, queries)
83+
}, dockerService, services.ldapService, queries, services.oauthBrokerService)
7484

7585
err = authService.Init()
7686

@@ -80,16 +90,6 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er
8090

8191
services.authService = authService
8292

83-
oauthBrokerService := service.NewOAuthBrokerService(app.context.oauthProviders)
84-
85-
err = oauthBrokerService.Init()
86-
87-
if err != nil {
88-
return Services{}, err
89-
}
90-
91-
services.oauthBrokerService = oauthBrokerService
92-
9393
oidcService := service.NewOIDCService(service.OIDCServiceConfig{
9494
Clients: app.config.OIDC.Clients,
9595
PrivateKeyPath: app.config.OIDC.PrivateKeyPath,

internal/config/config.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ var BuildTimestamp = "0000-00-00T00:00:00Z"
7373
var SessionCookieName = "tinyauth-session"
7474
var CSRFCookieName = "tinyauth-csrf"
7575
var RedirectCookieName = "tinyauth-redirect"
76+
var OAuthSessionCookieName = "tinyauth-oauth"
7677

7778
// Main app config
7879

internal/controller/oauth_controller.go

Lines changed: 55 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -21,26 +21,25 @@ type OAuthRequest struct {
2121
}
2222

2323
type OAuthControllerConfig struct {
24-
CSRFCookieName string
25-
RedirectCookieName string
26-
SecureCookie bool
27-
AppURL string
28-
CookieDomain string
24+
CSRFCookieName string
25+
OAuthSessionCookieName string
26+
RedirectCookieName string
27+
SecureCookie bool
28+
AppURL string
29+
CookieDomain string
2930
}
3031

3132
type OAuthController struct {
3233
config OAuthControllerConfig
3334
router *gin.RouterGroup
3435
auth *service.AuthService
35-
broker *service.OAuthBrokerService
3636
}
3737

38-
func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *service.AuthService, broker *service.OAuthBrokerService) *OAuthController {
38+
func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *OAuthController {
3939
return &OAuthController{
4040
config: config,
4141
router: router,
4242
auth: auth,
43-
broker: broker,
4443
}
4544
}
4645

@@ -63,21 +62,32 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
6362
return
6463
}
6564

66-
service, exists := controller.broker.GetService(req.Provider)
65+
sessionId, session, err := controller.auth.NewOAuthSession(req.Provider)
6766

68-
if !exists {
69-
tlog.App.Warn().Msgf("OAuth provider not found: %s", req.Provider)
70-
c.JSON(404, gin.H{
71-
"status": 404,
72-
"message": "Not Found",
67+
if err != nil {
68+
tlog.App.Error().Err(err).Msg("Failed to create OAuth session")
69+
c.JSON(500, gin.H{
70+
"status": 500,
71+
"message": "Internal Server Error",
72+
})
73+
return
74+
}
75+
76+
tlog.App.Debug().Interface("session", session).Msg("Created new OAuth session")
77+
78+
authUrl, err := controller.auth.GetOAuthURL(sessionId)
79+
80+
if err != nil {
81+
tlog.App.Error().Err(err).Msg("Failed to get OAuth URL")
82+
c.JSON(500, gin.H{
83+
"status": 500,
84+
"message": "Internal Server Error",
7385
})
7486
return
7587
}
7688

77-
service.GenerateVerifier()
78-
state := service.GenerateState()
79-
authURL := service.GetAuthURL(state)
80-
c.SetCookie(controller.config.CSRFCookieName, state, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
89+
c.SetCookie(controller.config.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
90+
c.SetCookie(controller.config.CSRFCookieName, session.State, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
8191

8292
redirectURI := c.Query("redirect_uri")
8393
isRedirectSafe := utils.IsRedirectSafe(redirectURI, controller.config.CookieDomain)
@@ -95,7 +105,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
95105
c.JSON(200, gin.H{
96106
"status": 200,
97107
"message": "OK",
98-
"url": authURL,
108+
"url": authUrl,
99109
})
100110
}
101111

@@ -112,6 +122,16 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
112122
return
113123
}
114124

125+
sessionIdCookie, err := c.Cookie(controller.config.OAuthSessionCookieName)
126+
127+
if err != nil {
128+
tlog.App.Warn().Err(err).Msg("OAuth session cookie missing")
129+
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
130+
return
131+
}
132+
133+
c.SetCookie(controller.config.OAuthSessionCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
134+
115135
state := c.Query("state")
116136
csrfCookie, err := c.Cookie(controller.config.CSRFCookieName)
117137

@@ -125,28 +145,17 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
125145
c.SetCookie(controller.config.CSRFCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
126146

127147
code := c.Query("code")
128-
service, exists := controller.broker.GetService(req.Provider)
129148

130-
if !exists {
131-
tlog.App.Warn().Msgf("OAuth provider not found: %s", req.Provider)
132-
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
133-
return
134-
}
149+
tlog.App.Debug().Str("code", code).Str("state", state).Msg("Received OAuth callback")
150+
_, err = controller.auth.GetOAuthToken(sessionIdCookie, code)
135151

136-
err = service.VerifyCode(code)
137152
if err != nil {
138-
tlog.App.Error().Err(err).Msg("Failed to verify OAuth code")
153+
tlog.App.Error().Err(err).Msg("Failed to exchange code for token")
139154
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
140155
return
141156
}
142157

143-
user, err := controller.broker.GetUser(req.Provider)
144-
145-
if err != nil {
146-
tlog.App.Error().Err(err).Msg("Failed to get user from OAuth provider")
147-
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
148-
return
149-
}
158+
user, err := controller.auth.GetOAuthUserinfo(sessionIdCookie)
150159

151160
if user.Email == "" {
152161
tlog.App.Error().Msg("OAuth provider did not return an email")
@@ -192,13 +201,21 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
192201
username = strings.Replace(user.Email, "@", "_", 1)
193202
}
194203

204+
service, err := controller.auth.GetOAuthService(sessionIdCookie)
205+
206+
if err != nil {
207+
tlog.App.Error().Err(err).Msg("Failed to get OAuth service for session")
208+
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
209+
return
210+
}
211+
195212
sessionCookie := repository.Session{
196213
Username: username,
197214
Name: name,
198215
Email: user.Email,
199216
Provider: req.Provider,
200217
OAuthGroups: utils.CoalesceToString(user.Groups),
201-
OAuthName: service.GetName(),
218+
OAuthName: service.Name(),
202219
OAuthSub: user.Sub,
203220
}
204221

@@ -214,6 +231,9 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
214231

215232
tlog.AuditLoginSuccess(c, sessionCookie.Username, sessionCookie.Provider)
216233

234+
// Clear OAuth session
235+
controller.auth.EndOAuthSession(sessionIdCookie)
236+
217237
redirectURI, err := c.Cookie(controller.config.RedirectCookieName)
218238

219239
if err != nil || !utils.IsRedirectSafe(redirectURI, controller.config.CookieDomain) {

0 commit comments

Comments
 (0)