Skip to content

Commit 5815ce5

Browse files
Adds gsuite provider specific extension for fetching groups and user information (#123)
1 parent 83f6a58 commit 5815ce5

File tree

437 files changed

+80682
-88340
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

437 files changed

+80682
-88340
lines changed

go.mod

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,14 @@ require (
1414
github.com/hashicorp/vault/api v1.0.5-0.20200215224050-f6547fa8e820
1515
github.com/hashicorp/vault/sdk v0.1.14-0.20200215224050-f6547fa8e820
1616
github.com/hashicorp/yamux v0.0.0-20181012175058-2f1d1f20f75d // indirect
17+
github.com/mitchellh/mapstructure v1.1.2
1718
github.com/mitchellh/pointerstructure v1.0.0
1819
github.com/patrickmn/go-cache v2.1.0+incompatible
1920
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
2021
github.com/ryanuber/go-glob v1.0.0
21-
github.com/stretchr/testify v1.3.0
22+
github.com/stretchr/testify v1.4.0
2223
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d
23-
golang.org/x/sync v0.0.0-20190423024810-112230192c58
24-
golang.org/x/text v0.3.2 // indirect
25-
google.golang.org/appengine v1.5.0 // indirect
26-
google.golang.org/genproto v0.0.0-20190801165951-fa694d86fc64 // indirect
24+
golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a
25+
google.golang.org/api v0.29.0
2726
gopkg.in/square/go-jose.v2 v2.4.1
2827
)

go.sum

Lines changed: 238 additions & 0 deletions
Large diffs are not rendered by default.

path_login.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,11 @@ func (b *jwtAuthBackend) createIdentity(allClaims map[string]interface{}, role *
318318
return nil, nil, fmt.Errorf("claim %q could not be converted to string", role.UserClaim)
319319
}
320320

321+
err := b.fetchUserInfo(allClaims, role)
322+
if err != nil {
323+
return nil, nil, err
324+
}
325+
321326
metadata, err := extractMetadata(b.Logger(), allClaims, role.ClaimMappings)
322327
if err != nil {
323328
return nil, nil, err
@@ -360,6 +365,22 @@ func (b *jwtAuthBackend) createIdentity(allClaims map[string]interface{}, role *
360365
return alias, groupAliases, nil
361366
}
362367

368+
// Checks if there's a custom provider_config and calls FetchUserInfo() if implemented.
369+
func (b *jwtAuthBackend) fetchUserInfo(allClaims map[string]interface{}, role *jwtRole) error {
370+
pConfig, err := NewProviderConfig(b.cachedConfig, ProviderMap())
371+
if err != nil {
372+
return fmt.Errorf("failed to load custom provider config: %s", err)
373+
}
374+
// Fetch user info from custom provider if it's implemented
375+
if pConfig != nil {
376+
if uif, ok := pConfig.(UserInfoFetcher); ok {
377+
return uif.FetchUserInfo(b, allClaims, role)
378+
}
379+
}
380+
381+
return nil
382+
}
383+
363384
// Checks if there's a custom provider_config and calls FetchGroups() if implemented
364385
func (b *jwtAuthBackend) fetchGroups(allClaims map[string]interface{}, role *jwtRole) (interface{}, error) {
365386
pConfig, err := NewProviderConfig(b.cachedConfig, ProviderMap())

provider_config.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ import (
1111
// ProviderMap returns a map of provider names to custom types
1212
func ProviderMap() map[string]CustomProvider {
1313
return map[string]CustomProvider{
14-
"azure": &AzureProvider{},
14+
"azure": &AzureProvider{},
15+
"gsuite": &GSuiteProvider{},
1516
}
1617
}
1718

@@ -48,6 +49,11 @@ func NewProviderConfig(jc *jwtConfig, providerMap map[string]CustomProvider) (Cu
4849
return newCustomProvider, nil
4950
}
5051

52+
// UserInfoFetcher - Optional support for custom user info handling
53+
type UserInfoFetcher interface {
54+
FetchUserInfo(*jwtAuthBackend, map[string]interface{}, *jwtRole) error
55+
}
56+
5157
// GroupsFetcher - Optional support for custom groups handling
5258
type GroupsFetcher interface {
5359
// FetchGroups queries for groups claims during login

provider_gsuite.go

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
package jwtauth
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"errors"
7+
"fmt"
8+
"io/ioutil"
9+
10+
"github.com/mitchellh/mapstructure"
11+
"golang.org/x/oauth2/google"
12+
"golang.org/x/oauth2/jwt"
13+
admin "google.golang.org/api/admin/directory/v1"
14+
"google.golang.org/api/option"
15+
)
16+
17+
// GSuiteProvider provides G Suite-specific configuration and behavior.
18+
type GSuiteProvider struct {
19+
config GSuiteProviderConfig // Configuration for the provider
20+
jwtConfig *jwt.Config // Google JWT configuration
21+
adminSvc *admin.Service // Google admin service
22+
}
23+
24+
// GSuiteProviderConfig represents the configuration for a GSuiteProvider.
25+
type GSuiteProviderConfig struct {
26+
// Path to a Google service account key file. Required.
27+
ServiceAccountFilePath string `mapstructure:"gsuite_service_account"`
28+
29+
// Email address of a G Suite admin to impersonate. Required.
30+
AdminImpersonateEmail string `mapstructure:"gsuite_admin_impersonate"`
31+
32+
// If set to true, groups will be fetched from G Suite.
33+
FetchGroups bool `mapstructure:"fetch_groups"`
34+
35+
// If set to true, user info will be fetched from G Suite using UserCustomSchemas.
36+
FetchUserInfo bool `mapstructure:"fetch_user_info"`
37+
38+
// Group membership recursion max depth (0 = do not recurse).
39+
GroupsRecurseMaxDepth int `mapstructure:"groups_recurse_max_depth"`
40+
41+
// Comma-separated list of G Suite custom schemas to fetch as claims.
42+
UserCustomSchemas string `mapstructure:"user_custom_schemas"`
43+
44+
// JSON contents of a Google service account key file.
45+
serviceAccountKeyJSON []byte
46+
}
47+
48+
// Initialize initializes the GSuiteProvider by validating and creating configuration.
49+
func (g *GSuiteProvider) Initialize(jc *jwtConfig) error {
50+
// Decode the provider config
51+
var config GSuiteProviderConfig
52+
if err := mapstructure.Decode(jc.ProviderConfig, &config); err != nil {
53+
return err
54+
}
55+
56+
// Read the Google service account key file
57+
keyJSON, err := ioutil.ReadFile(config.ServiceAccountFilePath)
58+
if err != nil {
59+
return err
60+
}
61+
config.serviceAccountKeyJSON = keyJSON
62+
63+
return g.initialize(config)
64+
}
65+
66+
func (g *GSuiteProvider) initialize(config GSuiteProviderConfig) error {
67+
var err error
68+
69+
// Validate configuration
70+
if config.ServiceAccountFilePath == "" {
71+
return errors.New("'gsuite_service_account' must be set to the file path for a " +
72+
"service account key")
73+
}
74+
if config.AdminImpersonateEmail == "" {
75+
return errors.New("'gsuite_admin_impersonate' must be set to an email address of a " +
76+
"G Suite user with 'Read' permission to access the G Suite Admin User and Group APIs")
77+
}
78+
if config.GroupsRecurseMaxDepth < 0 {
79+
return errors.New("'gsuite_recurse_max_depth' must be a positive integer")
80+
}
81+
82+
// Create the google JWT config from the service account key file
83+
if g.jwtConfig, err = google.JWTConfigFromJSON(config.serviceAccountKeyJSON,
84+
admin.AdminDirectoryGroupReadonlyScope, admin.AdminDirectoryUserReadonlyScope); err != nil {
85+
return err
86+
}
87+
88+
// Set the subject to impersonate and config
89+
g.jwtConfig.Subject = config.AdminImpersonateEmail
90+
g.config = config
91+
return nil
92+
}
93+
94+
// SensitiveKeys returns keys that should be redacted when reading the config of this provider
95+
func (g *GSuiteProvider) SensitiveKeys() []string {
96+
return []string{}
97+
}
98+
99+
// FetchGroups fetches and returns groups from G Suite.
100+
func (g *GSuiteProvider) FetchGroups(b *jwtAuthBackend, allClaims map[string]interface{}, role *jwtRole) (interface{}, error) {
101+
if !g.config.FetchGroups {
102+
return nil, nil
103+
}
104+
105+
userName, err := g.getUserClaim(b, allClaims, role)
106+
if err != nil {
107+
return nil, err
108+
}
109+
110+
// Set context and create a new admin service for requests to Google admin APIs
111+
g.adminSvc, err = admin.NewService(b.providerCtx, option.WithHTTPClient(g.jwtConfig.Client(b.providerCtx)))
112+
if err != nil {
113+
return nil, err
114+
}
115+
116+
// Get the G Suite groups
117+
userGroupsMap := make(map[string]bool)
118+
if err := g.search(b.providerCtx, userGroupsMap, userName, g.config.GroupsRecurseMaxDepth); err != nil {
119+
return nil, err
120+
}
121+
122+
// Convert set of groups to list
123+
var userGroups = make([]interface{}, 0, len(userGroupsMap))
124+
for email := range userGroupsMap {
125+
userGroups = append(userGroups, email)
126+
}
127+
128+
b.Logger().Debug("fetched G Suite groups", "groups", userGroups)
129+
return userGroups, nil
130+
}
131+
132+
// search recursively searches for G Suite groups based on a configured depth for this provider.
133+
func (g *GSuiteProvider) search(ctx context.Context, visited map[string]bool, userName string, depth int) error {
134+
call := g.adminSvc.Groups.List().UserKey(userName).Fields("nextPageToken", "groups(email)")
135+
if err := call.Pages(ctx, func(groups *admin.Groups) error {
136+
var newGroups []string
137+
for _, group := range groups.Groups {
138+
if _, ok := visited[group.Email]; ok {
139+
continue
140+
}
141+
visited[group.Email] = true
142+
newGroups = append(newGroups, group.Email)
143+
}
144+
// Only recursively search for new groups that haven't been seen
145+
if depth > 0 {
146+
for _, email := range newGroups {
147+
if err := g.search(ctx, visited, email, depth-1); err != nil {
148+
return err
149+
}
150+
}
151+
}
152+
return nil
153+
}); err != nil {
154+
return err
155+
}
156+
return nil
157+
}
158+
159+
// FetchUserInfo fetches additional user information from G Suite using custom schemas.
160+
func (g *GSuiteProvider) FetchUserInfo(b *jwtAuthBackend, allClaims map[string]interface{}, role *jwtRole) error {
161+
if !g.config.FetchUserInfo || g.config.UserCustomSchemas == "" {
162+
if g.config.UserCustomSchemas != "" {
163+
b.Logger().Warn(fmt.Sprintf("must set 'fetch_user_info=true' to fetch 'user_custom_schemas': %s", g.config.UserCustomSchemas))
164+
}
165+
166+
return nil
167+
}
168+
169+
userName, err := g.getUserClaim(b, allClaims, role)
170+
if err != nil {
171+
return err
172+
}
173+
174+
// Set context and create a new admin service for requests to Google admin APIs
175+
g.adminSvc, err = admin.NewService(b.providerCtx, option.WithHTTPClient(g.jwtConfig.Client(b.providerCtx)))
176+
if err != nil {
177+
return err
178+
}
179+
180+
return g.fillCustomSchemas(b.providerCtx, userName, allClaims)
181+
}
182+
183+
// fillCustomSchemas fetches G Suite user information associated with the custom schemas
184+
// configured for this provider. It inserts the schema -> value pairs into the passed
185+
// allClaims so that the values can be used for claim mapping to token and identity metadata.
186+
func (g *GSuiteProvider) fillCustomSchemas(ctx context.Context, userName string, allClaims map[string]interface{}) error {
187+
userResponse, err := g.adminSvc.Users.Get(userName).Context(ctx).Projection("custom").
188+
CustomFieldMask(g.config.UserCustomSchemas).Fields("customSchemas").Do()
189+
if err != nil {
190+
return err
191+
}
192+
193+
for schema, rawValue := range userResponse.CustomSchemas {
194+
// note: metadata extraction via claim_mappings only supports strings
195+
// as values, but filtering happens later so we must use interface{}
196+
var value map[string]interface{}
197+
if err := json.Unmarshal(rawValue, &value); err != nil {
198+
return err
199+
}
200+
201+
allClaims[schema] = value
202+
}
203+
204+
return nil
205+
}
206+
207+
// getUserClaim returns the user claim value configured in the passed role.
208+
// If the user claim is not found or is not a string, an error is returned.
209+
func (g *GSuiteProvider) getUserClaim(b *jwtAuthBackend, allClaims map[string]interface{}, role *jwtRole) (string, error) {
210+
userClaimRaw := getClaim(b.Logger(), allClaims, role.UserClaim)
211+
if userClaimRaw == nil {
212+
return "", fmt.Errorf("unable to locate %q in claims", role.UserClaim)
213+
}
214+
userClaim, ok := userClaimRaw.(string)
215+
if !ok {
216+
return "", fmt.Errorf("claim %q could not be converted to string", role.UserClaim)
217+
}
218+
219+
return userClaim, nil
220+
}

0 commit comments

Comments
 (0)