From b762685799422ab779adefde348535e7a204c363 Mon Sep 17 00:00:00 2001
From: Marko Milojevic <marko.milojevic@aoe.com>
Date: Mon, 15 Oct 2018 11:17:29 +0200
Subject: [PATCH 1/2] support for claims request parameter:
 https://openid.net/specs/openid-connect-core-1_0.html#ClaimsParameter

---
 oauth2.go      |  77 +++++++++++++++++++++++++++
 oauth2_test.go | 138 +++++++++++++++++++++++++++++++++++++++++++++++--
 2 files changed, 212 insertions(+), 3 deletions(-)

diff --git a/oauth2.go b/oauth2.go
index 3de63315b..bda075f0d 100644
--- a/oauth2.go
+++ b/oauth2.go
@@ -11,6 +11,7 @@ package oauth2 // import "golang.org/x/oauth2"
 import (
 	"bytes"
 	"context"
+	"encoding/json"
 	"errors"
 	"net/http"
 	"net/url"
@@ -61,6 +62,10 @@ type Config struct {
 
 	// Scope specifies optional requested permissions.
 	Scopes []string
+
+	// ClaimSet is optional requested parameter used to that
+	// specific Claims be returned.
+	ClaimSet *ClaimSet
 }
 
 // A TokenSource is anything that can return a token.
@@ -78,6 +83,26 @@ type Endpoint struct {
 	TokenURL string
 }
 
+// JSON keys for top-level member of the Claims request JSON.
+const (
+	UserInfoClaim = "userinfo"
+	IdTokenClaim  = "id_token"
+)
+
+// claim describes JSON object used to specify additional information
+// about the Claim being requested.
+type claim struct {
+	Essential bool     `json:"essential,omitempty"`
+	Value     string   `json:"value,omitempty"`
+	Values    []string `json:"values,omitempty"`
+}
+
+// ClaimSet contains map with members of the Claims request.
+// It provides methods to add specific Claims.
+type ClaimSet struct {
+	claims map[string]map[string]*claim
+}
+
 var (
 	// AccessTypeOnline and AccessTypeOffline are options passed
 	// to the Options.AuthCodeURL method. They modify the
@@ -151,6 +176,13 @@ func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string {
 	} else {
 		buf.WriteByte('?')
 	}
+	if c.ClaimSet != nil {
+		body, err := json.Marshal(c.ClaimSet)
+		if err != nil {
+			panic(err)
+		}
+		v.Set("claims", string(body))
+	}
 	buf.WriteString(v.Encode())
 	return buf.String()
 }
@@ -229,6 +261,51 @@ func (c *Config) TokenSource(ctx context.Context, t *Token) TokenSource {
 	}
 }
 
+// MarshalJSON is part of the json.Marshaler interface.
+// It's used to encode hidden map that contains Claims objects.
+func (c *ClaimSet) MarshalJSON() ([]byte, error) {
+	return json.Marshal(c.claims)
+}
+
+// AddVoluntaryClaim adds the Claim being requested in default manner,
+// as a Voluntary Claim.
+func (c *ClaimSet) AddVoluntaryClaim(topLevelName string, claimName string) {
+	c.initializeTopLevelMember(topLevelName)
+	c.claims[topLevelName][claimName] = nil
+}
+
+// AddClaimWithValue adds the Claim being requested to return a particular value.
+// The Claim can be defined as an Essential Claim.
+func (c *ClaimSet) AddClaimWithValue(topLevelName string, claimName string, essential bool, value string) {
+	c.initializeTopLevelMember(topLevelName)
+	c.claims[topLevelName][claimName] = &claim{
+		Essential: essential,
+		Value:     value,
+	}
+}
+
+// AddClaimWithValues adds the Claim being requested to return
+// one of a set of values, with the values appearing in order of preference.
+// The Claim can be defined as an Essential Claim.
+func (c *ClaimSet) AddClaimWithValues(topLevelName string, claimName string, essential bool, values ...string) {
+	c.initializeTopLevelMember(topLevelName)
+	c.claims[topLevelName][claimName] = &claim{
+		Essential: essential,
+		Values:    values,
+	}
+}
+
+// initializeTopLevelMember checks if top-level member is initialized
+// as a map of JSON Claims objects and initializes it, if it's needed.
+func (c *ClaimSet) initializeTopLevelMember(topLevelName string) {
+	if c.claims == nil {
+		c.claims = map[string]map[string]*claim{}
+	}
+	if c.claims[topLevelName] == nil {
+		c.claims[topLevelName] = map[string]*claim{}
+	}
+}
+
 // tokenRefresher is a TokenSource that makes "grant_type"=="refresh_token"
 // HTTP requests to renew a token using a RefreshToken.
 type tokenRefresher struct {
diff --git a/oauth2_test.go b/oauth2_test.go
index 19aaf6b2b..f215ebccb 100644
--- a/oauth2_test.go
+++ b/oauth2_test.go
@@ -6,6 +6,7 @@ package oauth2
 
 import (
 	"context"
+	"encoding/json"
 	"errors"
 	"fmt"
 	"io/ioutil"
@@ -71,6 +72,62 @@ func TestAuthCodeURL_Optional(t *testing.T) {
 	}
 }
 
+func TestAuthCodeURL_UserInfoClaims(t *testing.T) {
+	claimSet := &ClaimSet{}
+	claimSet.AddVoluntaryClaim(UserInfoClaim, "email")
+	conf := &Config{
+		ClientID: "CLIENT_ID",
+		Endpoint: Endpoint{
+			AuthURL:  "/auth-url",
+			TokenURL: "/token-url",
+		},
+		ClaimSet: claimSet,
+	}
+	url := conf.AuthCodeURL("")
+	const want = "/auth-url?claims=%7B%22userinfo%22%3A%7B%22email%22%3Anull%7D%7D&client_id=CLIENT_ID&response_type=code"
+	if got := url; got != want {
+		t.Fatalf("got auth code = %q; want %q", got, want)
+	}
+}
+
+func TestAuthCodeURL_IdTokenClaims(t *testing.T) {
+	claimSet := &ClaimSet{}
+	claimSet.AddClaimWithValues(IdTokenClaim, "name", false, "nameValue1", "nameValue2")
+	conf := &Config{
+		ClientID: "CLIENT_ID",
+		Endpoint: Endpoint{
+			AuthURL:  "/auth-url",
+			TokenURL: "/token-url",
+		},
+		ClaimSet: claimSet,
+	}
+	url := conf.AuthCodeURL("")
+	const want = "/auth-url?claims=%7B%22id_token%22%3A%7B%22name%22%3A%7B%22values%22%3A%5B%22nameValue1%22%2C%22nameValue2%22%5D%7D%7D%7D&client_id=CLIENT_ID&response_type=code"
+	if got := url; got != want {
+		t.Fatalf("got auth code = %q; want %q", got, want)
+	}
+}
+
+func TestAuthCodeURL_MultipleClaims(t *testing.T) {
+	claimSet := &ClaimSet{}
+	claimSet.AddVoluntaryClaim(UserInfoClaim, "email")
+	claimSet.AddClaimWithValue(IdTokenClaim, "email", true, "emailValue")
+	claimSet.AddClaimWithValues(IdTokenClaim, "name", false, "nameValue1", "nameValue2")
+	conf := &Config{
+		ClientID: "CLIENT_ID",
+		Endpoint: Endpoint{
+			AuthURL:  "/auth-url",
+			TokenURL: "/token-url",
+		},
+		ClaimSet: claimSet,
+	}
+	url := conf.AuthCodeURL("")
+	const want = "/auth-url?claims=%7B%22id_token%22%3A%7B%22email%22%3A%7B%22essential%22%3Atrue%2C%22value%22%3A%22emailValue%22%7D%2C%22name%22%3A%7B%22values%22%3A%5B%22nameValue1%22%2C%22nameValue2%22%5D%7D%7D%2C%22userinfo%22%3A%7B%22email%22%3Anull%7D%7D&client_id=CLIENT_ID&response_type=code"
+	if got := url; got != want {
+		t.Fatalf("got auth code = %q; want %q", got, want)
+	}
+}
+
 func TestURLUnsafeClientConfig(t *testing.T) {
 	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		if got, want := r.Header.Get("Authorization"), "Basic Q0xJRU5UX0lEJTNGJTNGOkNMSUVOVF9TRUNSRVQlM0YlM0Y="; got != want {
@@ -231,10 +288,10 @@ func TestExchangeRequest_JSONResponse(t *testing.T) {
 func TestExtraValueRetrieval(t *testing.T) {
 	values := url.Values{}
 	kvmap := map[string]string{
-		"scope": "user", "token_type": "bearer", "expires_in": "86400.92",
+		"scope":       "user", "token_type": "bearer", "expires_in": "86400.92",
 		"server_time": "1443571905.5606415", "referer_ip": "10.0.0.1",
-		"etag": "\"afZYj912P4alikMz_P11982\"", "request_id": "86400",
-		"untrimmed": "  untrimmed  ",
+		"etag":        "\"afZYj912P4alikMz_P11982\"", "request_id": "86400",
+		"untrimmed":   "  untrimmed  ",
 	}
 	for key, value := range kvmap {
 		values.Set(key, value)
@@ -548,3 +605,78 @@ func TestConfigClientWithToken(t *testing.T) {
 		t.Error(err)
 	}
 }
+
+func TestClaimSet_AddVoluntaryClaim(t *testing.T) {
+	claimSet := &ClaimSet{}
+	claimSet.AddVoluntaryClaim(UserInfoClaim, "name")
+
+	body, err := json.Marshal(claimSet)
+	if err != nil {
+		t.Error(err)
+	}
+
+	expected := `{"userinfo":{"name":null}}`
+	if string(body) != expected {
+		t.Errorf("Claims request parameter = %q; want %q", string(body), expected)
+	}
+}
+
+func TestClaimSet_AddClaimWithValue_Essential(t *testing.T) {
+	claimSet := &ClaimSet{}
+	claimSet.AddClaimWithValue(UserInfoClaim, "name", true, "nameValue")
+
+	body, err := json.Marshal(claimSet)
+	if err != nil {
+		t.Error(err)
+	}
+
+	expected := `{"userinfo":{"name":{"essential":true,"value":"nameValue"}}}`
+	if string(body) != expected {
+		t.Errorf("Claims request parameter = %q; want %q", string(body), expected)
+	}
+}
+
+func TestClaimSet_AddClaimWithValue_Voluntary(t *testing.T) {
+	claimSet := &ClaimSet{}
+	claimSet.AddClaimWithValue(UserInfoClaim, "name", false, "nameValue")
+
+	body, err := json.Marshal(claimSet)
+	if err != nil {
+		t.Error(err)
+	}
+
+	expected := `{"userinfo":{"name":{"value":"nameValue"}}}`
+	if string(body) != expected {
+		t.Errorf("Claims request parameter = %q; want %q", string(body), expected)
+	}
+}
+
+func TestClaimSet_AddClaimWithValues_Essential(t *testing.T) {
+	claimSet := &ClaimSet{}
+	claimSet.AddClaimWithValues(IdTokenClaim, "email", true, "emailValue", "mailValue")
+
+	body, err := json.Marshal(claimSet)
+	if err != nil {
+		t.Error(err)
+	}
+
+	expected := `{"id_token":{"email":{"essential":true,"values":["emailValue","mailValue"]}}}`
+	if string(body) != expected {
+		t.Errorf("Claims request parameter = %q; want %q", string(body), expected)
+	}
+}
+
+func TestClaimSet_AddClaimWithValues_Voluntary(t *testing.T) {
+	claimSet := &ClaimSet{}
+	claimSet.AddClaimWithValues(IdTokenClaim, "email", false, "emailValue", "mailValue")
+
+	body, err := json.Marshal(claimSet)
+	if err != nil {
+		t.Error(err)
+	}
+
+	expected := `{"id_token":{"email":{"values":["emailValue","mailValue"]}}}`
+	if string(body) != expected {
+		t.Errorf("Claims request parameter = %q; want %q", string(body), expected)
+	}
+}

From 3a2f27cdbda43ba4594e232a3460e6b4286b59eb Mon Sep 17 00:00:00 2001
From: Marko Milojevic <marko.milojevic@aoe.com>
Date: Tue, 29 Jan 2019 17:44:13 +0100
Subject: [PATCH 2/2] address patchset3 comments

---
 oauth2.go | 26 +++++++++++++++-----------
 1 file changed, 15 insertions(+), 11 deletions(-)

diff --git a/oauth2.go b/oauth2.go
index bda075f0d..5307a7a15 100644
--- a/oauth2.go
+++ b/oauth2.go
@@ -63,8 +63,11 @@ type Config struct {
 	// Scope specifies optional requested permissions.
 	Scopes []string
 
-	// ClaimSet is optional requested parameter used to that
-	// specific Claims be returned.
+	// ClaimSet optionally enables requesting of individual Claims.
+	// It is the only way to request Claims outside the standard set.
+	// It is also the only way to request specific combinations
+	// of the standard Claims that cannot be specified using scope values.
+	// See https://openid.net/specs/openid-connect-core-1_0.html#ClaimsParameter.
 	ClaimSet *ClaimSet
 }
 
@@ -83,7 +86,7 @@ type Endpoint struct {
 	TokenURL string
 }
 
-// JSON keys for top-level member of the Claims request JSON.
+// JSON keys for top-level members of the "claims" request parameter value.
 const (
 	UserInfoClaim = "userinfo"
 	IdTokenClaim  = "id_token"
@@ -97,8 +100,7 @@ type claim struct {
 	Values    []string `json:"values,omitempty"`
 }
 
-// ClaimSet contains map with members of the Claims request.
-// It provides methods to add specific Claims.
+// ClaimSet contains all requested individual Claims.
 type ClaimSet struct {
 	claims map[string]map[string]*claim
 }
@@ -261,21 +263,22 @@ func (c *Config) TokenSource(ctx context.Context, t *Token) TokenSource {
 	}
 }
 
-// MarshalJSON is part of the json.Marshaler interface.
-// It's used to encode hidden map that contains Claims objects.
+// MarshalJSON is used to encode a private fields from ClaimSet.
+// Final output  represents actual value for "claims" request parameter.
 func (c *ClaimSet) MarshalJSON() ([]byte, error) {
 	return json.Marshal(c.claims)
 }
 
-// AddVoluntaryClaim adds the Claim being requested in default manner,
-// as a Voluntary Claim.
+// AddVoluntaryClaim adds the Claim being requested as a Voluntary Claim.
+// Authorization Server is not required to provide this Claim in its response.
 func (c *ClaimSet) AddVoluntaryClaim(topLevelName string, claimName string) {
 	c.initializeTopLevelMember(topLevelName)
 	c.claims[topLevelName][claimName] = nil
 }
 
 // AddClaimWithValue adds the Claim being requested to return a particular value.
-// The Claim can be defined as an Essential Claim.
+// If Claim is defined as an Essential Claim, the Authorization Server is required
+// to provide this Claim in its response.
 func (c *ClaimSet) AddClaimWithValue(topLevelName string, claimName string, essential bool, value string) {
 	c.initializeTopLevelMember(topLevelName)
 	c.claims[topLevelName][claimName] = &claim{
@@ -286,7 +289,8 @@ func (c *ClaimSet) AddClaimWithValue(topLevelName string, claimName string, esse
 
 // AddClaimWithValues adds the Claim being requested to return
 // one of a set of values, with the values appearing in order of preference.
-// The Claim can be defined as an Essential Claim.
+// If Claim is defined as an Essential Claim, the Authorization Server is required
+// to provide this Claim in its response.
 func (c *ClaimSet) AddClaimWithValues(topLevelName string, claimName string, essential bool, values ...string) {
 	c.initializeTopLevelMember(topLevelName)
 	c.claims[topLevelName][claimName] = &claim{