diff --git a/e2e_test/e2e_test.go b/e2e_test/e2e_test.go index 20449e0..f5290b3 100644 --- a/e2e_test/e2e_test.go +++ b/e2e_test/e2e_test.go @@ -164,6 +164,83 @@ func TestRedirectURLHostname(t *testing.T) { wg.Wait() } +func TestRedirectURL(t *testing.T) { + ctx, cancel := context.WithTimeout(context.TODO(), 1*time.Second) + defer cancel() + openBrowserCh := make(chan string) + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + defer close(openBrowserCh) + redirectURL := "http://localhost:8888" + + // Start a local server and get a token. + s := httptest.NewServer(&authserver.Handler{ + T: t, + NewAuthorizationResponse: func(r authserver.AuthorizationRequest) string { + if w := "email profile"; r.Scope != w { + t.Errorf("scope wants %s but %s", w, r.Scope) + return fmt.Sprintf("%s?error=invalid_scope", r.RedirectURI) + } + if r.RedirectURI != redirectURL { + t.Errorf("redirect_uri should be %s but was %s", redirectURL, r.RedirectURI) + return fmt.Sprintf("%s?error=invalid_redirect_uri", r.RedirectURI) + } + return fmt.Sprintf("%s?state=%s&code=%s", r.RedirectURI, r.State, "AUTH_CODE") + }, + NewTokenResponse: func(r authserver.TokenRequest) (int, string) { + if w := "AUTH_CODE"; r.Code != w { + t.Errorf("code wants %s but %s", w, r.Code) + return 400, invalidGrantResponse + } + return 200, validTokenResponse + }, + }) + defer s.Close() + cfg := oauth2cli.Config{ + OAuth2Config: oauth2.Config{ + ClientID: "YOUR_CLIENT_ID", + ClientSecret: "YOUR_CLIENT_SECRET", + Scopes: []string{"email", "profile"}, + Endpoint: oauth2.Endpoint{ + AuthURL: s.URL + "/auth", + TokenURL: s.URL + "/token", + }, + }, + LocalServerBindAddress: []string{"127.0.0.1:8888"}, + RedirectURLHostname: "127.0.0.1", // should be ignored as RedirectURL will override + RedirectURL: redirectURL, + LocalServerReadyChan: openBrowserCh, + LocalServerMiddleware: loggingMiddleware(t), + Logf: t.Logf, + } + token, err := oauth2cli.GetToken(ctx, cfg) + if err != nil { + t.Errorf("could not get a token: %s", err) + return + } + if "ACCESS_TOKEN" != token.AccessToken { + t.Errorf("AccessToken wants %s but %s", "ACCESS_TOKEN", token.AccessToken) + } + if "REFRESH_TOKEN" != token.RefreshToken { + t.Errorf("RefreshToken wants %s but %s", "REFRESH_TOKEN", token.AccessToken) + } + }() + wg.Add(1) + go func() { + defer wg.Done() + toURL, ok := <-openBrowserCh + if !ok { + t.Errorf("server already closed") + return + } + client.GetAndVerify(t, toURL, 200, oauth2cli.DefaultLocalServerSuccessHTML) + }() + wg.Wait() +} + func TestSuccessRedirect(t *testing.T) { ctx, cancel := context.WithTimeout(context.TODO(), 1*time.Second) defer cancel() diff --git a/oauth2cli.go b/oauth2cli.go index 296a2bb..da4f13b 100644 --- a/oauth2cli.go +++ b/oauth2cli.go @@ -56,6 +56,9 @@ type Config struct { // You can set this if your provider does not accept localhost. // Default to localhost. RedirectURLHostname string + // FQDN of redirect url + // This option will override `RedirectURLHostname` + port + RedirectURL string // Options for an authorization request. // You can set oauth2.AccessTypeOffline and the PKCE options here. AuthCodeOptions []oauth2.AuthCodeOption diff --git a/server.go b/server.go index b8c56ae..a5846bc 100644 --- a/server.go +++ b/server.go @@ -94,6 +94,9 @@ func receiveCodeViaLocalServer(ctx context.Context, c *Config) (string, error) { } func computeRedirectURL(l net.Listener, c *Config) string { + if c.RedirectURL != "" { + return c.RedirectURL + } hostPort := fmt.Sprintf("%s:%d", c.RedirectURLHostname, l.Addr().(*net.TCPAddr).Port) if c.LocalServerCertFile != "" { return "https://" + hostPort