-
Notifications
You must be signed in to change notification settings - Fork 606
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add Azure Active Directory V2.0 support
- Loading branch information
Showing
6 changed files
with
1,112 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,3 +27,4 @@ generated/ | |
*.swp | ||
.vscode/launch.json | ||
.vscode/settings.json | ||
.idea |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,224 @@ | ||
package azureadv2 | ||
|
||
import ( | ||
"encoding/json" | ||
"fmt" | ||
"io" | ||
"net/http" | ||
|
||
"github.com/markbates/goth" | ||
"golang.org/x/oauth2" | ||
) | ||
|
||
// also https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-v2-protocols#endpoints | ||
const ( | ||
authURLTemplate string = "https://login.microsoftonline.com/%s/oauth2/v2.0/authorize" | ||
tokenURLTemplate string = "https://login.microsoftonline.com/%s/oauth2/v2.0/token" | ||
graphAPIResource string = "https://graph.microsoft.com/v1.0/" | ||
) | ||
|
||
type ( | ||
// TenantType are the well known tenant types to scope the users that can authenticate. TenantType is not an | ||
// exclusive list of Azure Tenants which can be used. A consumer can also use their own Tenant ID to scope | ||
// authentication to their specific Tenant either through the Tenant ID or the friendly domain name. | ||
// | ||
// see also https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-v2-protocols#endpoints | ||
TenantType string | ||
|
||
// Provider is the implementation of `goth.Provider` for accessing AzureAD V2. | ||
Provider struct { | ||
ClientKey string | ||
Secret string | ||
CallbackURL string | ||
HTTPClient *http.Client | ||
config *oauth2.Config | ||
providerName string | ||
} | ||
|
||
// ProviderOptions are the collection of optional configuration to provide when constructing a Provider | ||
ProviderOptions struct { | ||
Scopes []ScopeType | ||
Tenant TenantType | ||
} | ||
) | ||
|
||
// These are the well known Azure AD Tenants. These are not an exclusive list of all Tenants | ||
// | ||
// See also https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-v2-protocols#endpoints | ||
const ( | ||
// CommonTenant allows users with both personal Microsoft accounts and work/school accounts from Azure Active | ||
// Directory to sign into the application. | ||
CommonTenant TenantType = "common" | ||
|
||
// OrganizationsTenant allows only users with work/school accounts from Azure Active Directory to sign into the application. | ||
OrganizationsTenant TenantType = "organizations" | ||
|
||
// ConsumersTenant allows only users with personal Microsoft accounts (MSA) to sign into the application. | ||
ConsumersTenant TenantType = "consumers" | ||
) | ||
|
||
// New creates a new AzureAD provider, and sets up important connection details. | ||
// You should always call `AzureAD.New` to get a new Provider. Never try to create | ||
// one manually. | ||
func New(clientKey, secret, callbackURL string, opts ProviderOptions) *Provider { | ||
p := &Provider{ | ||
ClientKey: clientKey, | ||
Secret: secret, | ||
CallbackURL: callbackURL, | ||
providerName: "azureadv2", | ||
} | ||
|
||
p.config = newConfig(p, opts) | ||
return p | ||
} | ||
|
||
func newConfig(provider *Provider, opts ProviderOptions) *oauth2.Config { | ||
tenant := opts.Tenant | ||
if tenant == "" { | ||
tenant = CommonTenant | ||
} | ||
|
||
c := &oauth2.Config{ | ||
ClientID: provider.ClientKey, | ||
ClientSecret: provider.Secret, | ||
RedirectURL: provider.CallbackURL, | ||
Endpoint: oauth2.Endpoint{ | ||
AuthURL: fmt.Sprintf(authURLTemplate, tenant), | ||
TokenURL: fmt.Sprintf(tokenURLTemplate, tenant), | ||
}, | ||
Scopes: []string{}, | ||
} | ||
|
||
if len(opts.Scopes) > 0 { | ||
c.Scopes = append(c.Scopes, scopesToStrings(opts.Scopes...)...) | ||
} else { | ||
defaultScopes := scopesToStrings(OpenIDScope, ProfileScope, EmailScope, UserReadScope) | ||
c.Scopes = append(c.Scopes, defaultScopes...) | ||
} | ||
|
||
return c | ||
} | ||
|
||
// Name is the name used to retrieve this provider later. | ||
func (p *Provider) Name() string { | ||
return p.providerName | ||
} | ||
|
||
// SetName is to update the name of the provider (needed in case of multiple providers of 1 type) | ||
func (p *Provider) SetName(name string) { | ||
p.providerName = name | ||
} | ||
|
||
// Client is HTTP client to be used in all fetch operations. | ||
func (p *Provider) Client() *http.Client { | ||
return goth.HTTPClientWithFallBack(p.HTTPClient) | ||
} | ||
|
||
// Debug is a no-op for the package | ||
func (p *Provider) Debug(debug bool) {} | ||
|
||
// BeginAuth asks for an authentication end-point for AzureAD. | ||
func (p *Provider) BeginAuth(state string) (goth.Session, error) { | ||
authURL := p.config.AuthCodeURL(state) | ||
|
||
return &Session{ | ||
AuthURL: authURL, | ||
}, nil | ||
} | ||
|
||
// FetchUser will go to AzureAD and access basic information about the user. | ||
func (p *Provider) FetchUser(session goth.Session) (goth.User, error) { | ||
msSession := session.(*Session) | ||
user := goth.User{ | ||
AccessToken: msSession.AccessToken, | ||
Provider: p.Name(), | ||
ExpiresAt: msSession.ExpiresAt, | ||
} | ||
|
||
if user.AccessToken == "" { | ||
return user, fmt.Errorf("%s cannot get user information without accessToken", p.providerName) | ||
} | ||
|
||
req, err := http.NewRequest("GET", graphAPIResource+"me", nil) | ||
if err != nil { | ||
return user, err | ||
} | ||
|
||
req.Header.Set(authorizationHeader(msSession)) | ||
|
||
response, err := p.Client().Do(req) | ||
if err != nil { | ||
return user, err | ||
} | ||
defer response.Body.Close() | ||
|
||
if response.StatusCode != http.StatusOK { | ||
return user, fmt.Errorf("%s responded with a %d trying to fetch user information", p.providerName, response.StatusCode) | ||
} | ||
|
||
err = userFromReader(response.Body, &user) | ||
user.AccessToken = msSession.AccessToken | ||
user.RefreshToken = msSession.RefreshToken | ||
user.ExpiresAt = msSession.ExpiresAt | ||
return user, err | ||
} | ||
|
||
//RefreshTokenAvailable refresh token is provided by auth provider or not | ||
func (p *Provider) RefreshTokenAvailable() bool { | ||
return true | ||
} | ||
|
||
//RefreshToken get new access token based on the refresh token | ||
func (p *Provider) RefreshToken(refreshToken string) (*oauth2.Token, error) { | ||
token := &oauth2.Token{RefreshToken: refreshToken} | ||
ts := p.config.TokenSource(goth.ContextForClient(p.Client()), token) | ||
newToken, err := ts.Token() | ||
if err != nil { | ||
return nil, err | ||
} | ||
return newToken, err | ||
} | ||
|
||
func authorizationHeader(session *Session) (string, string) { | ||
return "Authorization", fmt.Sprintf("Bearer %s", session.AccessToken) | ||
} | ||
|
||
func userFromReader(r io.Reader, user *goth.User) error { | ||
u := struct { | ||
ID string `json:"id"` // The unique identifier for the user. | ||
BusinessPhone string `json:"businessPhone"` // The user's phone numbers. | ||
DisplayName string `json:"displayName"` // The name displayed in the address book for the user. | ||
FirstName string `json:"givenName"` // The first name of the user. | ||
JobTitle string `json:"jobTitle"` // The user's job title. | ||
Email string `json:"mail"` // The user's email address. | ||
MobilePhone string `json:"mobilePhone"` // The user's cellphone number. | ||
OfficeLocation string `json:"officeLocation"` // The user's physical office location. | ||
PreferredLanguage string `json:"preferredLanguage"` // The user's language of preference. | ||
LastName string `json:"surname"` // The last name of the user. | ||
UserPrincipalName string `json:"userPrincipalName"` // The user's principal name. | ||
}{} | ||
|
||
err := json.NewDecoder(r).Decode(&u) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
user.Email = u.Email | ||
user.Name = u.DisplayName | ||
user.FirstName = u.FirstName | ||
user.LastName = u.LastName | ||
user.NickName = u.DisplayName | ||
user.Location = u.OfficeLocation | ||
user.UserID = u.ID | ||
user.AvatarURL = graphAPIResource + fmt.Sprintf("users/%s/photo/$value", u.ID) | ||
|
||
return nil | ||
} | ||
|
||
func scopesToStrings(scopes ...ScopeType) []string { | ||
strs := make([]string, len(scopes)) | ||
for i := 0; i < len(scopes); i++ { | ||
strs[i] = string(scopes[i]) | ||
} | ||
return strs | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
package azureadv2_test | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/markbates/goth" | ||
"github.com/markbates/goth/providers/azureadv2" | ||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
const ( | ||
applicationID = "6731de76-14a6-49ae-97bc-6eba6914391e" | ||
secret = "foo" | ||
redirectUri = "https://localhost:3000" | ||
) | ||
|
||
func Test_New(t *testing.T) { | ||
t.Parallel() | ||
a := assert.New(t) | ||
provider := azureadProvider() | ||
|
||
a.Equal(provider.Name(), "azureadv2") | ||
a.Equal(provider.ClientKey, applicationID) | ||
a.Equal(provider.Secret, secret) | ||
a.Equal(provider.CallbackURL, redirectUri) | ||
} | ||
|
||
func Test_Implements_Provider(t *testing.T) { | ||
t.Parallel() | ||
a := assert.New(t) | ||
p := azureadProvider() | ||
a.Implements((*goth.Provider)(nil), p) | ||
} | ||
|
||
func Test_BeginAuth(t *testing.T) { | ||
t.Parallel() | ||
a := assert.New(t) | ||
provider := azureadProvider() | ||
session, err := provider.BeginAuth("test_state") | ||
a.NoError(err) | ||
s := session.(*azureadv2.Session) | ||
a.Contains(s.AuthURL, "login.microsoftonline.com/common/oauth2/v2.0/authorize") | ||
a.Contains(s.AuthURL, "redirect_uri=https%3A%2F%2Flocalhost%3A3000") | ||
a.Contains(s.AuthURL, "scope=openid+profile+email") | ||
} | ||
|
||
func Test_SessionFromJSON(t *testing.T) { | ||
t.Parallel() | ||
a := assert.New(t) | ||
|
||
provider := azureadProvider() | ||
session, err := provider.UnmarshalSession(`{"au":"http://foo","at":"1234567890"}`) | ||
a.NoError(err) | ||
|
||
s := session.(*azureadv2.Session) | ||
a.Equal(s.AuthURL, "http://foo") | ||
a.Equal(s.AccessToken, "1234567890") | ||
} | ||
|
||
func azureadProvider() *azureadv2.Provider { | ||
return azureadv2.New(applicationID, secret, redirectUri, azureadv2.ProviderOptions{}) | ||
} |
Oops, something went wrong.