-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathimplicit.go
127 lines (115 loc) · 4.19 KB
/
implicit.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package callback
import (
"context"
"fmt"
"net/http"
"github.com/hashicorp/cap/oidc"
"golang.org/x/oauth2"
)
// Implicit creates an oidc implicit flow callback handler which
// uses a RequestReader to read existing oidc.Request(s) via the request's
// oidc "state" parameter as a key for the lookup.
//
// It should be noted that if your OIDC provider supports PKCE, then
// use it over the implicit flow
//
// The SuccessResponseFunc is used to create a response when callback is
// successful.
//
// The ErrorResponseFunc is to create a response when the callback fails.
func Implicit(ctx context.Context, p *oidc.Provider, rw RequestReader, sFn SuccessResponseFunc, eFn ErrorResponseFunc) (http.HandlerFunc, error) {
const op = "callback.Implicit"
if p == nil {
return nil, fmt.Errorf("%s: provider is empty: %w", op, oidc.ErrInvalidParameter)
}
if rw == nil {
return nil, fmt.Errorf("%s: request reader is empty: %w", op, oidc.ErrInvalidParameter)
}
if sFn == nil {
return nil, fmt.Errorf("%s: success response func is empty: %w", op, oidc.ErrInvalidParameter)
}
if eFn == nil {
return nil, fmt.Errorf("%s: error response func is empty: %w", op, oidc.ErrInvalidParameter)
}
return func(w http.ResponseWriter, req *http.Request) {
const op = "callback.Implicit"
reqState := req.FormValue("state")
if err := req.FormValue("error"); err != "" {
// get parameters from either the body or query parameters.
// FormValue prioritizes body values, if found
reqError := &AuthenErrorResponse{
Error: err,
Description: req.FormValue("error_description"),
Uri: req.FormValue("error_uri"),
}
eFn(reqState, reqError, nil, w, req)
return
}
if reqState == "" {
responseErr := fmt.Errorf("%s: empty state parameter: %w", op, oidc.ErrInvalidParameter)
eFn(reqState, nil, responseErr, w, req)
return
}
oidcRequest, err := rw.Read(ctx, reqState)
if err != nil {
responseErr := fmt.Errorf("%s: unable to read auth code request: %w", op, err)
eFn(reqState, nil, responseErr, w, req)
return
}
if oidcRequest == nil {
// could have expired or it could be invalid... no way to known for
// sure
responseErr := fmt.Errorf("%s: auth code request not found: %w", op, oidc.ErrNotFound)
eFn(reqState, nil, responseErr, w, req)
return
}
useImplicit, includeAccessToken := oidcRequest.ImplicitFlow()
if !useImplicit {
responseErr := fmt.Errorf("%s: request (%s) should not be using the implicit flow: %w", op, oidcRequest.State(), oidc.ErrInvalidFlow)
eFn(reqState, nil, responseErr, w, req)
return
}
if oidcRequest.IsExpired() {
responseErr := fmt.Errorf("%s: authentication request is expired: %w", op, oidc.ErrExpiredRequest)
eFn(reqState, nil, responseErr, w, req)
return
}
if reqState != oidcRequest.State() {
// the stateReadWriter didn't return the correct state for the key
// given... this is an internal sort of error on the part of the
// reader.
responseErr := fmt.Errorf("%s: authen state (%s) and response state (%s) are not equal: %w", op, oidcRequest.State(), reqState, oidc.ErrInvalidResponseState)
eFn(reqState, nil, responseErr, w, req)
return
}
reqIDToken := oidc.IDToken(req.FormValue("id_token"))
if _, err := p.VerifyIDToken(ctx, reqIDToken, oidcRequest); err != nil {
responseErr := fmt.Errorf("%s: unable to verify id_token: %w", op, err)
eFn(reqState, nil, responseErr, w, req)
return
}
var oath2Token *oauth2.Token
if includeAccessToken {
reqAccessToken := req.FormValue("access_token")
if reqAccessToken != "" {
if _, err := reqIDToken.VerifyAccessToken(oidc.AccessToken(reqAccessToken)); err != nil {
responseErr := fmt.Errorf("%s: unable to verify access_token: %w", op, err)
eFn(reqState, nil, responseErr, w, req)
return
}
oath2Token = &oauth2.Token{
AccessToken: reqAccessToken,
}
}
}
responseToken, err := oidc.NewToken(oidc.IDToken(reqIDToken), oath2Token)
if err != nil {
responseErr := fmt.Errorf("%s: unable to create response tokens: %w", op, err)
eFn(reqState, nil, responseErr, w, req)
return
}
sFn(reqState, responseToken, w, req)
}, nil
}