Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for claims request parameter #332

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ package oauth2 // import "golang.org/x/oauth2"
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/url"
Expand Down Expand Up @@ -57,6 +58,13 @@ type Config struct {

// Scope specifies optional requested permissions.
Scopes []string

// ClaimSet optionally enables requesting of individual Claims.
// It is the only way to request Claims outside the standard set.
// It is also the only way to request specific combinations
// of the standard Claims that cannot be specified using scope values.
// See https://openid.net/specs/openid-connect-core-1_0.html#ClaimsParameter.
ClaimSet *ClaimSet
}

// A TokenSource is anything that can return a token.
Expand All @@ -79,6 +87,26 @@ type Endpoint struct {
AuthStyle AuthStyle
}


// JSON keys for top-level members of the "claims" request parameter value.
const (
UserInfoClaim = "userinfo"
IdTokenClaim = "id_token"
)

// claim describes JSON object used to specify additional information
// about the Claim being requested.
type claim struct {
Essential bool `json:"essential,omitempty"`
Value string `json:"value,omitempty"`
Values []string `json:"values,omitempty"`
}

// ClaimSet contains all requested individual Claims.
type ClaimSet struct {
claims map[string]map[string]*claim
}

// AuthStyle represents how requests for tokens are authenticated
// to the server.
type AuthStyle int
Expand All @@ -99,6 +127,7 @@ const (
AuthStyleInHeader AuthStyle = 2
)


var (
// AccessTypeOnline and AccessTypeOffline are options passed
// to the Options.AuthCodeURL method. They modify the
Expand Down Expand Up @@ -172,6 +201,13 @@ func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string {
} else {
buf.WriteByte('?')
}
if c.ClaimSet != nil {
body, err := json.Marshal(c.ClaimSet)
if err != nil {
panic(err)
}
v.Set("claims", string(body))
}
buf.WriteString(v.Encode())
return buf.String()
}
Expand Down Expand Up @@ -250,6 +286,53 @@ func (c *Config) TokenSource(ctx context.Context, t *Token) TokenSource {
}
}

// MarshalJSON is used to encode a private fields from ClaimSet.
// Final output represents actual value for "claims" request parameter.
func (c *ClaimSet) MarshalJSON() ([]byte, error) {
return json.Marshal(c.claims)
}

// AddVoluntaryClaim adds the Claim being requested as a Voluntary Claim.
// Authorization Server is not required to provide this Claim in its response.
func (c *ClaimSet) AddVoluntaryClaim(topLevelName string, claimName string) {
c.initializeTopLevelMember(topLevelName)
c.claims[topLevelName][claimName] = nil
}

// AddClaimWithValue adds the Claim being requested to return a particular value.
// If Claim is defined as an Essential Claim, the Authorization Server is required
// to provide this Claim in its response.
func (c *ClaimSet) AddClaimWithValue(topLevelName string, claimName string, essential bool, value string) {
c.initializeTopLevelMember(topLevelName)
c.claims[topLevelName][claimName] = &claim{
Essential: essential,
Value: value,
}
}

// AddClaimWithValues adds the Claim being requested to return
// one of a set of values, with the values appearing in order of preference.
// If Claim is defined as an Essential Claim, the Authorization Server is required
// to provide this Claim in its response.
func (c *ClaimSet) AddClaimWithValues(topLevelName string, claimName string, essential bool, values ...string) {
c.initializeTopLevelMember(topLevelName)
c.claims[topLevelName][claimName] = &claim{
Essential: essential,
Values: values,
}
}

// initializeTopLevelMember checks if top-level member is initialized
// as a map of JSON Claims objects and initializes it, if it's needed.
func (c *ClaimSet) initializeTopLevelMember(topLevelName string) {
if c.claims == nil {
c.claims = map[string]map[string]*claim{}
}
if c.claims[topLevelName] == nil {
c.claims[topLevelName] = map[string]*claim{}
}
}

// tokenRefresher is a TokenSource that makes "grant_type"=="refresh_token"
// HTTP requests to renew a token using a RefreshToken.
type tokenRefresher struct {
Expand Down
138 changes: 135 additions & 3 deletions oauth2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package oauth2

import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -74,6 +75,62 @@ func TestAuthCodeURL_Optional(t *testing.T) {
}
}

func TestAuthCodeURL_UserInfoClaims(t *testing.T) {
claimSet := &ClaimSet{}
claimSet.AddVoluntaryClaim(UserInfoClaim, "email")
conf := &Config{
ClientID: "CLIENT_ID",
Endpoint: Endpoint{
AuthURL: "/auth-url",
TokenURL: "/token-url",
},
ClaimSet: claimSet,
}
url := conf.AuthCodeURL("")
const want = "/auth-url?claims=%7B%22userinfo%22%3A%7B%22email%22%3Anull%7D%7D&client_id=CLIENT_ID&response_type=code"
if got := url; got != want {
t.Fatalf("got auth code = %q; want %q", got, want)
}
}

func TestAuthCodeURL_IdTokenClaims(t *testing.T) {
claimSet := &ClaimSet{}
claimSet.AddClaimWithValues(IdTokenClaim, "name", false, "nameValue1", "nameValue2")
conf := &Config{
ClientID: "CLIENT_ID",
Endpoint: Endpoint{
AuthURL: "/auth-url",
TokenURL: "/token-url",
},
ClaimSet: claimSet,
}
url := conf.AuthCodeURL("")
const want = "/auth-url?claims=%7B%22id_token%22%3A%7B%22name%22%3A%7B%22values%22%3A%5B%22nameValue1%22%2C%22nameValue2%22%5D%7D%7D%7D&client_id=CLIENT_ID&response_type=code"
if got := url; got != want {
t.Fatalf("got auth code = %q; want %q", got, want)
}
}

func TestAuthCodeURL_MultipleClaims(t *testing.T) {
claimSet := &ClaimSet{}
claimSet.AddVoluntaryClaim(UserInfoClaim, "email")
claimSet.AddClaimWithValue(IdTokenClaim, "email", true, "emailValue")
claimSet.AddClaimWithValues(IdTokenClaim, "name", false, "nameValue1", "nameValue2")
conf := &Config{
ClientID: "CLIENT_ID",
Endpoint: Endpoint{
AuthURL: "/auth-url",
TokenURL: "/token-url",
},
ClaimSet: claimSet,
}
url := conf.AuthCodeURL("")
const want = "/auth-url?claims=%7B%22id_token%22%3A%7B%22email%22%3A%7B%22essential%22%3Atrue%2C%22value%22%3A%22emailValue%22%7D%2C%22name%22%3A%7B%22values%22%3A%5B%22nameValue1%22%2C%22nameValue2%22%5D%7D%7D%2C%22userinfo%22%3A%7B%22email%22%3Anull%7D%7D&client_id=CLIENT_ID&response_type=code"
if got := url; got != want {
t.Fatalf("got auth code = %q; want %q", got, want)
}
}

func TestURLUnsafeClientConfig(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got, want := r.Header.Get("Authorization"), "Basic Q0xJRU5UX0lEJTNGJTNGOkNMSUVOVF9TRUNSRVQlM0YlM0Y="; got != want {
Expand Down Expand Up @@ -234,10 +291,10 @@ func TestExchangeRequest_JSONResponse(t *testing.T) {
func TestExtraValueRetrieval(t *testing.T) {
values := url.Values{}
kvmap := map[string]string{
"scope": "user", "token_type": "bearer", "expires_in": "86400.92",
"scope": "user", "token_type": "bearer", "expires_in": "86400.92",
"server_time": "1443571905.5606415", "referer_ip": "10.0.0.1",
"etag": "\"afZYj912P4alikMz_P11982\"", "request_id": "86400",
"untrimmed": " untrimmed ",
"etag": "\"afZYj912P4alikMz_P11982\"", "request_id": "86400",
"untrimmed": " untrimmed ",
}
for key, value := range kvmap {
values.Set(key, value)
Expand Down Expand Up @@ -561,3 +618,78 @@ func TestConfigClientWithToken(t *testing.T) {
t.Error(err)
}
}

func TestClaimSet_AddVoluntaryClaim(t *testing.T) {
claimSet := &ClaimSet{}
claimSet.AddVoluntaryClaim(UserInfoClaim, "name")

body, err := json.Marshal(claimSet)
if err != nil {
t.Error(err)
}

expected := `{"userinfo":{"name":null}}`
if string(body) != expected {
t.Errorf("Claims request parameter = %q; want %q", string(body), expected)
}
}

func TestClaimSet_AddClaimWithValue_Essential(t *testing.T) {
claimSet := &ClaimSet{}
claimSet.AddClaimWithValue(UserInfoClaim, "name", true, "nameValue")

body, err := json.Marshal(claimSet)
if err != nil {
t.Error(err)
}

expected := `{"userinfo":{"name":{"essential":true,"value":"nameValue"}}}`
if string(body) != expected {
t.Errorf("Claims request parameter = %q; want %q", string(body), expected)
}
}

func TestClaimSet_AddClaimWithValue_Voluntary(t *testing.T) {
claimSet := &ClaimSet{}
claimSet.AddClaimWithValue(UserInfoClaim, "name", false, "nameValue")

body, err := json.Marshal(claimSet)
if err != nil {
t.Error(err)
}

expected := `{"userinfo":{"name":{"value":"nameValue"}}}`
if string(body) != expected {
t.Errorf("Claims request parameter = %q; want %q", string(body), expected)
}
}

func TestClaimSet_AddClaimWithValues_Essential(t *testing.T) {
claimSet := &ClaimSet{}
claimSet.AddClaimWithValues(IdTokenClaim, "email", true, "emailValue", "mailValue")

body, err := json.Marshal(claimSet)
if err != nil {
t.Error(err)
}

expected := `{"id_token":{"email":{"essential":true,"values":["emailValue","mailValue"]}}}`
if string(body) != expected {
t.Errorf("Claims request parameter = %q; want %q", string(body), expected)
}
}

func TestClaimSet_AddClaimWithValues_Voluntary(t *testing.T) {
claimSet := &ClaimSet{}
claimSet.AddClaimWithValues(IdTokenClaim, "email", false, "emailValue", "mailValue")

body, err := json.Marshal(claimSet)
if err != nil {
t.Error(err)
}

expected := `{"id_token":{"email":{"values":["emailValue","mailValue"]}}}`
if string(body) != expected {
t.Errorf("Claims request parameter = %q; want %q", string(body), expected)
}
}