Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions pkg/cfg/cfg.go
Original file line number Diff line number Diff line change
Expand Up @@ -414,10 +414,12 @@ func basicTest() error {
return errors.New("configuration error: required configuration option 'oauth.client_id' is not set")
}

// Domains is required _unless_ Cfg.AllowAllUsers is set
if (!Cfg.AllowAllUsers && len(Cfg.Domains) == 0) ||
(Cfg.AllowAllUsers && len(Cfg.Domains) > 0) {
return fmt.Errorf("configuration error: either one of %s or %s needs to be set (but not both)", Branding.LCName+".domains", Branding.LCName+".allowAllUsers")
// Domains or a whitelist is required _unless_ Cfg.AllowAllUsers is set
whitelistLength := len(Cfg.Domains) + len(Cfg.WhiteList) + len(Cfg.TeamWhiteList)
if (!Cfg.AllowAllUsers && whitelistLength == 0) ||
(Cfg.AllowAllUsers && whitelistLength > 0) {
return fmt.Errorf("configuration error: either %s.allowAllUsers or a whitelist (%s.domains, %s.whitelist, %s.teamWhitelist) needs to be set (but not both)",
Branding.LCName, Branding.LCName, Branding.LCName, Branding.LCName)
}

// issue a warning if the secret is too small
Expand Down
2 changes: 2 additions & 0 deletions pkg/providers/common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ func Configure() {
log = cfg.Logging.Logger
}

type PrepareTokensAndClientT func(r *http.Request, ptokens *structs.PTokens, setProviderToken bool, opts ...oauth2.AuthCodeOption) (*http.Client, *oauth2.Token, error)

// PrepareTokensAndClient setup the client, usually for a UserInfo request
func PrepareTokensAndClient(r *http.Request, ptokens *structs.PTokens, setProviderToken bool, opts ...oauth2.AuthCodeOption) (*http.Client, *oauth2.Token, error) {
providerToken, err := cfg.OAuthClient.Exchange(context.TODO(), r.URL.Query().Get("code"), opts...)
Expand Down
2 changes: 1 addition & 1 deletion pkg/providers/github/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (

// Provider provider specific functions
type Provider struct {
PrepareTokensAndClient func(r *http.Request, ptokens *structs.PTokens, setProviderToken bool, opts ...oauth2.AuthCodeOption) (*http.Client, *oauth2.Token, error)
PrepareTokensAndClient common.PrepareTokensAndClientT
}

var log *zap.SugaredLogger
Expand Down
42 changes: 39 additions & 3 deletions pkg/providers/openid/openid.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ package openid

import (
"encoding/json"
"golang.org/x/oauth2"
"io/ioutil"
"net/http"

"golang.org/x/oauth2"

"github.com/vouch/vouch-proxy/pkg/cfg"
"github.com/vouch/vouch-proxy/pkg/providers/common"
"github.com/vouch/vouch-proxy/pkg/structs"
Expand All @@ -25,16 +26,22 @@ import (
// Provider provider specific functions
type Provider struct{}

var log *zap.SugaredLogger
var (
log *zap.SugaredLogger
prepareTokensAndClient = common.PrepareTokensAndClient
)

// Configure see main.go configure()
func (Provider) Configure() {
log = cfg.Logging.Logger
if prepareTokensAndClient == nil {
prepareTokensAndClient = common.PrepareTokensAndClient
}
}

// GetUserInfo provider specific call to get userinfomation
func (Provider) GetUserInfo(r *http.Request, user *structs.User, customClaims *structs.CustomClaims, ptokens *structs.PTokens, opts ...oauth2.AuthCodeOption) (rerr error) {
client, _, err := common.PrepareTokensAndClient(r, ptokens, true, opts...)
client, _, err := prepareTokensAndClient(r, ptokens, true, opts...)
if err != nil {
return err
}
Expand All @@ -57,6 +64,35 @@ func (Provider) GetUserInfo(r *http.Request, user *structs.User, customClaims *s
log.Error(err)
return err
}

if cfg.GenOAuth.UserTeamURL != "" {
log.Infof("OpenID teams URL: %s", cfg.GenOAuth.UserTeamURL)
teams, err := client.Get(cfg.GenOAuth.UserTeamURL)
if err != nil {
return err
}
defer func() {
if err := teams.Body.Close(); err != nil {
rerr = err
}
}()
teamsdata, _ := ioutil.ReadAll(teams.Body)
log.Infof("OpenID teams body (%v): %v", teams.StatusCode, string(teamsdata))

teamMemberships := &structs.GenericTeamMembershipList{}
if err := json.Unmarshal(teamsdata, teamMemberships); err != nil {
return err
}
for _, m := range *teamMemberships {
// filter to requested/whitelisted memberships only
for _, wl := range cfg.Cfg.TeamWhiteList {
if wl == m.ID {
user.TeamMemberships = append(user.TeamMemberships, m.ID)
}
}
}
}

user.PrepareUserData()
return nil
}
109 changes: 109 additions & 0 deletions pkg/providers/openid/openid_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*

Copyright 2023 The Vouch Proxy Authors.
Use of this source code is governed by The MIT License (MIT) that
can be found in the LICENSE file. Software distributed under The
MIT License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
OR CONDITIONS OF ANY KIND, either express or implied.

*/

package openid

import (
"net/http"
"testing"

mockhttp "github.com/karupanerura/go-mock-http-response"
"github.com/stretchr/testify/assert"
"github.com/vouch/vouch-proxy/pkg/cfg"
"github.com/vouch/vouch-proxy/pkg/domains"
"github.com/vouch/vouch-proxy/pkg/structs"
"golang.org/x/oauth2"
)

type ReqMatcher func(*http.Request) bool

type FunResponsePair struct {
matcher ReqMatcher
response *mockhttp.ResponseMock
}

type Transport struct {
MockError error
}

func (c *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
if c.MockError != nil {
return nil, c.MockError
}
for _, p := range mockedResponses {
if p.matcher(req) {
requests = append(requests, req.URL.String())
return p.response.MakeResponse(req), nil
}
}
return nil, nil
}

func mockResponse(fun ReqMatcher, statusCode int, headers map[string]string, body []byte) {
mockedResponses = append(mockedResponses, FunResponsePair{matcher: fun, response: mockhttp.NewResponseMock(statusCode, headers, body)})
}

func urlEquals(value string) ReqMatcher {
return func(r *http.Request) bool {
return r.URL.String() == value
}
}

var (
user *structs.User
token = &oauth2.Token{AccessToken: "123"}
mockedResponses = []FunResponsePair{}
requests []string
client = &http.Client{Transport: &Transport{}}
)

func setUp(t *testing.T) {
log = cfg.Logging.Logger
cfg.InitForTestPurposesWithProvider("openid")

cfg.Cfg.AllowAllUsers = false
cfg.Cfg.WhiteList = make([]string, 0)
cfg.Cfg.TeamWhiteList = make([]string, 0)
cfg.Cfg.Domains = []string{"domain1"}

domains.Configure()

mockedResponses = []FunResponsePair{}
requests = make([]string, 0)

user = &structs.User{Username: "testuser", Email: "[email protected]"}

origPrepareTokensAndClient := prepareTokensAndClient
t.Cleanup(func() { prepareTokensAndClient = origPrepareTokensAndClient })
prepareTokensAndClient = func(_ *http.Request, _ *structs.PTokens, _ bool, opts ...oauth2.AuthCodeOption) (*http.Client, *oauth2.Token, error) {
return client, token, nil
}
}

func TestGetUserInfo(t *testing.T) {
setUp(t)

cfg.GenOAuth.UserInfoURL = "https://some/api/for/info"
userInfoContent := []byte(`{"id": "1234", "username": "myusername", "email": "[email protected]"}`)
mockResponse(urlEquals(cfg.GenOAuth.UserInfoURL), http.StatusOK, map[string]string{}, userInfoContent)

cfg.GenOAuth.UserTeamURL = "https://some/api/for/teams"
userTeamContent := []byte(`[{"id": "1234567890", "name": "some room name"}, {"id": "xxx-not-relevant", "name": "some other room"}]`)
mockResponse(urlEquals(cfg.GenOAuth.UserTeamURL), http.StatusOK, map[string]string{}, userTeamContent)

cfg.Cfg.TeamWhiteList = append(cfg.Cfg.TeamWhiteList, "1234567890", "some-other-team")

provider := Provider{}
err := provider.GetUserInfo(nil, user, &structs.CustomClaims{}, &structs.PTokens{})

assert.Nil(t, err)
assert.Equal(t, "myusername", user.Username)
assert.Equal(t, []string{"1234567890"}, user.TeamMemberships)
}
8 changes: 7 additions & 1 deletion pkg/structs/structs.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ func (u *ADFSUser) PrepareUserData() {
u.Username = u.UPN
}

type GenericTeamMembershipList []GenericTeamMembership

type GenericTeamMembership struct {
ID string `json:"id"`
}

// GitHubUser is a retrieved and authentiacted user from GitHub.
type GitHubUser struct {
User
Expand Down Expand Up @@ -148,7 +154,7 @@ type Contact struct {
Verified bool `json:"is_verified"`
}

//OpenStaxUser is a retrieved and authenticated user from OpenStax Accounts
// OpenStaxUser is a retrieved and authenticated user from OpenStax Accounts
type OpenStaxUser struct {
User
Contacts []Contact `json:"contact_infos"`
Expand Down