diff --git a/pkg/cfg/oauth.go b/pkg/cfg/oauth.go index 5b4e90b3..b67dede9 100644 --- a/pkg/cfg/oauth.go +++ b/pkg/cfg/oauth.go @@ -85,6 +85,7 @@ type oauthConfig struct { PreferredDomain string `mapstructure:"preferredDomain"` AzureToken string `mapstructure:"azure_token" envconfig:"azure_token"` CodeChallengeMethod string `mapstructure:"code_challenge_method" envconfig:"code_challenge_method"` + TeamWhiteListClaim string `mapstructure:"team_whitelist_claim" envconfig:"team_whitelist_claim"` // DiscordUseIDs defaults to false, maintaining the more common username checking behavior // If set to true, match the Discord user's ID instead of their username DiscordUseIDs bool `mapstructure:"discord_use_ids" envconfig:"discord_use_ids"` diff --git a/pkg/providers/openid/openid.go b/pkg/providers/openid/openid.go index d5c9c5e7..a20d93f2 100644 --- a/pkg/providers/openid/openid.go +++ b/pkg/providers/openid/openid.go @@ -33,6 +33,25 @@ func (Provider) Configure() { log = cfg.Logging.Logger } +func GenerateTeamsOfUser(customClaims *structs.CustomClaims, claimName string) map[string]bool { + teamOutput := make(map[string]bool) + if val, ok := customClaims.Claims[claimName]; ok { + + customClaimsSlice := val.([]interface{}) + + for _, teamValue := range customClaimsSlice { + team, isMyType := teamValue.(string) + if isMyType { + teamOutput[team] = true + } + } + + return teamOutput + } + log.Debugf("Claim %s missing from UserInfo response. Make sure you include the correct scope", claimName) + return teamOutput +} + // GetUserInfo provider specific call to get userinfomation func (Provider) GetUserInfo(r *http.Request, user *structs.User, customClaims *structs.CustomClaims, ptokens *structs.PTokens, opts ...oauth2.AuthCodeOption) (rerr error) { client, _, err := common.PrepareTokensAndClient(r, ptokens, true, opts...) @@ -59,5 +78,17 @@ func (Provider) GetUserInfo(r *http.Request, user *structs.User, customClaims *s return err } user.PrepareUserData() + + if len(cfg.Cfg.TeamWhiteList) != 0 && len(cfg.GenOAuth.TeamWhiteListClaim) != 0 { + allTeamsOfUser := GenerateTeamsOfUser(customClaims, cfg.GenOAuth.TeamWhiteListClaim) + + for _, whiteListedTeam := range cfg.Cfg.TeamWhiteList { + if allTeamsOfUser[whiteListedTeam] { + user.TeamMemberships = append(user.TeamMemberships, whiteListedTeam) + } + } + } + log.Debug("getUserInfoFromOAuth") + log.Debug(user) return nil }