Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 218 additions & 0 deletions app/auth/plugins/aws_imds/outgoing.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
package awsimds

import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"sync"
"time"

authplugins "github.com/winhowes/AuthTranslator/app/auth"
)

// awsIMDSParams configures the AWS IMDS plugin.
type awsIMDSParams struct {
Header string `json:"header"`
Prefix string `json:"prefix"`
}

// AWSIMDS fetches the IAM role session token from the AWS Instance Metadata
// Service (IMDSv2) and adds it to outgoing requests.
type AWSIMDS struct{}

// MetadataHost is the base URL for the AWS metadata service. It can be
// overridden in tests.
var MetadataHost = "http://169.254.169.254"

// HTTPClient is used for all metadata requests.
var HTTPClient = &http.Client{Timeout: 5 * time.Second}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Disable proxies for IMDS metadata fetches

The metadata client is instantiated with the default transport, which honors HTTP(S)_PROXY (lines 36-37). In environments where those env vars are set—a common case on EC2—the IMDS calls will be routed through the proxy, leaking the session token/credentials and often failing to reach 169.254.169.254 at all. AWS SDKs explicitly disable proxies for IMDS requests; the plugin should do the same (e.g., set a transport with Proxy: nil) to avoid credential exposure and spurious failures.

Useful? React with 👍 / 👎.


var tokenCache = struct {
sync.Mutex
ct cachedToken
}{ct: cachedToken{}}

type cachedToken struct {
token string
exp time.Time
}

// AWSOIDC is kept as a backward-compatible alias for configurations still
// referencing the old plugin name. It delegates all behavior to AWSIMDS but
// advertises the legacy `aws_oidc` name.
type AWSOIDC struct{ AWSIMDS }

func (a *AWSOIDC) Name() string { return "aws_oidc" }

func (a *AWSIMDS) Name() string { return "aws_imds" }

func (a *AWSIMDS) RequiredParams() []string { return nil }

func (a *AWSIMDS) OptionalParams() []string { return []string{"header", "prefix"} }

func (a *AWSIMDS) ParseParams(m map[string]interface{}) (interface{}, error) {
p, err := authplugins.ParseParams[awsIMDSParams](m)
if err != nil {
return nil, err
}
if p.Header == "" {
p.Header = "Authorization"
}
if p.Prefix == "" {
p.Prefix = "Bearer "
}
return p, nil
}

func (a *AWSIMDS) AddAuth(ctx context.Context, r *http.Request, params interface{}) error {
cfg, ok := params.(*awsIMDSParams)
if !ok {
return fmt.Errorf("invalid config")
}
tok, exp := getCachedToken()
if tok == "" || time.Now().After(exp.Add(-1*time.Minute)) {
var err error
tok, exp, err = fetchToken(ctx)
if err != nil {
return err
}
setCachedToken(tok, exp)
}
r.Header.Set(cfg.Header, cfg.Prefix+tok)
return nil
}

func fetchToken(ctx context.Context) (string, time.Time, error) {
metaToken, err := fetchMetadataToken(ctx)
if err != nil {
return "", time.Time{}, err
}

roleName, err := fetchRoleName(ctx, metaToken)
if err != nil {
return "", time.Time{}, err
}

credentials, err := fetchRoleCredentials(ctx, metaToken, roleName)
if err != nil {
return "", time.Time{}, err
}

if credentials.Token == "" {
return "", time.Time{}, fmt.Errorf("empty session token from IMDS for role %s", roleName)
}

exp, err := time.Parse(time.RFC3339, credentials.Expiration)
if err != nil {
return "", time.Time{}, fmt.Errorf("parse expiration: %w", err)
}

return credentials.Token, exp, nil
}

func fetchMetadataToken(ctx context.Context) (string, error) {
tokenURL := fmt.Sprintf("%s/latest/api/token", MetadataHost)
req, err := http.NewRequestWithContext(ctx, http.MethodPut, tokenURL, nil)
if err != nil {
return "", err
}
req.Header.Set("X-aws-ec2-metadata-token-ttl-seconds", "21600")

resp, err := HTTPClient.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("token fetch status %s: %s", resp.Status, body)
}

tokenBytes, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
return string(tokenBytes), nil
}

func fetchRoleName(ctx context.Context, metaToken string) (string, error) {
roleURL := fmt.Sprintf("%s/latest/meta-data/iam/security-credentials/", MetadataHost)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, roleURL, nil)
if err != nil {
return "", err
}
req.Header.Set("X-aws-ec2-metadata-token", metaToken)

resp, err := HTTPClient.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("role name status %s: %s", resp.Status, body)
}

roleBytes, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
roleName := strings.TrimSpace(string(roleBytes))
if roleName == "" {
return "", fmt.Errorf("empty role name from IMDS")
}
return roleName, nil
}

type roleCredentials struct {
Expiration string `json:"Expiration"`
Token string `json:"Token"`
}
Comment on lines +164 to +169
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Using IMDS session token without signing credentials

roleCredentials only captures the metadata response’s Token and Expiration, and AddAuth later sends that token directly in a header. IMDS role credentials also include AccessKeyId and SecretAccessKey, which are required to produce the SigV4 signature and to accompany the session token as X-Amz-Security-Token. Because the plugin discards the signing keys and treats the session token as a bearer value, any request to AWS APIs using this auth type will be rejected with 401/403—it never produces a valid AWS credential.

Useful? React with 👍 / 👎.


func fetchRoleCredentials(ctx context.Context, metaToken, roleName string) (*roleCredentials, error) {
credsURL := fmt.Sprintf("%s/latest/meta-data/iam/security-credentials/%s", MetadataHost, roleName)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, credsURL, nil)
if err != nil {
return nil, err
}
req.Header.Set("X-aws-ec2-metadata-token", metaToken)

resp, err := HTTPClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("role credentials status %s: %s", resp.Status, body)
}

var rc roleCredentials
if err := json.NewDecoder(resp.Body).Decode(&rc); err != nil {
return nil, err
}
if rc.Expiration == "" {
return nil, fmt.Errorf("missing expiration in role credentials")
}
return &rc, nil
}

func getCachedToken() (string, time.Time) {
tokenCache.Lock()
defer tokenCache.Unlock()
return tokenCache.ct.token, tokenCache.ct.exp
}

func setCachedToken(tok string, exp time.Time) {
tokenCache.Lock()
tokenCache.ct = cachedToken{token: tok, exp: exp}
tokenCache.Unlock()
}

func init() {
authplugins.RegisterOutgoing(&AWSIMDS{})
authplugins.RegisterOutgoing(&AWSOIDC{})
}
162 changes: 162 additions & 0 deletions app/auth/plugins/aws_imds/outgoing_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
package awsimds

import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
)

func TestAddAuthFetchesAndCachesToken(t *testing.T) {
expires := time.Now().Add(2 * time.Minute).UTC().Truncate(time.Second)
sessionToken := "sts-session-token"
metaToken := "meta123"
roleName := "example-role"
var requestCount int

srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/latest/api/token":
requestCount++
if r.Method != http.MethodPut {
t.Fatalf("expected PUT for token, got %s", r.Method)
}
if ttl := r.Header.Get("X-aws-ec2-metadata-token-ttl-seconds"); ttl == "" {
t.Fatalf("missing TTL header")
}
w.Write([]byte(metaToken))
case "/latest/meta-data/iam/security-credentials/":
requestCount++
if got := r.Header.Get("X-aws-ec2-metadata-token"); got != metaToken {
t.Fatalf("expected metadata token %q, got %q", metaToken, got)
}
w.Write([]byte(roleName))
case "/latest/meta-data/iam/security-credentials/" + roleName:
requestCount++
if got := r.Header.Get("X-aws-ec2-metadata-token"); got != metaToken {
t.Fatalf("expected metadata token %q, got %q", metaToken, got)
}
json.NewEncoder(w).Encode(map[string]interface{}{
"Token": sessionToken,
"Expiration": expires.Format(time.RFC3339),
})
default:
t.Fatalf("unexpected path %s", r.URL.Path)
}
}))
defer srv.Close()

MetadataHost = srv.URL
HTTPClient = srv.Client()
tokenCache.ct = cachedToken{}

plugin := &AWSIMDS{}
paramsRaw, err := plugin.ParseParams(map[string]interface{}{})
if err != nil {
t.Fatalf("parse params: %v", err)
}
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
if err := plugin.AddAuth(context.Background(), req, paramsRaw); err != nil {
t.Fatalf("AddAuth: %v", err)
}

if got := req.Header.Get("Authorization"); got != "Bearer "+sessionToken {
t.Fatalf("unexpected header: %s", got)
}
if requestCount != 3 {
t.Fatalf("expected 3 metadata requests, got %d", requestCount)
}

// Second call should use cache.
req2, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
if err := plugin.AddAuth(context.Background(), req2, paramsRaw); err != nil {
t.Fatalf("AddAuth second: %v", err)
}
if requestCount != 3 {
t.Fatalf("expected cached token, still %d requests", requestCount)
}
}

func TestExpiresSoonTriggersRefresh(t *testing.T) {
expSoon := time.Now().Add(30 * time.Second).UTC().Truncate(time.Second)
expLater := time.Now().Add(10 * time.Minute).UTC().Truncate(time.Second)
metaToken := "meta123"
roleName := "role"
sessionTokens := []string{"first", "second"}
var credIndex int

srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/latest/api/token":
w.Write([]byte(metaToken))
case "/latest/meta-data/iam/security-credentials/":
w.Write([]byte(roleName))
case "/latest/meta-data/iam/security-credentials/" + roleName:
exp := expSoon
if credIndex > 0 {
exp = expLater
}
json.NewEncoder(w).Encode(map[string]interface{}{
"Token": sessionTokens[credIndex],
"Expiration": exp.Format(time.RFC3339),
})
credIndex++
default:
t.Fatalf("unexpected path %s", r.URL.Path)
}
}))
defer srv.Close()

MetadataHost = srv.URL
HTTPClient = srv.Client()
tokenCache.ct = cachedToken{}

plugin := &AWSIMDS{}
paramsRaw, err := plugin.ParseParams(map[string]interface{}{})
if err != nil {
t.Fatalf("parse params: %v", err)
}
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
if err := plugin.AddAuth(context.Background(), req, paramsRaw); err != nil {
t.Fatalf("AddAuth: %v", err)
}
req2, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
if err := plugin.AddAuth(context.Background(), req2, paramsRaw); err != nil {
t.Fatalf("AddAuth second: %v", err)
}
if credIndex != 2 {
t.Fatalf("expected token refresh, stage %d", credIndex)
}
}

func TestErrorResponses(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/latest/api/token":
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("no token"))
case "/latest/meta-data/iam/security-credentials/":
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("no role"))
default:
t.Fatalf("unexpected path %s", r.URL.Path)
}
}))
defer srv.Close()

MetadataHost = srv.URL
HTTPClient = srv.Client()
tokenCache.ct = cachedToken{}

plugin := &AWSIMDS{}
paramsRaw, err := plugin.ParseParams(map[string]interface{}{})
if err != nil {
t.Fatalf("parse params: %v", err)
}
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
if err := plugin.AddAuth(context.Background(), req, paramsRaw); err == nil {
t.Fatalf("expected error from metadata token fetch")
}
}
1 change: 1 addition & 0 deletions app/auth/plugins/plugins.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package plugins

import (
_ "github.com/winhowes/AuthTranslator/app/auth/plugins/aws_imds"
_ "github.com/winhowes/AuthTranslator/app/auth/plugins/azure_oidc"
_ "github.com/winhowes/AuthTranslator/app/auth/plugins/basic"
_ "github.com/winhowes/AuthTranslator/app/auth/plugins/findreplace"
Expand Down
Loading
Loading