@@ -21,26 +21,25 @@ type OAuthRequest struct {
2121}
2222
2323type OAuthControllerConfig struct {
24- CSRFCookieName string
25- RedirectCookieName string
26- SecureCookie bool
27- AppURL string
28- CookieDomain string
24+ CSRFCookieName string
25+ OAuthSessionCookieName string
26+ RedirectCookieName string
27+ SecureCookie bool
28+ AppURL string
29+ CookieDomain string
2930}
3031
3132type OAuthController struct {
3233 config OAuthControllerConfig
3334 router * gin.RouterGroup
3435 auth * service.AuthService
35- broker * service.OAuthBrokerService
3636}
3737
38- func NewOAuthController (config OAuthControllerConfig , router * gin.RouterGroup , auth * service.AuthService , broker * service. OAuthBrokerService ) * OAuthController {
38+ func NewOAuthController (config OAuthControllerConfig , router * gin.RouterGroup , auth * service.AuthService ) * OAuthController {
3939 return & OAuthController {
4040 config : config ,
4141 router : router ,
4242 auth : auth ,
43- broker : broker ,
4443 }
4544}
4645
@@ -63,21 +62,30 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
6362 return
6463 }
6564
66- service , exists := controller .broker . GetService (req .Provider )
65+ sessionId , session , err := controller .auth . NewOAuthSession (req .Provider )
6766
68- if ! exists {
69- tlog .App .Warn ().Msgf ("OAuth provider not found: %s" , req .Provider )
70- c .JSON (404 , gin.H {
71- "status" : 404 ,
72- "message" : "Not Found" ,
67+ if err != nil {
68+ tlog .App .Error ().Err (err ).Msg ("Failed to create OAuth session" )
69+ c .JSON (500 , gin.H {
70+ "status" : 500 ,
71+ "message" : "Internal Server Error" ,
72+ })
73+ return
74+ }
75+
76+ authUrl , err := controller .auth .GetOAuthURL (sessionId )
77+
78+ if err != nil {
79+ tlog .App .Error ().Err (err ).Msg ("Failed to get OAuth URL" )
80+ c .JSON (500 , gin.H {
81+ "status" : 500 ,
82+ "message" : "Internal Server Error" ,
7383 })
7484 return
7585 }
7686
77- service .GenerateVerifier ()
78- state := service .GenerateState ()
79- authURL := service .GetAuthURL (state )
80- c .SetCookie (controller .config .CSRFCookieName , state , int (time .Hour .Seconds ()), "/" , fmt .Sprintf (".%s" , controller .config .CookieDomain ), controller .config .SecureCookie , true )
87+ c .SetCookie (controller .config .OAuthSessionCookieName , sessionId , int (time .Hour .Seconds ()), "/" , fmt .Sprintf (".%s" , controller .config .CookieDomain ), controller .config .SecureCookie , true )
88+ c .SetCookie (controller .config .CSRFCookieName , session .State , int (time .Hour .Seconds ()), "/" , fmt .Sprintf (".%s" , controller .config .CookieDomain ), controller .config .SecureCookie , true )
8189
8290 redirectURI := c .Query ("redirect_uri" )
8391 isRedirectSafe := utils .IsRedirectSafe (redirectURI , controller .config .CookieDomain )
@@ -95,7 +103,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
95103 c .JSON (200 , gin.H {
96104 "status" : 200 ,
97105 "message" : "OK" ,
98- "url" : authURL ,
106+ "url" : authUrl ,
99107 })
100108}
101109
@@ -112,6 +120,17 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
112120 return
113121 }
114122
123+ sessionIdCookie , err := c .Cookie (controller .config .OAuthSessionCookieName )
124+
125+ if err != nil {
126+ tlog .App .Warn ().Err (err ).Msg ("OAuth session cookie missing" )
127+ c .Redirect (http .StatusTemporaryRedirect , fmt .Sprintf ("%s/error" , controller .config .AppURL ))
128+ return
129+ }
130+
131+ c .SetCookie (controller .config .OAuthSessionCookieName , "" , - 1 , "/" , fmt .Sprintf (".%s" , controller .config .CookieDomain ), controller .config .SecureCookie , true )
132+ defer controller .auth .EndOAuthSession (sessionIdCookie )
133+
115134 state := c .Query ("state" )
116135 csrfCookie , err := c .Cookie (controller .config .CSRFCookieName )
117136
@@ -125,28 +144,15 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
125144 c .SetCookie (controller .config .CSRFCookieName , "" , - 1 , "/" , fmt .Sprintf (".%s" , controller .config .CookieDomain ), controller .config .SecureCookie , true )
126145
127146 code := c .Query ("code" )
128- service , exists : = controller .broker . GetService ( req . Provider )
147+ _ , err = controller .auth . GetOAuthToken ( sessionIdCookie , code )
129148
130- if ! exists {
131- tlog .App .Warn ().Msgf ("OAuth provider not found: %s" , req .Provider )
132- c .Redirect (http .StatusTemporaryRedirect , fmt .Sprintf ("%s/error" , controller .config .AppURL ))
133- return
134- }
135-
136- err = service .VerifyCode (code )
137149 if err != nil {
138- tlog .App .Error ().Err (err ).Msg ("Failed to verify OAuth code" )
150+ tlog .App .Error ().Err (err ).Msg ("Failed to exchange code for token " )
139151 c .Redirect (http .StatusTemporaryRedirect , fmt .Sprintf ("%s/error" , controller .config .AppURL ))
140152 return
141153 }
142154
143- user , err := controller .broker .GetUser (req .Provider )
144-
145- if err != nil {
146- tlog .App .Error ().Err (err ).Msg ("Failed to get user from OAuth provider" )
147- c .Redirect (http .StatusTemporaryRedirect , fmt .Sprintf ("%s/error" , controller .config .AppURL ))
148- return
149- }
155+ user , err := controller .auth .GetOAuthUserinfo (sessionIdCookie )
150156
151157 if user .Email == "" {
152158 tlog .App .Error ().Msg ("OAuth provider did not return an email" )
@@ -192,13 +198,21 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
192198 username = strings .Replace (user .Email , "@" , "_" , 1 )
193199 }
194200
201+ service , err := controller .auth .GetOAuthService (sessionIdCookie )
202+
203+ if err != nil {
204+ tlog .App .Error ().Err (err ).Msg ("Failed to get OAuth service for session" )
205+ c .Redirect (http .StatusTemporaryRedirect , fmt .Sprintf ("%s/error" , controller .config .AppURL ))
206+ return
207+ }
208+
195209 sessionCookie := repository.Session {
196210 Username : username ,
197211 Name : name ,
198212 Email : user .Email ,
199213 Provider : req .Provider ,
200214 OAuthGroups : utils .CoalesceToString (user .Groups ),
201- OAuthName : service .GetName (),
215+ OAuthName : service .Name (),
202216 OAuthSub : user .Sub ,
203217 }
204218
0 commit comments