diff --git a/oauth2.go b/oauth2.go index 428283f0b..e7577ceee 100644 --- a/oauth2.go +++ b/oauth2.go @@ -11,6 +11,7 @@ package oauth2 // import "golang.org/x/oauth2" import ( "bytes" "context" + "encoding/json" "errors" "net/http" "net/url" @@ -57,6 +58,13 @@ type Config struct { // Scope specifies optional requested permissions. Scopes []string + + // ClaimSet optionally enables requesting of individual Claims. + // It is the only way to request Claims outside the standard set. + // It is also the only way to request specific combinations + // of the standard Claims that cannot be specified using scope values. + // See https://openid.net/specs/openid-connect-core-1_0.html#ClaimsParameter. + ClaimSet *ClaimSet } // A TokenSource is anything that can return a token. @@ -79,6 +87,26 @@ type Endpoint struct { AuthStyle AuthStyle } + +// JSON keys for top-level members of the "claims" request parameter value. +const ( + UserInfoClaim = "userinfo" + IdTokenClaim = "id_token" +) + +// claim describes JSON object used to specify additional information +// about the Claim being requested. +type claim struct { + Essential bool `json:"essential,omitempty"` + Value string `json:"value,omitempty"` + Values []string `json:"values,omitempty"` +} + +// ClaimSet contains all requested individual Claims. +type ClaimSet struct { + claims map[string]map[string]*claim +} + // AuthStyle represents how requests for tokens are authenticated // to the server. type AuthStyle int @@ -99,6 +127,7 @@ const ( AuthStyleInHeader AuthStyle = 2 ) + var ( // AccessTypeOnline and AccessTypeOffline are options passed // to the Options.AuthCodeURL method. They modify the @@ -172,6 +201,13 @@ func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string { } else { buf.WriteByte('?') } + if c.ClaimSet != nil { + body, err := json.Marshal(c.ClaimSet) + if err != nil { + panic(err) + } + v.Set("claims", string(body)) + } buf.WriteString(v.Encode()) return buf.String() } @@ -250,6 +286,53 @@ func (c *Config) TokenSource(ctx context.Context, t *Token) TokenSource { } } +// MarshalJSON is used to encode a private fields from ClaimSet. +// Final output represents actual value for "claims" request parameter. +func (c *ClaimSet) MarshalJSON() ([]byte, error) { + return json.Marshal(c.claims) +} + +// AddVoluntaryClaim adds the Claim being requested as a Voluntary Claim. +// Authorization Server is not required to provide this Claim in its response. +func (c *ClaimSet) AddVoluntaryClaim(topLevelName string, claimName string) { + c.initializeTopLevelMember(topLevelName) + c.claims[topLevelName][claimName] = nil +} + +// AddClaimWithValue adds the Claim being requested to return a particular value. +// If Claim is defined as an Essential Claim, the Authorization Server is required +// to provide this Claim in its response. +func (c *ClaimSet) AddClaimWithValue(topLevelName string, claimName string, essential bool, value string) { + c.initializeTopLevelMember(topLevelName) + c.claims[topLevelName][claimName] = &claim{ + Essential: essential, + Value: value, + } +} + +// AddClaimWithValues adds the Claim being requested to return +// one of a set of values, with the values appearing in order of preference. +// If Claim is defined as an Essential Claim, the Authorization Server is required +// to provide this Claim in its response. +func (c *ClaimSet) AddClaimWithValues(topLevelName string, claimName string, essential bool, values ...string) { + c.initializeTopLevelMember(topLevelName) + c.claims[topLevelName][claimName] = &claim{ + Essential: essential, + Values: values, + } +} + +// initializeTopLevelMember checks if top-level member is initialized +// as a map of JSON Claims objects and initializes it, if it's needed. +func (c *ClaimSet) initializeTopLevelMember(topLevelName string) { + if c.claims == nil { + c.claims = map[string]map[string]*claim{} + } + if c.claims[topLevelName] == nil { + c.claims[topLevelName] = map[string]*claim{} + } +} + // tokenRefresher is a TokenSource that makes "grant_type"=="refresh_token" // HTTP requests to renew a token using a RefreshToken. type tokenRefresher struct { diff --git a/oauth2_test.go b/oauth2_test.go index b76eaae95..494d7079f 100644 --- a/oauth2_test.go +++ b/oauth2_test.go @@ -6,6 +6,7 @@ package oauth2 import ( "context" + "encoding/json" "errors" "fmt" "io" @@ -74,6 +75,62 @@ func TestAuthCodeURL_Optional(t *testing.T) { } } +func TestAuthCodeURL_UserInfoClaims(t *testing.T) { + claimSet := &ClaimSet{} + claimSet.AddVoluntaryClaim(UserInfoClaim, "email") + conf := &Config{ + ClientID: "CLIENT_ID", + Endpoint: Endpoint{ + AuthURL: "/auth-url", + TokenURL: "/token-url", + }, + ClaimSet: claimSet, + } + url := conf.AuthCodeURL("") + const want = "/auth-url?claims=%7B%22userinfo%22%3A%7B%22email%22%3Anull%7D%7D&client_id=CLIENT_ID&response_type=code" + if got := url; got != want { + t.Fatalf("got auth code = %q; want %q", got, want) + } +} + +func TestAuthCodeURL_IdTokenClaims(t *testing.T) { + claimSet := &ClaimSet{} + claimSet.AddClaimWithValues(IdTokenClaim, "name", false, "nameValue1", "nameValue2") + conf := &Config{ + ClientID: "CLIENT_ID", + Endpoint: Endpoint{ + AuthURL: "/auth-url", + TokenURL: "/token-url", + }, + ClaimSet: claimSet, + } + url := conf.AuthCodeURL("") + const want = "/auth-url?claims=%7B%22id_token%22%3A%7B%22name%22%3A%7B%22values%22%3A%5B%22nameValue1%22%2C%22nameValue2%22%5D%7D%7D%7D&client_id=CLIENT_ID&response_type=code" + if got := url; got != want { + t.Fatalf("got auth code = %q; want %q", got, want) + } +} + +func TestAuthCodeURL_MultipleClaims(t *testing.T) { + claimSet := &ClaimSet{} + claimSet.AddVoluntaryClaim(UserInfoClaim, "email") + claimSet.AddClaimWithValue(IdTokenClaim, "email", true, "emailValue") + claimSet.AddClaimWithValues(IdTokenClaim, "name", false, "nameValue1", "nameValue2") + conf := &Config{ + ClientID: "CLIENT_ID", + Endpoint: Endpoint{ + AuthURL: "/auth-url", + TokenURL: "/token-url", + }, + ClaimSet: claimSet, + } + url := conf.AuthCodeURL("") + const want = "/auth-url?claims=%7B%22id_token%22%3A%7B%22email%22%3A%7B%22essential%22%3Atrue%2C%22value%22%3A%22emailValue%22%7D%2C%22name%22%3A%7B%22values%22%3A%5B%22nameValue1%22%2C%22nameValue2%22%5D%7D%7D%2C%22userinfo%22%3A%7B%22email%22%3Anull%7D%7D&client_id=CLIENT_ID&response_type=code" + if got := url; got != want { + t.Fatalf("got auth code = %q; want %q", got, want) + } +} + func TestURLUnsafeClientConfig(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if got, want := r.Header.Get("Authorization"), "Basic Q0xJRU5UX0lEJTNGJTNGOkNMSUVOVF9TRUNSRVQlM0YlM0Y="; got != want { @@ -234,10 +291,10 @@ func TestExchangeRequest_JSONResponse(t *testing.T) { func TestExtraValueRetrieval(t *testing.T) { values := url.Values{} kvmap := map[string]string{ - "scope": "user", "token_type": "bearer", "expires_in": "86400.92", + "scope": "user", "token_type": "bearer", "expires_in": "86400.92", "server_time": "1443571905.5606415", "referer_ip": "10.0.0.1", - "etag": "\"afZYj912P4alikMz_P11982\"", "request_id": "86400", - "untrimmed": " untrimmed ", + "etag": "\"afZYj912P4alikMz_P11982\"", "request_id": "86400", + "untrimmed": " untrimmed ", } for key, value := range kvmap { values.Set(key, value) @@ -561,3 +618,78 @@ func TestConfigClientWithToken(t *testing.T) { t.Error(err) } } + +func TestClaimSet_AddVoluntaryClaim(t *testing.T) { + claimSet := &ClaimSet{} + claimSet.AddVoluntaryClaim(UserInfoClaim, "name") + + body, err := json.Marshal(claimSet) + if err != nil { + t.Error(err) + } + + expected := `{"userinfo":{"name":null}}` + if string(body) != expected { + t.Errorf("Claims request parameter = %q; want %q", string(body), expected) + } +} + +func TestClaimSet_AddClaimWithValue_Essential(t *testing.T) { + claimSet := &ClaimSet{} + claimSet.AddClaimWithValue(UserInfoClaim, "name", true, "nameValue") + + body, err := json.Marshal(claimSet) + if err != nil { + t.Error(err) + } + + expected := `{"userinfo":{"name":{"essential":true,"value":"nameValue"}}}` + if string(body) != expected { + t.Errorf("Claims request parameter = %q; want %q", string(body), expected) + } +} + +func TestClaimSet_AddClaimWithValue_Voluntary(t *testing.T) { + claimSet := &ClaimSet{} + claimSet.AddClaimWithValue(UserInfoClaim, "name", false, "nameValue") + + body, err := json.Marshal(claimSet) + if err != nil { + t.Error(err) + } + + expected := `{"userinfo":{"name":{"value":"nameValue"}}}` + if string(body) != expected { + t.Errorf("Claims request parameter = %q; want %q", string(body), expected) + } +} + +func TestClaimSet_AddClaimWithValues_Essential(t *testing.T) { + claimSet := &ClaimSet{} + claimSet.AddClaimWithValues(IdTokenClaim, "email", true, "emailValue", "mailValue") + + body, err := json.Marshal(claimSet) + if err != nil { + t.Error(err) + } + + expected := `{"id_token":{"email":{"essential":true,"values":["emailValue","mailValue"]}}}` + if string(body) != expected { + t.Errorf("Claims request parameter = %q; want %q", string(body), expected) + } +} + +func TestClaimSet_AddClaimWithValues_Voluntary(t *testing.T) { + claimSet := &ClaimSet{} + claimSet.AddClaimWithValues(IdTokenClaim, "email", false, "emailValue", "mailValue") + + body, err := json.Marshal(claimSet) + if err != nil { + t.Error(err) + } + + expected := `{"id_token":{"email":{"values":["emailValue","mailValue"]}}}` + if string(body) != expected { + t.Errorf("Claims request parameter = %q; want %q", string(body), expected) + } +}