diff --git a/app/auth/plugins/oauth2/outgoing.go b/app/auth/plugins/oauth2/outgoing.go new file mode 100644 index 0000000..e762549 --- /dev/null +++ b/app/auth/plugins/oauth2/outgoing.go @@ -0,0 +1,396 @@ +package oauth2 + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "sort" + "strconv" + "strings" + "sync" + "time" + + authplugins "github.com/winhowes/AuthTranslator/app/auth" + "github.com/winhowes/AuthTranslator/app/secrets" +) + +const refreshSkew = time.Minute + +// oauth2Params configures generic OAuth2 access-token acquisition for outgoing +// requests. client_secret and refresh_token are secret references, not raw values. +type oauth2Params struct { + TokenURL string `json:"token_url"` + GrantType string `json:"grant_type"` + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + RefreshToken string `json:"refresh_token"` + Scope string `json:"scope"` + Audience string `json:"audience"` + ClientAuth string `json:"client_auth"` + Header string `json:"header"` + Prefix string `json:"prefix"` + ExtraParams map[string]string `json:"extra_params"` +} + +// OAuth2 obtains OAuth2 access tokens from a configurable token endpoint and +// attaches them to outgoing requests. +type OAuth2 struct{} + +// HTTPClient performs token endpoint HTTP requests. It can be swapped in tests. +var HTTPClient = &http.Client{Timeout: 5 * time.Second} + +type cachedToken struct { + accessToken string + refreshToken string + exp time.Time + refreshAt time.Time +} + +var tokenCache = struct { + sync.Mutex + m map[string]cachedToken + refreshLocks map[string]*sync.Mutex +}{m: make(map[string]cachedToken), refreshLocks: make(map[string]*sync.Mutex)} + +type tokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn json.RawMessage `json:"expires_in"` + RefreshToken string `json:"refresh_token"` +} + +func (o *OAuth2) Name() string { return "oauth2" } + +func (o *OAuth2) RequiredParams() []string { return []string{"token_url"} } + +func (o *OAuth2) OptionalParams() []string { + return []string{ + "grant_type", + "client_id", + "client_secret", + "refresh_token", + "scope", + "audience", + "client_auth", + "header", + "prefix", + "extra_params", + } +} + +func (o *OAuth2) ParseParams(m map[string]interface{}) (interface{}, error) { + _, prefixSet := m["prefix"] + + p, err := authplugins.ParseParams[oauth2Params](m) + if err != nil { + return nil, err + } + if p.TokenURL == "" { + return nil, fmt.Errorf("missing token_url") + } + u, err := url.Parse(p.TokenURL) + if err != nil || u.Scheme == "" || u.Host == "" { + return nil, fmt.Errorf("invalid token_url") + } + if p.GrantType == "" { + if p.RefreshToken != "" { + p.GrantType = "refresh_token" + } else { + p.GrantType = "client_credentials" + } + } + if p.ClientAuth == "" { + p.ClientAuth = "body" + } + if p.Header == "" { + p.Header = "Authorization" + } + if !prefixSet { + p.Prefix = "Bearer " + } + if err := validateOAuth2Params(p); err != nil { + return nil, err + } + return p, nil +} + +func validateOAuth2Params(p *oauth2Params) error { + switch p.GrantType { + case "client_credentials": + if p.ClientID == "" || p.ClientSecret == "" { + return fmt.Errorf("client_credentials grant requires client_id and client_secret") + } + case "refresh_token": + if p.RefreshToken == "" { + return fmt.Errorf("refresh_token grant requires refresh_token") + } + default: + return fmt.Errorf("unsupported grant_type %q", p.GrantType) + } + + switch p.ClientAuth { + case "body": + if p.ClientSecret != "" && p.ClientID == "" { + return fmt.Errorf("client_id is required with client_secret") + } + case "basic": + if p.ClientID == "" || p.ClientSecret == "" { + return fmt.Errorf("basic client_auth requires client_id and client_secret") + } + case "none": + if p.ClientSecret != "" { + return fmt.Errorf("client_secret cannot be used with client_auth none") + } + default: + return fmt.Errorf("unsupported client_auth %q", p.ClientAuth) + } + + for k := range p.ExtraParams { + key := strings.ToLower(strings.TrimSpace(k)) + if key == "" { + return fmt.Errorf("extra_params cannot contain empty keys") + } + switch key { + case "grant_type", "client_id", "client_secret", "refresh_token", "scope", "audience": + return fmt.Errorf("extra_params cannot override %q", key) + } + } + return nil +} + +func (p *oauth2Params) SecretRefs() []string { + var refs []string + if p.ClientSecret != "" { + refs = append(refs, p.ClientSecret) + } + if p.RefreshToken != "" { + refs = append(refs, p.RefreshToken) + } + return refs +} + +func (o *OAuth2) AddAuth(ctx context.Context, r *http.Request, params interface{}) error { + cfg, ok := params.(*oauth2Params) + if !ok { + return fmt.Errorf("invalid config") + } + + key := cfg.cacheKey() + ct := getCachedToken(key) + if tokenNeedsRefresh(ct) { + refreshLock := getRefreshLock(key) + refreshLock.Lock() + defer refreshLock.Unlock() + + ct = getCachedToken(key) + if !tokenNeedsRefresh(ct) { + r.Header.Set(cfg.Header, cfg.Prefix+ct.accessToken) + return nil + } + + next, err := fetchToken(ctx, cfg, ct.refreshToken) + if err != nil { + if tokenUsable(ct, time.Now()) { + r.Header.Set(cfg.Header, cfg.Prefix+ct.accessToken) + return nil + } + return err + } + if next.refreshToken == "" { + next.refreshToken = ct.refreshToken + } + setCachedToken(key, next) + ct = next + } + + r.Header.Set(cfg.Header, cfg.Prefix+ct.accessToken) + return nil +} + +func tokenUsable(ct cachedToken, now time.Time) bool { + return ct.accessToken != "" && now.Before(ct.exp) +} + +func tokenNeedsRefresh(ct cachedToken) bool { + if ct.accessToken == "" { + return true + } + now := time.Now() + refreshAt := ct.refreshAt + if refreshAt.IsZero() { + refreshAt = tokenRefreshTime(now, ct.exp) + } + return !now.Before(refreshAt) +} + +func tokenRefreshTime(now, exp time.Time) time.Time { + ttl := exp.Sub(now) + if ttl <= 0 { + return now + } + skew := ttl / 10 + if skew > refreshSkew { + skew = refreshSkew + } + if skew <= 0 { + return exp + } + return exp.Add(-skew) +} + +func fetchToken(ctx context.Context, cfg *oauth2Params, cachedRefreshToken string) (cachedToken, error) { + form := url.Values{} + form.Set("grant_type", cfg.GrantType) + if cfg.Scope != "" { + form.Set("scope", cfg.Scope) + } + if cfg.Audience != "" { + form.Set("audience", cfg.Audience) + } + for k, v := range cfg.ExtraParams { + form.Set(k, v) + } + + clientSecret, err := loadSecretRef(ctx, cfg.ClientSecret) + if err != nil { + return cachedToken{}, err + } + + if cfg.GrantType == "refresh_token" { + refreshToken := cachedRefreshToken + if refreshToken == "" { + refreshToken, err = loadSecretRef(ctx, cfg.RefreshToken) + if err != nil { + return cachedToken{}, err + } + } + form.Set("refresh_token", refreshToken) + } + + switch cfg.ClientAuth { + case "body": + if cfg.ClientID != "" { + form.Set("client_id", cfg.ClientID) + } + if cfg.ClientSecret != "" { + form.Set("client_secret", clientSecret) + } + case "basic": + // Added below after the request is built. + case "none": + if cfg.ClientID != "" { + form.Set("client_id", cfg.ClientID) + } + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, cfg.TokenURL, strings.NewReader(form.Encode())) + if err != nil { + return cachedToken{}, err + } + req.Header.Set("Accept", "application/json") + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + if cfg.ClientAuth == "basic" { + req.SetBasicAuth(cfg.ClientID, clientSecret) + } + + resp, err := HTTPClient.Do(req) + if err != nil { + return cachedToken{}, err + } + defer resp.Body.Close() + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + return cachedToken{}, fmt.Errorf("token request failed: %s: %s", resp.Status, strings.TrimSpace(string(body))) + } + + var tr tokenResponse + if err := json.NewDecoder(resp.Body).Decode(&tr); err != nil { + return cachedToken{}, err + } + if tr.AccessToken == "" { + return cachedToken{}, fmt.Errorf("empty access token") + } + + now := time.Now() + exp := now.Add(parseExpiresIn(tr.ExpiresIn)) + return cachedToken{ + accessToken: tr.AccessToken, + refreshToken: tr.RefreshToken, + exp: exp, + refreshAt: tokenRefreshTime(now, exp), + }, nil +} + +func loadSecretRef(ctx context.Context, ref string) (string, error) { + if ref == "" { + return "", nil + } + return secrets.LoadSecret(ctx, ref) +} + +func parseExpiresIn(raw json.RawMessage) time.Duration { + if len(raw) == 0 { + return time.Minute + } + var seconds float64 + if err := json.Unmarshal(raw, &seconds); err == nil && seconds > 0 { + return time.Duration(seconds * float64(time.Second)) + } + var text string + if err := json.Unmarshal(raw, &text); err == nil { + if parsed, err := strconv.ParseFloat(text, 64); err == nil && parsed > 0 { + return time.Duration(parsed * float64(time.Second)) + } + } + return time.Minute +} + +func (p *oauth2Params) cacheKey() string { + parts := []string{ + p.TokenURL, + p.GrantType, + p.ClientID, + p.ClientSecret, + p.RefreshToken, + p.Scope, + p.Audience, + p.ClientAuth, + } + keys := make([]string, 0, len(p.ExtraParams)) + for k := range p.ExtraParams { + keys = append(keys, k) + } + sort.Strings(keys) + for _, k := range keys { + parts = append(parts, k+"="+p.ExtraParams[k]) + } + return strings.Join(parts, "\x00") +} + +func getCachedToken(key string) cachedToken { + tokenCache.Lock() + defer tokenCache.Unlock() + return tokenCache.m[key] +} + +func setCachedToken(key string, tok cachedToken) { + tokenCache.Lock() + tokenCache.m[key] = tok + tokenCache.Unlock() +} + +func getRefreshLock(key string) *sync.Mutex { + tokenCache.Lock() + defer tokenCache.Unlock() + refreshLock := tokenCache.refreshLocks[key] + if refreshLock == nil { + refreshLock = &sync.Mutex{} + tokenCache.refreshLocks[key] = refreshLock + } + return refreshLock +} + +func init() { authplugins.RegisterOutgoing(&OAuth2{}) } diff --git a/app/auth/plugins/oauth2/outgoing_test.go b/app/auth/plugins/oauth2/outgoing_test.go new file mode 100644 index 0000000..4180609 --- /dev/null +++ b/app/auth/plugins/oauth2/outgoing_test.go @@ -0,0 +1,861 @@ +package oauth2 + +import ( + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/winhowes/AuthTranslator/app/secrets" + _ "github.com/winhowes/AuthTranslator/app/secrets/plugins" +) + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) } + +func resetCache() { + tokenCache.Lock() + tokenCache.m = make(map[string]cachedToken) + tokenCache.refreshLocks = make(map[string]*sync.Mutex) + tokenCache.Unlock() + secrets.ClearCache() +} + +func withTestClient(t *testing.T, client *http.Client) { + t.Helper() + oldClient := HTTPClient + HTTPClient = client + t.Cleanup(func() { + HTTPClient = oldClient + resetCache() + }) +} + +func TestOAuth2RefreshTokenAddAuth(t *testing.T) { + resetCache() + t.Setenv("CLIENT_SECRET", "secret") + t.Setenv("REFRESH_TOKEN", "refresh-1") + + var hits int32 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&hits, 1) + if r.Method != http.MethodPost { + t.Fatalf("unexpected method %s", r.Method) + } + if ct := r.Header.Get("Content-Type"); ct != "application/x-www-form-urlencoded" { + t.Fatalf("unexpected content type %q", ct) + } + if err := r.ParseForm(); err != nil { + t.Fatal(err) + } + assertForm(t, r, "grant_type", "refresh_token") + assertForm(t, r, "client_id", "client") + assertForm(t, r, "client_secret", "secret") + assertForm(t, r, "refresh_token", "refresh-1") + assertForm(t, r, "scope", "read write") + assertForm(t, r, "audience", "https://api.example.com") + assertForm(t, r, "resource", "https://resource.example.com") + fmt.Fprint(w, `{"access_token":"access-1","expires_in":3600,"refresh_token":"refresh-2"}`) + })) + defer ts.Close() + withTestClient(t, ts.Client()) + + p := OAuth2{} + cfg, err := p.ParseParams(map[string]interface{}{ + "token_url": ts.URL, + "grant_type": "refresh_token", + "client_id": "client", + "client_secret": "env:CLIENT_SECRET", + "refresh_token": "env:REFRESH_TOKEN", + "scope": "read write", + "audience": "https://api.example.com", + "extra_params": map[string]string{ + "resource": "https://resource.example.com", + }, + }) + if err != nil { + t.Fatal(err) + } + + r := &http.Request{Header: http.Header{}} + if err := p.AddAuth(context.Background(), r, cfg); err != nil { + t.Fatal(err) + } + if got := r.Header.Get("Authorization"); got != "Bearer access-1" { + t.Fatalf("unexpected auth header %q", got) + } + if got := atomic.LoadInt32(&hits); got != 1 { + t.Fatalf("expected one token request, got %d", got) + } +} + +func TestOAuth2CachesAccessToken(t *testing.T) { + resetCache() + t.Setenv("CLIENT_SECRET", "secret") + + var hits int32 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&hits, 1) + fmt.Fprint(w, `{"access_token":"cached","expires_in":3600}`) + })) + defer ts.Close() + withTestClient(t, ts.Client()) + + p := OAuth2{} + cfg, err := p.ParseParams(map[string]interface{}{ + "token_url": ts.URL, + "client_id": "client", + "client_secret": "env:CLIENT_SECRET", + }) + if err != nil { + t.Fatal(err) + } + + for i := 0; i < 2; i++ { + r := &http.Request{Header: http.Header{}} + if err := p.AddAuth(context.Background(), r, cfg); err != nil { + t.Fatal(err) + } + if got := r.Header.Get("Authorization"); got != "Bearer cached" { + t.Fatalf("unexpected auth header %q", got) + } + } + if got := atomic.LoadInt32(&hits); got != 1 { + t.Fatalf("expected cached token to be reused, got %d token requests", got) + } +} + +func TestOAuth2RefreshesEarlyAndUsesRotatedRefreshToken(t *testing.T) { + resetCache() + t.Setenv("REFRESH_TOKEN", "refresh-1") + + var hits int32 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + t.Fatal(err) + } + switch atomic.AddInt32(&hits, 1) { + case 1: + assertForm(t, r, "refresh_token", "refresh-1") + fmt.Fprint(w, `{"access_token":"old","expires_in":30,"refresh_token":"refresh-2"}`) + case 2: + assertForm(t, r, "refresh_token", "refresh-2") + fmt.Fprint(w, `{"access_token":"new","expires_in":3600}`) + default: + t.Fatalf("unexpected token request %d", hits) + } + })) + defer ts.Close() + withTestClient(t, ts.Client()) + + p := OAuth2{} + cfg, err := p.ParseParams(map[string]interface{}{ + "token_url": ts.URL, + "refresh_token": "env:REFRESH_TOKEN", + "client_auth": "none", + "client_id": "public-client", + }) + if err != nil { + t.Fatal(err) + } + parsed := cfg.(*oauth2Params) + if parsed.GrantType != "refresh_token" { + t.Fatalf("expected refresh_token default grant, got %q", parsed.GrantType) + } + + r1 := &http.Request{Header: http.Header{}} + if err := p.AddAuth(context.Background(), r1, cfg); err != nil { + t.Fatal(err) + } + if got := r1.Header.Get("Authorization"); got != "Bearer old" { + t.Fatalf("unexpected first header %q", got) + } + + r2 := &http.Request{Header: http.Header{}} + if err := p.AddAuth(context.Background(), r2, cfg); err != nil { + t.Fatal(err) + } + if got := r2.Header.Get("Authorization"); got != "Bearer old" { + t.Fatalf("short-lived token should be cached before refresh window, got %q", got) + } + if got := atomic.LoadInt32(&hits); got != 1 { + t.Fatalf("expected short-lived token to be cached, got %d token requests", got) + } + + key := parsed.cacheKey() + ct := getCachedToken(key) + ct.refreshAt = time.Now().Add(-time.Second) + setCachedToken(key, ct) + + r3 := &http.Request{Header: http.Header{}} + if err := p.AddAuth(context.Background(), r3, cfg); err != nil { + t.Fatal(err) + } + if got := r3.Header.Get("Authorization"); got != "Bearer new" { + t.Fatalf("unexpected refreshed header %q", got) + } +} + +func TestOAuth2SerializesConcurrentRefresh(t *testing.T) { + resetCache() + t.Setenv("REFRESH_TOKEN", "refresh-1") + + var hits int32 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + t.Error(err) + http.Error(w, "bad form", http.StatusBadRequest) + return + } + switch hit := atomic.AddInt32(&hits, 1); hit { + case 1: + if got := r.Form.Get("refresh_token"); got != "refresh-1" { + t.Errorf("expected first refresh token %q, got %q", "refresh-1", got) + } + fmt.Fprint(w, `{"access_token":"old","expires_in":30,"refresh_token":"refresh-2"}`) + case 2: + if got := r.Form.Get("refresh_token"); got != "refresh-2" { + t.Errorf("expected rotated refresh token %q, got %q", "refresh-2", got) + } + time.Sleep(25 * time.Millisecond) + fmt.Fprint(w, `{"access_token":"new","expires_in":3600,"refresh_token":"refresh-3"}`) + default: + http.Error(w, fmt.Sprintf("unexpected concurrent refresh %d", hit), http.StatusBadRequest) + } + })) + defer ts.Close() + withTestClient(t, ts.Client()) + + p := OAuth2{} + cfg, err := p.ParseParams(map[string]interface{}{ + "token_url": ts.URL, + "grant_type": "refresh_token", + "refresh_token": "env:REFRESH_TOKEN", + "client_auth": "none", + }) + if err != nil { + t.Fatal(err) + } + + first := &http.Request{Header: http.Header{}} + if err := p.AddAuth(context.Background(), first, cfg); err != nil { + t.Fatal(err) + } + if got := first.Header.Get("Authorization"); got != "Bearer old" { + t.Fatalf("unexpected first token %q", got) + } + + parsed := cfg.(*oauth2Params) + key := parsed.cacheKey() + ct := getCachedToken(key) + ct.refreshAt = time.Now().Add(-time.Second) + setCachedToken(key, ct) + + const workers = 8 + errs := make(chan error, workers) + start := make(chan struct{}) + var wg sync.WaitGroup + for i := 0; i < workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + <-start + r := &http.Request{Header: http.Header{}} + if err := p.AddAuth(context.Background(), r, cfg); err != nil { + errs <- err + return + } + if got := r.Header.Get("Authorization"); got != "Bearer new" { + errs <- fmt.Errorf("unexpected auth header %q", got) + } + }() + } + close(start) + wg.Wait() + close(errs) + for err := range errs { + if err != nil { + t.Fatal(err) + } + } + if got := atomic.LoadInt32(&hits); got != 2 { + t.Fatalf("expected exactly two token endpoint calls, got %d", got) + } +} + +func TestOAuth2PreRefreshFailureUsesCachedToken(t *testing.T) { + resetCache() + t.Setenv("CLIENT_SECRET", "secret") + + var hits int32 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch atomic.AddInt32(&hits, 1) { + case 1: + fmt.Fprint(w, `{"access_token":"cached","expires_in":3600}`) + default: + http.Error(w, "temporary token endpoint failure", http.StatusServiceUnavailable) + } + })) + defer ts.Close() + withTestClient(t, ts.Client()) + + p := OAuth2{} + cfg, err := p.ParseParams(map[string]interface{}{ + "token_url": ts.URL, + "client_id": "client", + "client_secret": "env:CLIENT_SECRET", + }) + if err != nil { + t.Fatal(err) + } + + first := &http.Request{Header: http.Header{}} + if err := p.AddAuth(context.Background(), first, cfg); err != nil { + t.Fatal(err) + } + if got := first.Header.Get("Authorization"); got != "Bearer cached" { + t.Fatalf("unexpected first token %q", got) + } + + parsed := cfg.(*oauth2Params) + key := parsed.cacheKey() + ct := getCachedToken(key) + ct.refreshAt = time.Now().Add(-time.Second) + setCachedToken(key, ct) + + fallback := &http.Request{Header: http.Header{}} + if err := p.AddAuth(context.Background(), fallback, cfg); err != nil { + t.Fatalf("expected cached token fallback, got %v", err) + } + if got := fallback.Header.Get("Authorization"); got != "Bearer cached" { + t.Fatalf("expected cached token fallback, got %q", got) + } + if got := atomic.LoadInt32(&hits); got != 2 { + t.Fatalf("expected failed pre-refresh attempt, got %d token requests", got) + } + + ct = getCachedToken(key) + ct.exp = time.Now().Add(-time.Second) + ct.refreshAt = time.Now().Add(-time.Second) + setCachedToken(key, ct) + + expired := &http.Request{Header: http.Header{}} + if err := p.AddAuth(context.Background(), expired, cfg); err == nil { + t.Fatal("expected refresh error after cached token expiry") + } + if got := expired.Header.Get("Authorization"); got != "" { + t.Fatalf("expected no header after expired-token refresh failure, got %q", got) + } +} + +func TestOAuth2ClientCredentialsBasicAuth(t *testing.T) { + resetCache() + t.Setenv("CLIENT_SECRET", "secret") + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user, pass, ok := r.BasicAuth() + if !ok || user != "client" || pass != "secret" { + t.Fatalf("unexpected basic auth user=%q pass=%q ok=%v", user, pass, ok) + } + if err := r.ParseForm(); err != nil { + t.Fatal(err) + } + assertForm(t, r, "grant_type", "client_credentials") + if got := r.Form.Get("client_secret"); got != "" { + t.Fatalf("client_secret should not be sent in body for basic auth, got %q", got) + } + fmt.Fprint(w, `{"access_token":"access-basic","expires_in":"3600"}`) + })) + defer ts.Close() + withTestClient(t, ts.Client()) + + p := OAuth2{} + cfg, err := p.ParseParams(map[string]interface{}{ + "token_url": ts.URL, + "grant_type": "client_credentials", + "client_id": "client", + "client_secret": "env:CLIENT_SECRET", + "client_auth": "basic", + "header": "X-Token", + "prefix": "Token ", + }) + if err != nil { + t.Fatal(err) + } + + r := &http.Request{Header: http.Header{}} + if err := p.AddAuth(context.Background(), r, cfg); err != nil { + t.Fatal(err) + } + if got := r.Header.Get("X-Token"); got != "Token access-basic" { + t.Fatalf("unexpected auth header %q", got) + } +} + +func TestOAuth2ParseParamsValidation(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + want string + }{ + { + name: "missing token url", + params: map[string]interface{}{}, + want: "missing token_url", + }, + { + name: "invalid token url", + params: map[string]interface{}{ + "token_url": "://bad", + }, + want: "invalid token_url", + }, + { + name: "unsupported grant", + params: map[string]interface{}{ + "token_url": "https://auth.example.com/token", + "grant_type": "password", + }, + want: "unsupported grant_type", + }, + { + name: "missing refresh token", + params: map[string]interface{}{ + "token_url": "https://auth.example.com/token", + "grant_type": "refresh_token", + }, + want: "requires refresh_token", + }, + { + name: "client credentials missing secret", + params: map[string]interface{}{ + "token_url": "https://auth.example.com/token", + "client_id": "client", + }, + want: "requires client_id and client_secret", + }, + { + name: "basic missing secret", + params: map[string]interface{}{ + "token_url": "https://auth.example.com/token", + "grant_type": "refresh_token", + "refresh_token": "env:REFRESH_TOKEN", + "client_id": "client", + "client_auth": "basic", + }, + want: "basic client_auth requires", + }, + { + name: "body secret without client id", + params: map[string]interface{}{ + "token_url": "https://auth.example.com/token", + "grant_type": "refresh_token", + "refresh_token": "env:REFRESH_TOKEN", + "client_secret": "env:CLIENT_SECRET", + }, + want: "client_id is required", + }, + { + name: "none with secret", + params: map[string]interface{}{ + "token_url": "https://auth.example.com/token", + "grant_type": "refresh_token", + "refresh_token": "env:REFRESH_TOKEN", + "client_secret": "env:CLIENT_SECRET", + "client_auth": "none", + }, + want: "client_secret cannot be used", + }, + { + name: "unsupported client auth", + params: map[string]interface{}{ + "token_url": "https://auth.example.com/token", + "grant_type": "refresh_token", + "refresh_token": "env:REFRESH_TOKEN", + "client_auth": "signed", + }, + want: "unsupported client_auth", + }, + { + name: "empty extra param key", + params: map[string]interface{}{ + "token_url": "https://auth.example.com/token", + "client_id": "client", + "client_secret": "env:CLIENT_SECRET", + "extra_params": map[string]string{ + " ": "x", + }, + }, + want: "extra_params cannot contain empty keys", + }, + { + name: "extra param override", + params: map[string]interface{}{ + "token_url": "https://auth.example.com/token", + "client_id": "client", + "client_secret": "env:CLIENT_SECRET", + "extra_params": map[string]string{ + "grant_type": "refresh_token", + }, + }, + want: "extra_params cannot override", + }, + { + name: "unknown field", + params: map[string]interface{}{ + "token_url": "https://auth.example.com/token", + "client_id": "client", + "client_secret": "env:CLIENT_SECRET", + "unknown": true, + }, + want: "unknown field", + }, + } + + p := OAuth2{} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if _, err := p.ParseParams(tt.params); err == nil || !strings.Contains(err.Error(), tt.want) { + t.Fatalf("expected error containing %q, got %v", tt.want, err) + } + }) + } +} + +func TestOAuth2ParamLists(t *testing.T) { + p := OAuth2{} + if p.Name() != "oauth2" { + t.Fatalf("unexpected name %q", p.Name()) + } + if got := p.RequiredParams(); len(got) != 1 || got[0] != "token_url" { + t.Fatalf("unexpected required params: %v", got) + } + opts := p.OptionalParams() + for _, want := range []string{"grant_type", "client_id", "client_secret", "refresh_token", "scope", "audience", "client_auth", "header", "prefix", "extra_params"} { + found := false + for _, got := range opts { + if got == want { + found = true + break + } + } + if !found { + t.Fatalf("missing optional param %q in %v", want, opts) + } + } +} + +func TestOAuth2AddAuthInvalidParams(t *testing.T) { + p := OAuth2{} + r := &http.Request{Header: http.Header{}} + if err := p.AddAuth(context.Background(), r, struct{}{}); err == nil { + t.Fatal("expected invalid config error") + } +} + +func TestOAuth2SecretLoadErrors(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + }{ + { + name: "client secret", + params: map[string]interface{}{ + "token_url": "https://auth.example.com/token", + "client_id": "client", + "client_secret": "env:MISSING_CLIENT_SECRET", + }, + }, + { + name: "refresh token", + params: map[string]interface{}{ + "token_url": "https://auth.example.com/token", + "grant_type": "refresh_token", + "refresh_token": "env:MISSING_REFRESH_TOKEN", + "client_auth": "none", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resetCache() + p := OAuth2{} + cfg, err := p.ParseParams(tt.params) + if err != nil { + t.Fatal(err) + } + r := &http.Request{Header: http.Header{}} + if err := p.AddAuth(context.Background(), r, cfg); err == nil { + t.Fatal("expected secret load error") + } + if got := r.Header.Get("Authorization"); got != "" { + t.Fatalf("expected empty auth header, got %q", got) + } + }) + } +} + +func TestOAuth2FetchTokenRequestBuildError(t *testing.T) { + _, err := fetchToken(context.Background(), &oauth2Params{ + TokenURL: "http://[::1", + GrantType: "refresh_token", + ClientAuth: "none", + }, "") + if err == nil { + t.Fatal("expected request construction error") + } +} + +func TestOAuth2FetchTokenTransportError(t *testing.T) { + resetCache() + oldClient := HTTPClient + HTTPClient = &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + return nil, fmt.Errorf("boom") + })} + t.Cleanup(func() { + HTTPClient = oldClient + resetCache() + }) + + _, err := fetchToken(context.Background(), &oauth2Params{ + TokenURL: "https://auth.example.com/token", + GrantType: "client_credentials", + ClientID: "client", + ClientSecret: "env:CLIENT_SECRET", + ClientAuth: "body", + }, "") + if err == nil || !strings.Contains(err.Error(), "CLIENT_SECRET") { + t.Fatalf("expected client secret load error before transport, got %v", err) + } + + t.Setenv("CLIENT_SECRET", "secret") + _, err = fetchToken(context.Background(), &oauth2Params{ + TokenURL: "https://auth.example.com/token", + GrantType: "client_credentials", + ClientID: "client", + ClientSecret: "env:CLIENT_SECRET", + ClientAuth: "body", + }, "") + if err == nil || !strings.Contains(err.Error(), "boom") { + t.Fatalf("expected transport error, got %v", err) + } +} + +func TestOAuth2TokenEndpointErrors(t *testing.T) { + tests := []struct { + name string + status int + body string + want string + }{ + {name: "bad status", status: http.StatusBadRequest, body: "bad token", want: "token request failed"}, + {name: "malformed json", status: http.StatusOK, body: "{broken", want: "invalid character"}, + {name: "empty token", status: http.StatusOK, body: `{"expires_in":3600}`, want: "empty access token"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resetCache() + t.Setenv("CLIENT_SECRET", "secret") + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tt.status) + fmt.Fprint(w, tt.body) + })) + defer ts.Close() + withTestClient(t, ts.Client()) + + p := OAuth2{} + cfg, err := p.ParseParams(map[string]interface{}{ + "token_url": ts.URL, + "client_id": "client", + "client_secret": "env:CLIENT_SECRET", + }) + if err != nil { + t.Fatal(err) + } + r := &http.Request{Header: http.Header{}} + err = p.AddAuth(context.Background(), r, cfg) + if err == nil || !strings.Contains(err.Error(), tt.want) { + t.Fatalf("expected error containing %q, got %v", tt.want, err) + } + if got := r.Header.Get("Authorization"); got != "" { + t.Fatalf("expected auth header to remain empty, got %q", got) + } + }) + } +} + +func TestOAuth2SecretRefsAndDefaults(t *testing.T) { + p := OAuth2{} + cfg, err := p.ParseParams(map[string]interface{}{ + "token_url": "https://auth.example.com/token", + "grant_type": "refresh_token", + "client_id": "client", + "client_secret": "env:CLIENT_SECRET", + "refresh_token": "env:REFRESH_TOKEN", + }) + if err != nil { + t.Fatal(err) + } + parsed := cfg.(*oauth2Params) + if parsed.Header != "Authorization" || parsed.Prefix != "Bearer " || parsed.ClientAuth != "body" { + t.Fatalf("unexpected defaults: %+v", parsed) + } + refs := parsed.SecretRefs() + if len(refs) != 2 || refs[0] != "env:CLIENT_SECRET" || refs[1] != "env:REFRESH_TOKEN" { + t.Fatalf("unexpected refs: %v", refs) + } +} + +func TestOAuth2DefaultExpiry(t *testing.T) { + if got := parseExpiresIn(nil); got != time.Minute { + t.Fatalf("unexpected default expiry: %s", got) + } + if got := parseExpiresIn([]byte(`"120"`)); got != 120*time.Second { + t.Fatalf("unexpected string expiry: %s", got) + } + if got := parseExpiresIn([]byte(`"nope"`)); got != time.Minute { + t.Fatalf("unexpected fallback expiry: %s", got) + } + if got := parseExpiresIn([]byte(`{}`)); got != time.Minute { + t.Fatalf("unexpected object fallback expiry: %s", got) + } +} + +func TestOAuth2TokenRefreshTime(t *testing.T) { + now := time.Unix(1000, 0) + + if got := tokenRefreshTime(now, now.Add(-time.Second)); !got.Equal(now) { + t.Fatalf("expired token should refresh now, got %s", got) + } + if got := tokenRefreshTime(now, now.Add(time.Nanosecond)); !got.Equal(now.Add(time.Nanosecond)) { + t.Fatalf("sub-nanosecond skew should refresh at expiry, got %s", got) + } + if got := tokenRefreshTime(now, now.Add(30*time.Second)); !got.Equal(now.Add(27 * time.Second)) { + t.Fatalf("short-lived token refreshAt should use proportional skew, got %s", got) + } + if got := tokenRefreshTime(now, now.Add(time.Hour)); !got.Equal(now.Add(59 * time.Minute)) { + t.Fatalf("long-lived token refreshAt should cap skew, got %s", got) + } +} + +func TestOAuth2TokenNeedsRefreshWithComputedRefreshAt(t *testing.T) { + if !tokenNeedsRefresh(cachedToken{}) { + t.Fatal("empty cached token should refresh") + } + if tokenNeedsRefresh(cachedToken{accessToken: "tok", exp: time.Now().Add(30 * time.Second)}) { + t.Fatal("short-lived token should not refresh immediately when refreshAt is absent") + } + if !tokenNeedsRefresh(cachedToken{accessToken: "tok", exp: time.Now().Add(-time.Second)}) { + t.Fatal("expired token should refresh when refreshAt is absent") + } +} + +func TestOAuth2TokenUsable(t *testing.T) { + now := time.Now() + if tokenUsable(cachedToken{}, now) { + t.Fatal("empty token should not be usable") + } + if tokenUsable(cachedToken{accessToken: "tok", exp: now}, now) { + t.Fatal("token at expiry should not be usable") + } + if !tokenUsable(cachedToken{accessToken: "tok", exp: now.Add(time.Second)}, now) { + t.Fatal("unexpired token should be usable") + } +} + +func TestOAuth2ClientAuthNoneIncludesClientID(t *testing.T) { + resetCache() + t.Setenv("REFRESH_TOKEN", "refresh") + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + t.Fatal(err) + } + assertForm(t, r, "client_id", "public-client") + assertForm(t, r, "refresh_token", "refresh") + if auth := r.Header.Get("Authorization"); auth != "" { + t.Fatalf("expected no client auth header, got %q", auth) + } + fmt.Fprint(w, `{"access_token":"public","expires_in":3600}`) + })) + defer ts.Close() + withTestClient(t, ts.Client()) + + p := OAuth2{} + cfg, err := p.ParseParams(map[string]interface{}{ + "token_url": ts.URL, + "grant_type": "refresh_token", + "refresh_token": "env:REFRESH_TOKEN", + "client_auth": "none", + "client_id": "public-client", + }) + if err != nil { + t.Fatal(err) + } + r := &http.Request{Header: http.Header{}} + if err := p.AddAuth(context.Background(), r, cfg); err != nil { + t.Fatal(err) + } + if got := r.Header.Get("Authorization"); got != "Bearer public" { + t.Fatalf("unexpected auth header %q", got) + } +} + +func TestOAuth2FetchTokenResponseCloseOnStatusError(t *testing.T) { + resetCache() + t.Setenv("CLIENT_SECRET", "secret") + + body := &closeRecorder{Reader: strings.NewReader("bad")} + oldClient := HTTPClient + HTTPClient = &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusBadRequest, + Status: "400 Bad Request", + Body: body, + }, nil + })} + t.Cleanup(func() { + HTTPClient = oldClient + resetCache() + }) + + _, err := fetchToken(context.Background(), &oauth2Params{ + TokenURL: "https://auth.example.com/token", + GrantType: "client_credentials", + ClientID: "client", + ClientSecret: "env:CLIENT_SECRET", + ClientAuth: "body", + }, "") + if err == nil { + t.Fatal("expected status error") + } + if !body.closed { + t.Fatal("expected response body to be closed") + } +} + +func assertForm(t *testing.T, r *http.Request, key, want string) { + t.Helper() + if got := r.Form.Get(key); got != want { + t.Fatalf("expected form %s=%q, got %q", key, want, got) + } +} + +type closeRecorder struct { + io.Reader + closed bool +} + +func (c *closeRecorder) Close() error { + c.closed = true + return nil +} diff --git a/app/auth/plugins/plugins.go b/app/auth/plugins/plugins.go index 0279f56..11bcf2c 100644 --- a/app/auth/plugins/plugins.go +++ b/app/auth/plugins/plugins.go @@ -11,6 +11,7 @@ import ( _ "github.com/winhowes/AuthTranslator/app/auth/plugins/hmac" _ "github.com/winhowes/AuthTranslator/app/auth/plugins/jwt" _ "github.com/winhowes/AuthTranslator/app/auth/plugins/mtls" + _ "github.com/winhowes/AuthTranslator/app/auth/plugins/oauth2" _ "github.com/winhowes/AuthTranslator/app/auth/plugins/passthrough" _ "github.com/winhowes/AuthTranslator/app/auth/plugins/slack_signature" _ "github.com/winhowes/AuthTranslator/app/auth/plugins/token" diff --git a/app/integration.go b/app/integration.go index 0487584..8316cc7 100644 --- a/app/integration.go +++ b/app/integration.go @@ -29,6 +29,10 @@ type paramRules interface { OptionalParams() []string } +type secretRefsProvider interface { + SecretRefs() []string +} + // validateRequired checks that all fields named in rules.RequiredParams() // are non-zero in v. It assumes v has already been produced by the plugin’s // ParseParams, which should take care of “unknown field” errors by calling @@ -83,18 +87,23 @@ func validateRequired(v interface{}, rules paramRules) error { return nil } -// collectSecretRefs returns any fields named "secrets" (case-insensitive) that -// are slices of strings. It assumes cfg is a struct or pointer to struct. +// collectSecretRefs returns plugin-declared secret refs plus any fields named +// "secrets" (case-insensitive) that are slices of strings. It assumes cfg is a +// struct or pointer to struct for reflective collection. func collectSecretRefs(cfg interface{}) []string { + var refs []string + if p, ok := cfg.(secretRefsProvider); ok { + refs = append(refs, p.SecretRefs()...) + } + rv := reflect.ValueOf(cfg) if rv.Kind() == reflect.Pointer { rv = rv.Elem() } if rv.Kind() != reflect.Struct { - return nil + return refs } rt := rv.Type() - var refs []string for i := 0; i < rt.NumField(); i++ { sf := rt.Field(i) name := sf.Tag.Get("json") diff --git a/app/integration_test.go b/app/integration_test.go index 58b7401..f89e357 100644 --- a/app/integration_test.go +++ b/app/integration_test.go @@ -89,6 +89,24 @@ func (s secretFailOutgoingPlugin) AddAuth(ctx context.Context, r *http.Request, func (s secretFailOutgoingPlugin) RequiredParams() []string { return nil } func (s secretFailOutgoingPlugin) OptionalParams() []string { return nil } +type providedSecretOutgoingPlugin struct{} + +type providedSecretConfig struct { + Ref string `json:"ref"` +} + +func (p *providedSecretConfig) SecretRefs() []string { return []string{p.Ref} } + +func (p providedSecretOutgoingPlugin) Name() string { return "provided-secret-out" } +func (p providedSecretOutgoingPlugin) ParseParams(map[string]interface{}) (interface{}, error) { + return &providedSecretConfig{Ref: "bogus:VAL"}, nil +} +func (p providedSecretOutgoingPlugin) AddAuth(ctx context.Context, r *http.Request, params interface{}) error { + return nil +} +func (p providedSecretOutgoingPlugin) RequiredParams() []string { return nil } +func (p providedSecretOutgoingPlugin) OptionalParams() []string { return nil } + func TestAddIntegrationMissingParam(t *testing.T) { i := &Integration{ Name: "test", @@ -1139,6 +1157,18 @@ func TestPrepareIntegrationOutgoingSecretValidation(t *testing.T) { } } +func TestPrepareIntegrationOutgoingProvidedSecretValidation(t *testing.T) { + authplugins.RegisterOutgoing(providedSecretOutgoingPlugin{}) + integ := &Integration{ + Name: "provided-secret-out", + Destination: "http://example.com", + OutgoingAuth: []AuthPluginConfig{{Type: "provided-secret-out", Params: map[string]interface{}{}}}, + } + if err := prepareIntegration(integ); err == nil || !strings.Contains(err.Error(), "unknown secret source") { + t.Fatalf("expected provided secret validation error, got %v", err) + } +} + func TestPrepareIntegrationWildcardPortHook(t *testing.T) { orig := parseURL parseURL = func(string) (*url.URL, error) { diff --git a/docs/auth-plugins.md b/docs/auth-plugins.md index 8f8cf65..c295eaf 100644 --- a/docs/auth-plugins.md +++ b/docs/auth-plugins.md @@ -34,6 +34,7 @@ AuthTranslator’s behaviour is extended by **plugins** – small Go packages th | Outbound | `google_oidc` | Attaches a Google identity token from the metadata service. | | Outbound | `gcp_token` | Uses a metadata service access token. | | Outbound | `azure_managed_identity` | Retrieves an Azure access token from the Instance Metadata Service. | +| Outbound | `oauth2` | Exchanges client credentials or a refresh token for an access token. | | Outbound | `hmac_signature` | Computes an HMAC for the request. | | Outbound | `jwt` | Adds a signed JWT to the request. | | Outbound | `mtls` | Sends a client certificate and exposes the CN via header. | @@ -123,6 +124,30 @@ outgoing_auth: Obtains an access token from the Azure Instance Metadata Service for the specified `resource`, caches it, and attaches it to the configured header on each outgoing request. +### Outbound `oauth2` + +```yaml +outgoing_auth: + - type: oauth2 + params: + token_url: https://auth.example.com/oauth/token + grant_type: refresh_token # or client_credentials + client_id: my-client-id # required for client_credentials + client_secret: env:OAUTH_SECRET # secret reference + refresh_token: env:OAUTH_REFRESH # secret reference for refresh_token grant + scope: "read write" # optional + audience: https://api.example.com # optional + client_auth: body # body, basic, or none + header: Authorization # optional (default: Authorization) + prefix: "Bearer " # optional (default: "Bearer ") + extra_params: # optional provider-specific token params + resource: https://resource.example.com +``` + +Requests an access token from the configured token endpoint, caches it, and refreshes it one minute before `expires_in`. +For refresh-token flows, any rotated `refresh_token` returned by the provider is reused in memory until restart. +`client_secret` and `refresh_token` must be secret references, not raw values. + --- ## Writing your own plugin