Skip to content

Commit b762685

Browse files
committed
1 parent 5dab416 commit b762685

File tree

2 files changed

+212
-3
lines changed

2 files changed

+212
-3
lines changed

oauth2.go

+77
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ package oauth2 // import "golang.org/x/oauth2"
1111
import (
1212
"bytes"
1313
"context"
14+
"encoding/json"
1415
"errors"
1516
"net/http"
1617
"net/url"
@@ -61,6 +62,10 @@ type Config struct {
6162

6263
// Scope specifies optional requested permissions.
6364
Scopes []string
65+
66+
// ClaimSet is optional requested parameter used to that
67+
// specific Claims be returned.
68+
ClaimSet *ClaimSet
6469
}
6570

6671
// A TokenSource is anything that can return a token.
@@ -78,6 +83,26 @@ type Endpoint struct {
7883
TokenURL string
7984
}
8085

86+
// JSON keys for top-level member of the Claims request JSON.
87+
const (
88+
UserInfoClaim = "userinfo"
89+
IdTokenClaim = "id_token"
90+
)
91+
92+
// claim describes JSON object used to specify additional information
93+
// about the Claim being requested.
94+
type claim struct {
95+
Essential bool `json:"essential,omitempty"`
96+
Value string `json:"value,omitempty"`
97+
Values []string `json:"values,omitempty"`
98+
}
99+
100+
// ClaimSet contains map with members of the Claims request.
101+
// It provides methods to add specific Claims.
102+
type ClaimSet struct {
103+
claims map[string]map[string]*claim
104+
}
105+
81106
var (
82107
// AccessTypeOnline and AccessTypeOffline are options passed
83108
// to the Options.AuthCodeURL method. They modify the
@@ -151,6 +176,13 @@ func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string {
151176
} else {
152177
buf.WriteByte('?')
153178
}
179+
if c.ClaimSet != nil {
180+
body, err := json.Marshal(c.ClaimSet)
181+
if err != nil {
182+
panic(err)
183+
}
184+
v.Set("claims", string(body))
185+
}
154186
buf.WriteString(v.Encode())
155187
return buf.String()
156188
}
@@ -229,6 +261,51 @@ func (c *Config) TokenSource(ctx context.Context, t *Token) TokenSource {
229261
}
230262
}
231263

264+
// MarshalJSON is part of the json.Marshaler interface.
265+
// It's used to encode hidden map that contains Claims objects.
266+
func (c *ClaimSet) MarshalJSON() ([]byte, error) {
267+
return json.Marshal(c.claims)
268+
}
269+
270+
// AddVoluntaryClaim adds the Claim being requested in default manner,
271+
// as a Voluntary Claim.
272+
func (c *ClaimSet) AddVoluntaryClaim(topLevelName string, claimName string) {
273+
c.initializeTopLevelMember(topLevelName)
274+
c.claims[topLevelName][claimName] = nil
275+
}
276+
277+
// AddClaimWithValue adds the Claim being requested to return a particular value.
278+
// The Claim can be defined as an Essential Claim.
279+
func (c *ClaimSet) AddClaimWithValue(topLevelName string, claimName string, essential bool, value string) {
280+
c.initializeTopLevelMember(topLevelName)
281+
c.claims[topLevelName][claimName] = &claim{
282+
Essential: essential,
283+
Value: value,
284+
}
285+
}
286+
287+
// AddClaimWithValues adds the Claim being requested to return
288+
// one of a set of values, with the values appearing in order of preference.
289+
// The Claim can be defined as an Essential Claim.
290+
func (c *ClaimSet) AddClaimWithValues(topLevelName string, claimName string, essential bool, values ...string) {
291+
c.initializeTopLevelMember(topLevelName)
292+
c.claims[topLevelName][claimName] = &claim{
293+
Essential: essential,
294+
Values: values,
295+
}
296+
}
297+
298+
// initializeTopLevelMember checks if top-level member is initialized
299+
// as a map of JSON Claims objects and initializes it, if it's needed.
300+
func (c *ClaimSet) initializeTopLevelMember(topLevelName string) {
301+
if c.claims == nil {
302+
c.claims = map[string]map[string]*claim{}
303+
}
304+
if c.claims[topLevelName] == nil {
305+
c.claims[topLevelName] = map[string]*claim{}
306+
}
307+
}
308+
232309
// tokenRefresher is a TokenSource that makes "grant_type"=="refresh_token"
233310
// HTTP requests to renew a token using a RefreshToken.
234311
type tokenRefresher struct {

oauth2_test.go

+135-3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package oauth2
66

77
import (
88
"context"
9+
"encoding/json"
910
"errors"
1011
"fmt"
1112
"io/ioutil"
@@ -71,6 +72,62 @@ func TestAuthCodeURL_Optional(t *testing.T) {
7172
}
7273
}
7374

75+
func TestAuthCodeURL_UserInfoClaims(t *testing.T) {
76+
claimSet := &ClaimSet{}
77+
claimSet.AddVoluntaryClaim(UserInfoClaim, "email")
78+
conf := &Config{
79+
ClientID: "CLIENT_ID",
80+
Endpoint: Endpoint{
81+
AuthURL: "/auth-url",
82+
TokenURL: "/token-url",
83+
},
84+
ClaimSet: claimSet,
85+
}
86+
url := conf.AuthCodeURL("")
87+
const want = "/auth-url?claims=%7B%22userinfo%22%3A%7B%22email%22%3Anull%7D%7D&client_id=CLIENT_ID&response_type=code"
88+
if got := url; got != want {
89+
t.Fatalf("got auth code = %q; want %q", got, want)
90+
}
91+
}
92+
93+
func TestAuthCodeURL_IdTokenClaims(t *testing.T) {
94+
claimSet := &ClaimSet{}
95+
claimSet.AddClaimWithValues(IdTokenClaim, "name", false, "nameValue1", "nameValue2")
96+
conf := &Config{
97+
ClientID: "CLIENT_ID",
98+
Endpoint: Endpoint{
99+
AuthURL: "/auth-url",
100+
TokenURL: "/token-url",
101+
},
102+
ClaimSet: claimSet,
103+
}
104+
url := conf.AuthCodeURL("")
105+
const want = "/auth-url?claims=%7B%22id_token%22%3A%7B%22name%22%3A%7B%22values%22%3A%5B%22nameValue1%22%2C%22nameValue2%22%5D%7D%7D%7D&client_id=CLIENT_ID&response_type=code"
106+
if got := url; got != want {
107+
t.Fatalf("got auth code = %q; want %q", got, want)
108+
}
109+
}
110+
111+
func TestAuthCodeURL_MultipleClaims(t *testing.T) {
112+
claimSet := &ClaimSet{}
113+
claimSet.AddVoluntaryClaim(UserInfoClaim, "email")
114+
claimSet.AddClaimWithValue(IdTokenClaim, "email", true, "emailValue")
115+
claimSet.AddClaimWithValues(IdTokenClaim, "name", false, "nameValue1", "nameValue2")
116+
conf := &Config{
117+
ClientID: "CLIENT_ID",
118+
Endpoint: Endpoint{
119+
AuthURL: "/auth-url",
120+
TokenURL: "/token-url",
121+
},
122+
ClaimSet: claimSet,
123+
}
124+
url := conf.AuthCodeURL("")
125+
const want = "/auth-url?claims=%7B%22id_token%22%3A%7B%22email%22%3A%7B%22essential%22%3Atrue%2C%22value%22%3A%22emailValue%22%7D%2C%22name%22%3A%7B%22values%22%3A%5B%22nameValue1%22%2C%22nameValue2%22%5D%7D%7D%2C%22userinfo%22%3A%7B%22email%22%3Anull%7D%7D&client_id=CLIENT_ID&response_type=code"
126+
if got := url; got != want {
127+
t.Fatalf("got auth code = %q; want %q", got, want)
128+
}
129+
}
130+
74131
func TestURLUnsafeClientConfig(t *testing.T) {
75132
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
76133
if got, want := r.Header.Get("Authorization"), "Basic Q0xJRU5UX0lEJTNGJTNGOkNMSUVOVF9TRUNSRVQlM0YlM0Y="; got != want {
@@ -231,10 +288,10 @@ func TestExchangeRequest_JSONResponse(t *testing.T) {
231288
func TestExtraValueRetrieval(t *testing.T) {
232289
values := url.Values{}
233290
kvmap := map[string]string{
234-
"scope": "user", "token_type": "bearer", "expires_in": "86400.92",
291+
"scope": "user", "token_type": "bearer", "expires_in": "86400.92",
235292
"server_time": "1443571905.5606415", "referer_ip": "10.0.0.1",
236-
"etag": "\"afZYj912P4alikMz_P11982\"", "request_id": "86400",
237-
"untrimmed": " untrimmed ",
293+
"etag": "\"afZYj912P4alikMz_P11982\"", "request_id": "86400",
294+
"untrimmed": " untrimmed ",
238295
}
239296
for key, value := range kvmap {
240297
values.Set(key, value)
@@ -548,3 +605,78 @@ func TestConfigClientWithToken(t *testing.T) {
548605
t.Error(err)
549606
}
550607
}
608+
609+
func TestClaimSet_AddVoluntaryClaim(t *testing.T) {
610+
claimSet := &ClaimSet{}
611+
claimSet.AddVoluntaryClaim(UserInfoClaim, "name")
612+
613+
body, err := json.Marshal(claimSet)
614+
if err != nil {
615+
t.Error(err)
616+
}
617+
618+
expected := `{"userinfo":{"name":null}}`
619+
if string(body) != expected {
620+
t.Errorf("Claims request parameter = %q; want %q", string(body), expected)
621+
}
622+
}
623+
624+
func TestClaimSet_AddClaimWithValue_Essential(t *testing.T) {
625+
claimSet := &ClaimSet{}
626+
claimSet.AddClaimWithValue(UserInfoClaim, "name", true, "nameValue")
627+
628+
body, err := json.Marshal(claimSet)
629+
if err != nil {
630+
t.Error(err)
631+
}
632+
633+
expected := `{"userinfo":{"name":{"essential":true,"value":"nameValue"}}}`
634+
if string(body) != expected {
635+
t.Errorf("Claims request parameter = %q; want %q", string(body), expected)
636+
}
637+
}
638+
639+
func TestClaimSet_AddClaimWithValue_Voluntary(t *testing.T) {
640+
claimSet := &ClaimSet{}
641+
claimSet.AddClaimWithValue(UserInfoClaim, "name", false, "nameValue")
642+
643+
body, err := json.Marshal(claimSet)
644+
if err != nil {
645+
t.Error(err)
646+
}
647+
648+
expected := `{"userinfo":{"name":{"value":"nameValue"}}}`
649+
if string(body) != expected {
650+
t.Errorf("Claims request parameter = %q; want %q", string(body), expected)
651+
}
652+
}
653+
654+
func TestClaimSet_AddClaimWithValues_Essential(t *testing.T) {
655+
claimSet := &ClaimSet{}
656+
claimSet.AddClaimWithValues(IdTokenClaim, "email", true, "emailValue", "mailValue")
657+
658+
body, err := json.Marshal(claimSet)
659+
if err != nil {
660+
t.Error(err)
661+
}
662+
663+
expected := `{"id_token":{"email":{"essential":true,"values":["emailValue","mailValue"]}}}`
664+
if string(body) != expected {
665+
t.Errorf("Claims request parameter = %q; want %q", string(body), expected)
666+
}
667+
}
668+
669+
func TestClaimSet_AddClaimWithValues_Voluntary(t *testing.T) {
670+
claimSet := &ClaimSet{}
671+
claimSet.AddClaimWithValues(IdTokenClaim, "email", false, "emailValue", "mailValue")
672+
673+
body, err := json.Marshal(claimSet)
674+
if err != nil {
675+
t.Error(err)
676+
}
677+
678+
expected := `{"id_token":{"email":{"values":["emailValue","mailValue"]}}}`
679+
if string(body) != expected {
680+
t.Errorf("Claims request parameter = %q; want %q", string(body), expected)
681+
}
682+
}

0 commit comments

Comments
 (0)