Skip to content

Commit d621a33

Browse files
committed
Add oidc callback mode that is direct to server
Signed-off-by: Dave Dykstra <[email protected]>
1 parent 0bb52bf commit d621a33

File tree

10 files changed

+442
-99
lines changed

10 files changed

+442
-99
lines changed

builtin/credential/jwt/backend.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ func backend() *jwtAuthBackend {
6262
"login",
6363
"oidc/auth_url",
6464
"oidc/callback",
65+
"oidc/poll",
6566

6667
// Uncomment to mount simple UI handler for local development
6768
// "ui",

builtin/credential/jwt/cli.go

Lines changed: 113 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"fmt"
99
"net"
1010
"net/http"
11+
"net/url"
1112
"os"
1213
"os/signal"
1314
"path"
@@ -27,9 +28,11 @@ const (
2728
defaultPort = "8250"
2829
defaultCallbackHost = "localhost"
2930
defaultCallbackMethod = "http"
31+
defaultCallbackMode = "client"
3032

3133
FieldCallbackHost = "callbackhost"
3234
FieldCallbackMethod = "callbackmethod"
35+
FieldCallbackMode = "callbackmode"
3336
FieldListenAddress = "listenaddress"
3437
FieldPort = "port"
3538
FieldCallbackPort = "callbackport"
@@ -69,19 +72,44 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string, nonInteractive boo
6972
port = defaultPort
7073
}
7174

75+
var serverURL *url.URL
76+
callbackMode, ok := m[FieldCallbackMode]
77+
if !ok || callbackMode == "" {
78+
callbackMode = defaultCallbackMode
79+
} else if callbackMode == "direct" {
80+
serverAddr := api.ReadBaoVariable("BAO_ADDR")
81+
if serverAddr != "" {
82+
serverURL, _ = url.Parse(serverAddr)
83+
}
84+
}
85+
7286
callbackHost, ok := m[FieldCallbackHost]
7387
if !ok {
74-
callbackHost = defaultCallbackHost
88+
if serverURL != nil {
89+
callbackHost = serverURL.Hostname()
90+
} else {
91+
// Note that since defaultCallbackHost is localhost,
92+
// this only works if the cli is run on the server
93+
callbackHost = defaultCallbackHost
94+
}
7595
}
7696

7797
callbackMethod, ok := m[FieldCallbackMethod]
7898
if !ok {
79-
callbackMethod = defaultCallbackMethod
99+
if serverURL != nil {
100+
callbackMethod = serverURL.Scheme
101+
} else {
102+
callbackMethod = defaultCallbackMethod
103+
}
80104
}
81105

82106
callbackPort, ok := m[FieldCallbackPort]
83107
if !ok {
84-
callbackPort = port
108+
if serverURL != nil {
109+
callbackPort = serverURL.Port() + "/v1/auth/" + mount
110+
} else {
111+
callbackPort = port
112+
}
85113
}
86114

87115
parseBool := func(f string, d bool) (bool, error) {
@@ -115,20 +143,49 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string, nonInteractive boo
115143

116144
role := m["role"]
117145

118-
authURL, clientNonce, err := fetchAuthURL(c, role, mount, callbackPort, callbackMethod, callbackHost)
146+
authURL, clientNonce, secret, err := fetchAuthURL(c, role, mount, callbackPort, callbackMethod, callbackHost)
119147
if err != nil {
120148
return nil, err
121149
}
122150

123-
// Set up callback handler
124151
doneCh := make(chan loginResp)
125-
http.HandleFunc("/oidc/callback", callbackHandler(c, mount, clientNonce, doneCh))
126152

127-
listener, err := net.Listen("tcp", listenAddress+":"+port)
128-
if err != nil {
129-
return nil, err
153+
var pollInterval string
154+
var interval int
155+
var state string
156+
var listener net.Listener
157+
158+
if secret != nil {
159+
pollInterval, _ = secret.Data["poll_interval"].(string)
160+
state, _ = secret.Data["state"].(string)
161+
}
162+
if callbackMode == "direct" {
163+
if state == "" {
164+
return nil, errors.New("no state returned in direct callback mode")
165+
}
166+
if pollInterval == "" {
167+
return nil, errors.New("no poll_interval returned in direct callback mode")
168+
}
169+
interval, err = strconv.Atoi(pollInterval)
170+
if err != nil {
171+
return nil, errors.New("cannot convert poll_interval " + pollInterval + " to integer")
172+
}
173+
} else {
174+
if state != "" {
175+
return nil, errors.New("state returned in client callback mode, try direct")
176+
}
177+
if pollInterval != "" {
178+
return nil, errors.New("poll_interval returned in client callback mode")
179+
}
180+
// Set up callback handler
181+
http.HandleFunc("/oidc/callback", callbackHandler(c, mount, clientNonce, doneCh))
182+
183+
listener, err := net.Listen("tcp", listenAddress+":"+port)
184+
if err != nil {
185+
return nil, err
186+
}
187+
defer listener.Close()
130188
}
131-
defer listener.Close()
132189

133190
// Open the default browser to the callback URL.
134191
if !skipBrowserLaunch {
@@ -144,6 +201,28 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string, nonInteractive boo
144201
}
145202
fmt.Fprintf(os.Stderr, "Waiting for OIDC authentication to complete...\n")
146203

204+
if callbackMode == "direct" {
205+
data := map[string]interface{}{
206+
"state": state,
207+
"client_nonce": clientNonce,
208+
}
209+
pollUrl := fmt.Sprintf("auth/%s/oidc/poll", mount)
210+
for {
211+
time.Sleep(time.Duration(interval) * time.Second)
212+
213+
secret, err := c.Logical().Write(pollUrl, data)
214+
if err == nil {
215+
return secret, nil
216+
}
217+
if strings.HasSuffix(err.Error(), "slow_down") {
218+
interval *= 2
219+
} else if !strings.HasSuffix(err.Error(), "authorization_pending") {
220+
return nil, err
221+
}
222+
// authorization is pending, try again
223+
}
224+
}
225+
147226
// Start local server
148227
go func() {
149228
err := http.Serve(listener, nil)
@@ -210,12 +289,12 @@ func callbackHandler(c *api.Client, mount string, clientNonce string, doneCh cha
210289
}
211290
}
212291

213-
func fetchAuthURL(c *api.Client, role, mount, callbackPort string, callbackMethod string, callbackHost string) (string, string, error) {
292+
func fetchAuthURL(c *api.Client, role, mount, callbackPort string, callbackMethod string, callbackHost string) (string, string, *api.Secret, error) {
214293
var authURL string
215294

216295
clientNonce, err := base62.Random(20)
217296
if err != nil {
218-
return "", "", err
297+
return "", "", nil, err
219298
}
220299

221300
redirectURI := fmt.Sprintf("%s://%s:%s/oidc/callback", callbackMethod, callbackHost, callbackPort)
@@ -227,18 +306,18 @@ func fetchAuthURL(c *api.Client, role, mount, callbackPort string, callbackMetho
227306

228307
secret, err := c.Logical().Write(fmt.Sprintf("auth/%s/oidc/auth_url", mount), data)
229308
if err != nil {
230-
return "", "", err
309+
return "", "", nil, err
231310
}
232311

233312
if secret != nil {
234313
authURL = secret.Data["auth_url"].(string)
235314
}
236315

237316
if authURL == "" {
238-
return "", "", fmt.Errorf("Unable to authorize role %q with redirect_uri %q. Check OpenBao logs for more information.", role, redirectURI)
317+
return "", "", nil, fmt.Errorf("Unable to authorize role %q with redirect_uri %q. Check OpenBao logs for more information.", role, redirectURI)
239318
}
240319

241-
return authURL, clientNonce, nil
320+
return authURL, clientNonce, secret, nil
242321
}
243322

244323
// parseError converts error from the API into summary and detailed portions.
@@ -292,35 +371,47 @@ Usage: bao login -method=oidc [CONFIG K=V...]
292371
293372
https://accounts.google.com/o/oauth2/v2/...
294373
295-
The default browser will be opened for the user to complete the login. Alternatively,
296-
the user may visit the provided URL directly.
374+
The default browser will be opened for the user to complete the login.
375+
Alternatively, the user may visit the provided URL directly.
297376
298377
Configuration:
299378
300379
role=<string>
301380
OpenBao role of type "OIDC" to use for authentication.
302381
303382
%s=<string>
304-
Optional address to bind the OIDC callback listener to (default: localhost).
383+
Mode of callback: "direct" for direct connection to the server or "client"
384+
for connection to the command line client (default: client).
385+
386+
%s=<string>
387+
Optional address to bind the OIDC callback listener to in client callback
388+
mode (default: localhost).
305389
306390
%s=<string>
307-
Optional localhost port to use for OIDC callback (default: 8250).
391+
Optional localhost port to use for OIDC callback in client callback mode
392+
(default: 8250).
308393
309394
%s=<string>
310-
Optional method to to use in OIDC redirect_uri (default: http).
395+
Optional method to use in OIDC redirect_uri (default: the method from
396+
$BAO_ADDR or $VAULT_ADDR in direct callback mode, else http)
311397
312398
%s=<string>
313-
Optional callback host address to use in OIDC redirect_uri (default: localhost).
399+
Optional callback host address to use in OIDC redirect_uri (default:
400+
the host from $BAO_ADDR or $VAULT_ADDR in direct callback mode, else
401+
localhost).
314402
315403
%s=<string>
316-
Optional port to to use in OIDC redirect_uri (default: the value set for port).
404+
Optional port to use in OIDC redirect_uri (default: the value set for
405+
port in client callback mode, else the port from $BAO_ADDR or $VAULT_ADDR
406+
with an added /v1/auth/<path> where <path> is from the login -path option).
317407
318408
%s=<bool>
319409
Toggle the automatic launching of the default browser to the login URL. (default: false).
320410
321411
%s=<bool>
322412
Abort on any error. (default: false).
323413
`,
414+
FieldCallbackMode,
324415
FieldListenAddress, FieldPort, FieldCallbackMethod,
325416
FieldCallbackHost, FieldCallbackPort, FieldSkipBrowser,
326417
FieldAbortOnError,
File renamed without changes.

0 commit comments

Comments
 (0)