Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 100 additions & 28 deletions connector/google/google.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,17 @@ import (
"log/slog"
"net/http"
"os"
"sort"
"strings"
"sync"
"time"

"cloud.google.com/go/compute/metadata"
"github.com/coreos/go-oidc/v3/oidc"
"golang.org/x/exp/slices"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"golang.org/x/sync/errgroup"
admin "google.golang.org/api/admin/directory/v1"
"google.golang.org/api/impersonate"
"google.golang.org/api/option"
Expand All @@ -27,6 +30,10 @@ import (
const (
issuerURL = "https://accounts.google.com"
wildcardDomainToAdminEmail = "*"

// defaultConcurrentGroupLookups is the limit used when Config.MaxConcurrentGroupLookups
// is zero or negative.
defaultConcurrentGroupLookups = 10
)

// Config holds configuration options for Google logins.
Expand Down Expand Up @@ -61,6 +68,10 @@ type Config struct {
// If this field is true, fetch direct group membership and transitive group membership
FetchTransitiveGroupMembership bool `json:"fetchTransitiveGroupMembership"`

// MaxConcurrentGroupLookups limits concurrent Admin Directory API calls when resolving
// transitive group membership. If zero or negative, the connector default limit applies.
MaxConcurrentGroupLookups int `json:"maxConcurrentGroupLookups"`

// Optional value for the prompt parameter, defaults to consent when offline_access
// scope is requested
PromptType *string `json:"promptType"`
Expand Down Expand Up @@ -119,6 +130,11 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector,
}

clientID := c.ClientID
maxConcurrent := c.MaxConcurrentGroupLookups
if maxConcurrent <= 0 {
maxConcurrent = defaultConcurrentGroupLookups
}

return &googleConnector{
redirectURI: c.RedirectURI,
oauth2Config: &oauth2.Config{
Expand All @@ -138,6 +154,7 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector,
serviceAccountFilePath: c.ServiceAccountFilePath,
domainToAdminEmail: c.DomainToAdminEmail,
fetchTransitiveGroupMembership: c.FetchTransitiveGroupMembership,
maxConcurrentGroupLookups: maxConcurrent,
adminSrv: adminSrv,
promptType: promptType,
}, nil
Expand All @@ -159,6 +176,7 @@ type googleConnector struct {
serviceAccountFilePath string
domainToAdminEmail map[string]string
fetchTransitiveGroupMembership bool
maxConcurrentGroupLookups int
adminSrv map[string]*admin.Service
promptType string
}
Expand Down Expand Up @@ -272,8 +290,7 @@ func (c *googleConnector) createIdentity(ctx context.Context, identity connector

var groups []string
if s.Groups && len(c.adminSrv) > 0 {
checkedGroups := make(map[string]struct{})
groups, err = c.getGroups(claims.Email, c.fetchTransitiveGroupMembership, checkedGroups)
groups, err = c.getGroups(ctx, claims.Email, c.fetchTransitiveGroupMembership)
if err != nil {
return identity, fmt.Errorf("google: could not retrieve groups: %v", err)
}
Expand All @@ -298,52 +315,107 @@ func (c *googleConnector) createIdentity(ctx context.Context, identity connector
}

// getGroups creates a connection to the admin directory service and lists
// all groups the user is a member of
func (c *googleConnector) getGroups(email string, fetchTransitiveGroupMembership bool, checkedGroups map[string]struct{}) ([]string, error) {
var userGroups []string
var err error
groupsList := &admin.Groups{}
domain := c.extractDomainFromEmail(email)
adminSrv, err := c.findAdminService(domain)
// all groups the user is a member of.
func (c *googleConnector) getGroups(ctx context.Context, email string, fetchTransitiveGroupMembership bool) ([]string, error) {
directGroups, err := c.listGroupEmails(ctx, email)
if err != nil {
return nil, err
}

for {
groupsList, err = adminSrv.Groups.List().
UserKey(email).PageToken(groupsList.NextPageToken).Do()
if err != nil {
return nil, fmt.Errorf("could not list groups: %v", err)
var seenMu sync.Mutex
seen := make(map[string]struct{})
userGroups := make([]string, 0, len(directGroups))
addGroup := func(groupEmail string) bool {
seenMu.Lock()
defer seenMu.Unlock()
if _, exists := seen[groupEmail]; exists {
return false
}
seen[groupEmail] = struct{}{}
// TODO (joelspeed): Make desired group key configurable
userGroups = append(userGroups, groupEmail)
return true
}

for _, group := range groupsList.Groups {
if _, exists := checkedGroups[group.Email]; exists {
continue
}
seeds := make([]string, 0, len(directGroups))
for _, groupEmail := range directGroups {
if addGroup(groupEmail) {
seeds = append(seeds, groupEmail)
}
}

if !fetchTransitiveGroupMembership || len(seeds) == 0 {
Comment thread
codrutpanea marked this conversation as resolved.
sort.Strings(userGroups)
return userGroups, nil
}

checkedGroups[group.Email] = struct{}{}
// TODO (joelspeed): Make desired group key configurable
userGroups = append(userGroups, group.Email)
apiSem := make(chan struct{}, c.maxConcurrentGroupLookups)
g, gctx := errgroup.WithContext(ctx)

if !fetchTransitiveGroupMembership {
continue
var enqueue func(string)
enqueue = func(groupEmail string) {
g.Go(func() error {
if err := gctx.Err(); err != nil {
return err
}
select {
case <-gctx.Done():
return gctx.Err()
case apiSem <- struct{}{}:
}
defer func() { <-apiSem }()

// getGroups takes a user's email/alias as well as a group's email/alias
transitiveGroups, err := c.getGroups(group.Email, fetchTransitiveGroupMembership, checkedGroups)
parentGroups, err := c.listGroupEmails(gctx, groupEmail)
if err != nil {
return nil, fmt.Errorf("could not list transitive groups: %v", err)
return fmt.Errorf("could not list transitive groups: %w", err)
}
for _, parent := range parentGroups {
if addGroup(parent) {
enqueue(parent)
}
}
return nil
})
}

for _, groupEmail := range seeds {
enqueue(groupEmail)
}

if err := g.Wait(); err != nil {
return nil, err
}

userGroups = append(userGroups, transitiveGroups...)
sort.Strings(userGroups)
return userGroups, nil
}

func (c *googleConnector) listGroupEmails(ctx context.Context, userKey string) ([]string, error) {
domain := c.extractDomainFromEmail(userKey)
adminSrv, err := c.findAdminService(domain)
if err != nil {
return nil, err
}

groupEmails := []string{}
groupsList := &admin.Groups{}
for {
groupsList, err = adminSrv.Groups.List().
UserKey(userKey).PageToken(groupsList.NextPageToken).Context(ctx).Do()
if err != nil {
return nil, fmt.Errorf("could not list groups: %v", err)
}

for _, group := range groupsList.Groups {
groupEmails = append(groupEmails, group.Email)
}

if groupsList.NextPageToken == "" {
break
}
}

return userGroups, nil
return groupEmails, nil
}

func (c *googleConnector) findAdminService(domain string) (*admin.Service, error) {
Expand Down
78 changes: 67 additions & 11 deletions connector/google/google_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ import (
"net/url"
"os"
"strings"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
admin "google.golang.org/api/admin/directory/v1"
Expand All @@ -32,7 +34,8 @@ var (
"groups_2@dexidp.com": {{Email: "groups_0@dexidp.com"}},
"groups_0@dexidp.com": {},
}
callCounter = make(map[string]int)
callCounterMu sync.Mutex
Comment thread
codrutpanea marked this conversation as resolved.
callCounter = make(map[string]int)
)

func testSetup() *httptest.Server {
Expand All @@ -43,7 +46,9 @@ func testSetup() *httptest.Server {
userKey := r.URL.Query().Get("userKey")
if groups, ok := testGroups[userKey]; ok {
json.NewEncoder(w).Encode(admin.Groups{Groups: groups})
callCounterMu.Lock()
callCounter[userKey]++
callCounterMu.Unlock()
}
})

Expand Down Expand Up @@ -224,23 +229,71 @@ func TestGetGroups(t *testing.T) {
},
} {
testCase := testCase
callCounter = map[string]int{}
callCounterMu.Lock()
callCounter = make(map[string]int)
callCounterMu.Unlock()
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
lookup := make(map[string]struct{})

groups, err := conn.getGroups(testCase.userKey, testCase.fetchTransitiveGroupMembership, lookup)
groups, err := conn.getGroups(context.Background(), testCase.userKey, testCase.fetchTransitiveGroupMembership)
if testCase.shouldErr {
assert.NotNil(err)
} else {
assert.Nil(err)
}
assert.ElementsMatch(testCase.expectedGroups, groups)
t.Logf("[%s] Amount of API calls per userKey: %+v\n", t.Name(), callCounter)
callCounterMu.Lock()
s := fmt.Sprintf("%+v", callCounter)
callCounterMu.Unlock()
t.Logf("[%s] Amount of API calls per userKey: %s\n", t.Name(), s)
})
}
}

// Regression test for MaxConcurrentGroupLookups=1 with a user -> A -> B membership chain.
func TestGetGroups_transitiveNoDeadlockAtConcurrentLimitOne(t *testing.T) {
chain := map[string][]*admin.Group{
"user_chain@dexidp.com": {{Email: "group_a@dexidp.com"}},
"group_a@dexidp.com": {{Email: "group_b@dexidp.com"}},
"group_b@dexidp.com": {},
}

mux := http.NewServeMux()
mux.HandleFunc("/admin/directory/v1/groups/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Content-Type", "application/json")
userKey := r.URL.Query().Get("userKey")
if groups, ok := chain[userKey]; ok {
_ = json.NewEncoder(w).Encode(admin.Groups{Groups: groups})
}
})
ts := httptest.NewServer(mux)
defer ts.Close()

serviceAccountFilePath, err := tempServiceAccountKey()
assert.Nil(t, err)

os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", serviceAccountFilePath)
conn, err := newConnector(&Config{
ClientID: "testClient",
ClientSecret: "testSecret",
RedirectURI: ts.URL + "/callback",
Scopes: []string{"openid", "groups"},
DomainToAdminEmail: map[string]string{"*": "admin@dexidp.com"},
MaxConcurrentGroupLookups: 1,
})
assert.Nil(t, err)

conn.adminSrv[wildcardDomainToAdminEmail], err = admin.NewService(context.Background(), option.WithoutAuthentication(), option.WithEndpoint(ts.URL))
assert.Nil(t, err)

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

groups, err := conn.getGroups(ctx, "user_chain@dexidp.com", true)
assert.Nil(t, err)
assert.Equal(t, []string{"group_a@dexidp.com", "group_b@dexidp.com"}, groups)
}

func TestDomainToAdminEmailConfig(t *testing.T) {
ts := testSetup()
defer ts.Close()
Expand Down Expand Up @@ -280,18 +333,22 @@ func TestDomainToAdminEmailConfig(t *testing.T) {
},
} {
testCase := testCase
callCounter = map[string]int{}
callCounterMu.Lock()
callCounter = make(map[string]int)
callCounterMu.Unlock()
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
lookup := make(map[string]struct{})

_, err := conn.getGroups(testCase.userKey, true, lookup)
_, err := conn.getGroups(context.Background(), testCase.userKey, true)
if testCase.expectedErr != "" {
assert.ErrorContains(err, testCase.expectedErr)
} else {
assert.Nil(err)
}
t.Logf("[%s] Amount of API calls per userKey: %+v\n", t.Name(), callCounter)
callCounterMu.Lock()
s := fmt.Sprintf("%+v", callCounter)
callCounterMu.Unlock()
t.Logf("[%s] Amount of API calls per userKey: %s\n", t.Name(), s)
})
}
}
Expand Down Expand Up @@ -381,9 +438,8 @@ func TestGCEWorkloadIdentity(t *testing.T) {
} {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
lookup := make(map[string]struct{})

_, err := conn.getGroups(testCase.userKey, true, lookup)
_, err := conn.getGroups(context.Background(), testCase.userKey, true)
if testCase.expectedErr != "" {
assert.ErrorContains(err, testCase.expectedErr)
} else {
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ require (
golang.org/x/exp v0.0.0-20240823005443-9b4947da3948
golang.org/x/net v0.53.0
golang.org/x/oauth2 v0.36.0
golang.org/x/sync v0.20.0
google.golang.org/api v0.277.0
google.golang.org/grpc v1.80.0
google.golang.org/protobuf v1.36.11
Expand Down Expand Up @@ -141,7 +142,6 @@ require (
go.uber.org/zap v1.27.0 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect
golang.org/x/mod v0.34.0 // indirect
golang.org/x/sync v0.20.0 // indirect
golang.org/x/sys v0.43.0 // indirect
golang.org/x/text v0.36.0 // indirect
golang.org/x/time v0.15.0 // indirect
Expand Down