Skip to content

Commit 7810ce5

Browse files
committed
Adds callback mode that is direct to vault
2 parents bbbb17e + aab6e60 commit 7810ce5

File tree

7 files changed

+313
-85
lines changed

7 files changed

+313
-85
lines changed

backend.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ func backend() *jwtAuthBackend {
5555
"login",
5656
"oidc/auth_url",
5757
"oidc/callback",
58+
"oidc/poll",
5859

5960
// Uncomment to mount simple UI handler for local development
6061
// "ui",

cli.go

Lines changed: 108 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"fmt"
66
"net"
77
"net/http"
8+
"net/url"
89
"os"
910
"os/signal"
1011
"path"
@@ -24,9 +25,11 @@ const (
2425
defaultPort = "8250"
2526
defaultCallbackHost = "localhost"
2627
defaultCallbackMethod = "http"
28+
defaultCallbackMode = "client"
2729

2830
FieldCallbackHost = "callbackhost"
2931
FieldCallbackMethod = "callbackmethod"
32+
FieldCallbackMode = "callbackmode"
3033
FieldListenAddress = "listenaddress"
3134
FieldPort = "port"
3235
FieldCallbackPort = "callbackport"
@@ -66,19 +69,42 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro
6669
port = defaultPort
6770
}
6871

72+
var vaultURL *url.URL
73+
callbackMode, ok := m[FieldCallbackMode]
74+
if !ok {
75+
callbackMode = defaultCallbackMode
76+
} else if callbackMode == "direct" {
77+
vaultAddr := os.Getenv("VAULT_ADDR")
78+
if vaultAddr != "" {
79+
vaultURL, _ = url.Parse(vaultAddr)
80+
}
81+
}
82+
6983
callbackHost, ok := m[FieldCallbackHost]
7084
if !ok {
71-
callbackHost = defaultCallbackHost
85+
if vaultURL != nil {
86+
callbackHost = vaultURL.Hostname()
87+
} else {
88+
callbackHost = defaultCallbackHost
89+
}
7290
}
7391

7492
callbackMethod, ok := m[FieldCallbackMethod]
7593
if !ok {
76-
callbackMethod = defaultCallbackMethod
94+
if vaultURL != nil {
95+
callbackMethod = vaultURL.Scheme
96+
} else {
97+
callbackMethod = defaultCallbackMethod
98+
}
7799
}
78100

79101
callbackPort, ok := m[FieldCallbackPort]
80102
if !ok {
81-
callbackPort = port
103+
if vaultURL != nil {
104+
callbackPort = vaultURL.Port() + "/v1/auth/" + mount
105+
} else {
106+
callbackPort = port
107+
}
82108
}
83109

84110
parseBool := func(f string, d bool) (bool, error) {
@@ -112,20 +138,49 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro
112138

113139
role := m["role"]
114140

115-
authURL, clientNonce, err := fetchAuthURL(c, role, mount, callbackPort, callbackMethod, callbackHost)
141+
authURL, clientNonce, secret, err := fetchAuthURL(c, role, mount, callbackPort, callbackMethod, callbackHost)
116142
if err != nil {
117143
return nil, err
118144
}
119145

120-
// Set up callback handler
121146
doneCh := make(chan loginResp)
122-
http.HandleFunc("/oidc/callback", callbackHandler(c, mount, clientNonce, doneCh))
123147

124-
listener, err := net.Listen("tcp", listenAddress+":"+port)
125-
if err != nil {
126-
return nil, err
148+
var pollInterval string
149+
var interval int
150+
var state string
151+
var listener net.Listener
152+
153+
if secret != nil {
154+
pollInterval, _ = secret.Data["poll_interval"].(string)
155+
state, _ = secret.Data["state"].(string)
156+
}
157+
if callbackMode == "direct" {
158+
if state == "" {
159+
return nil, errors.New("no state returned in direct callback mode")
160+
}
161+
if pollInterval == "" {
162+
return nil, errors.New("no poll_interval returned in direct callback mode")
163+
}
164+
interval, err = strconv.Atoi(pollInterval)
165+
if err != nil {
166+
return nil, errors.New("cannot convert poll_interval " + pollInterval + " to integer")
167+
}
168+
} else {
169+
if state != "" {
170+
return nil, errors.New("state returned in client callback mode, try direct")
171+
}
172+
if pollInterval != "" {
173+
return nil, errors.New("poll_interval returned in client callback mode")
174+
}
175+
// Set up callback handler
176+
http.HandleFunc("/oidc/callback", callbackHandler(c, mount, clientNonce, doneCh))
177+
178+
listener, err := net.Listen("tcp", listenAddress+":"+port)
179+
if err != nil {
180+
return nil, err
181+
}
182+
defer listener.Close()
127183
}
128-
defer listener.Close()
129184

130185
// Open the default browser to the callback URL.
131186
if !skipBrowserLaunch {
@@ -141,6 +196,26 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro
141196
}
142197
fmt.Fprintf(os.Stderr, "Waiting for OIDC authentication to complete...\n")
143198

199+
if callbackMode == "direct" {
200+
data := map[string]interface{}{
201+
"state": state,
202+
"client_nonce": clientNonce,
203+
}
204+
pollUrl := fmt.Sprintf("auth/%s/oidc/poll", mount)
205+
for {
206+
time.Sleep(time.Duration(interval) * time.Second)
207+
208+
secret, err := c.Logical().Write(pollUrl, data)
209+
if err == nil {
210+
return secret, nil
211+
}
212+
if !strings.HasSuffix(err.Error(), "authorization_pending") {
213+
return nil, err
214+
}
215+
// authorization is pending, try again
216+
}
217+
}
218+
144219
// Start local server
145220
go func() {
146221
err := http.Serve(listener, nil)
@@ -207,12 +282,12 @@ func callbackHandler(c *api.Client, mount string, clientNonce string, doneCh cha
207282
}
208283
}
209284

210-
func fetchAuthURL(c *api.Client, role, mount, callbackPort string, callbackMethod string, callbackHost string) (string, string, error) {
285+
func fetchAuthURL(c *api.Client, role, mount, callbackPort string, callbackMethod string, callbackHost string) (string, string, *api.Secret, error) {
211286
var authURL string
212287

213288
clientNonce, err := base62.Random(20)
214289
if err != nil {
215-
return "", "", err
290+
return "", "", nil, err
216291
}
217292

218293
redirectURI := fmt.Sprintf("%s://%s:%s/oidc/callback", callbackMethod, callbackHost, callbackPort)
@@ -224,18 +299,18 @@ func fetchAuthURL(c *api.Client, role, mount, callbackPort string, callbackMetho
224299

225300
secret, err := c.Logical().Write(fmt.Sprintf("auth/%s/oidc/auth_url", mount), data)
226301
if err != nil {
227-
return "", "", err
302+
return "", "", nil, err
228303
}
229304

230305
if secret != nil {
231306
authURL = secret.Data["auth_url"].(string)
232307
}
233308

234309
if authURL == "" {
235-
return "", "", fmt.Errorf("Unable to authorize role %q with redirect_uri %q. Check Vault logs for more information.", role, redirectURI)
310+
return "", "", nil, fmt.Errorf("Unable to authorize role %q with redirect_uri %q. Check Vault logs for more information.", role, redirectURI)
236311
}
237312

238-
return authURL, clientNonce, nil
313+
return authURL, clientNonce, secret, nil
239314
}
240315

241316
// parseError converts error from the API into summary and detailed portions.
@@ -289,35 +364,46 @@ Usage: vault login -method=oidc [CONFIG K=V...]
289364
290365
https://accounts.google.com/o/oauth2/v2/...
291366
292-
The default browser will be opened for the user to complete the login. Alternatively,
293-
the user may visit the provided URL directly.
367+
The default browser will be opened for the user to complete the login.
368+
Alternatively, the user may visit the provided URL directly.
294369
295370
Configuration:
296371
297372
role=<string>
298373
Vault role of type "OIDC" to use for authentication.
299374
300375
%s=<string>
301-
Optional address to bind the OIDC callback listener to (default: localhost).
376+
Mode of callback: "direct" for direct connection to Vault or "client"
377+
for connection to command line client (default: client).
378+
379+
%s=<string>
380+
Optional address to bind the OIDC callback listener to in client callback
381+
mode (default: localhost).
302382
303383
%s=<string>
304-
Optional localhost port to use for OIDC callback (default: 8250).
384+
Optional localhost port to use for OIDC callback in client callback mode
385+
(default: 8250).
305386
306387
%s=<string>
307-
Optional method to to use in OIDC redirect_uri (default: http).
388+
Optional method to use in OIDC redirect_uri (default: the method from
389+
$VAULT_ADDR in direct callback mode, else http)
308390
309391
%s=<string>
310-
Optional callback host address to use in OIDC redirect_uri (default: localhost).
392+
Optional callback host address to use in OIDC redirect_uri (default:
393+
the host from $VAULT_ADDR in direct callback mode, else localhost).
311394
312395
%s=<string>
313-
Optional port to to use in OIDC redirect_uri (default: the value set for port).
396+
Optional port to use in OIDC redirect_uri (default: the value set for
397+
port in client callback mode, else the port from $VAULT_ADDR with an
398+
added /v1/auth/<path> where <path> is from the login -path option).
314399
315400
%s=<bool>
316401
Toggle the automatic launching of the default browser to the login URL. (default: false).
317402
318403
%s=<bool>
319404
Abort on any error. (default: false).
320405
`,
406+
FieldCallbackMode,
321407
FieldListenAddress, FieldPort, FieldCallbackMethod,
322408
FieldCallbackHost, FieldCallbackPort, FieldSkipBrowser,
323409
FieldAbortOnError,
File renamed without changes.

0 commit comments

Comments
 (0)