diff --git a/Dockerfile.all b/Dockerfile.all index f3bf279a..83406c7a 100644 --- a/Dockerfile.all +++ b/Dockerfile.all @@ -14,6 +14,7 @@ RUN mkdir -p build && \ go build -mod=vendor -o build/kafka-proxy -ldflags "${LDFLAGS}" . && \ go build -mod=vendor -o build/auth-user -ldflags "${LDFLAGS}" cmd/plugin-auth-user/main.go && \ go build -mod=vendor -o build/auth-ldap -ldflags "${LDFLAGS}" cmd/plugin-auth-ldap/main.go && \ + go build -mod=vendor -o build/aws-msk-iam-provider -ldflags "${LDFLAGS}" cmd/plugin-aws-msk-iam-provider/main.go && \ go build -mod=vendor -o build/google-id-provider -ldflags "${LDFLAGS}" cmd/plugin-googleid-provider/main.go && \ go build -mod=vendor -o build/google-id-info -ldflags "${LDFLAGS}" cmd/plugin-googleid-info/main.go && \ go build -mod=vendor -o build/unsecured-jwt-info -ldflags "${LDFLAGS}" cmd/plugin-unsecured-jwt-info/main.go && \ @@ -34,6 +35,7 @@ COPY --from=builder /go/src/github.com/grepplabs/kafka-proxy/build /opt/kafka-pr RUN setcap 'cap_net_bind_service=+ep' /opt/kafka-proxy/bin/kafka-proxy && \ setcap 'cap_net_bind_service=+ep' /opt/kafka-proxy/bin/auth-user && \ setcap 'cap_net_bind_service=+ep' /opt/kafka-proxy/bin/auth-ldap && \ + setcap 'cap_net_bind_service=+ep' /opt/kafka-proxy/bin/aws-msk-iam-provider && \ setcap 'cap_net_bind_service=+ep' /opt/kafka-proxy/bin/google-id-provider && \ setcap 'cap_net_bind_service=+ep' /opt/kafka-proxy/bin/google-id-info && \ setcap 'cap_net_bind_service=+ep' /opt/kafka-proxy/bin/unsecured-jwt-info && \ diff --git a/Makefile b/Makefile index 85caa2bc..31e64f98 100644 --- a/Makefile +++ b/Makefile @@ -99,6 +99,9 @@ plugin.auth-user: plugin.auth-ldap: CGO_ENABLED=0 go build -o build/auth-ldap $(BUILD_FLAGS) -ldflags "$(LDFLAGS)" cmd/plugin-auth-ldap/main.go +plugin.aws-msk-iam-provider: + CGO_ENABLED=0 go build -o build/aws-msk-iam-provider $(BUILD_FLAGS) -ldflags "$(LDFLAGS)" cmd/plugin-aws-msk-iam-provider/main.go + plugin.google-id-provider: CGO_ENABLED=0 go build -o build/google-id-provider $(BUILD_FLAGS) -ldflags "$(LDFLAGS)" cmd/plugin-googleid-provider/main.go @@ -114,7 +117,7 @@ plugin.unsecured-jwt-provider: plugin.oidc-provider: CGO_ENABLED=0 go build -o build/oidc-provider $(BUILD_FLAGS) -ldflags "$(LDFLAGS)" cmd/plugin-oidc-provider/main.go -all: build plugin.auth-user plugin.auth-ldap plugin.google-id-provider plugin.google-id-info plugin.unsecured-jwt-info plugin.unsecured-jwt-provider plugin.oidc-provider +all: build plugin.auth-user plugin.auth-ldap plugin.aws-msk-iam-provider plugin.google-id-provider plugin.google-id-info plugin.unsecured-jwt-info plugin.unsecured-jwt-provider plugin.oidc-provider clean: rm -rf $(ROOT_DIR)/build diff --git a/cmd/plugin-aws-msk-iam-provider/README.md b/cmd/plugin-aws-msk-iam-provider/README.md new file mode 100644 index 00000000..fc5a5058 --- /dev/null +++ b/cmd/plugin-aws-msk-iam-provider/README.md @@ -0,0 +1,106 @@ +# AWS MSK IAM Authentication Plugin + +This plugin provides AWS MSK IAM authentication for the Kafka Proxy using SASL/OAUTHBEARER mechanism. It automatically handles token refresh to ensure authentication doesn't expire. + +## Features + +- **Multiple Authentication Methods**: Supports default credentials, named profiles, and IAM role assumption +- **Automatic Token Refresh**: Tokens are automatically refreshed 5 minutes before expiry +- **Role Assumption**: Supports assuming IAM roles with optional external ID +- **Error Handling**: Robust error handling with exponential backoff for token generation + +## Configuration Options + +| Option | Required | Description | Default | +|--------|----------|-------------|---------| +| `--region` | Yes | AWS region where MSK cluster is located | - | +| `--profile` | No | AWS named profile to use for authentication | - | +| `--role-arn` | No | IAM role ARN to assume | - | +| `--external-id` | No | External ID for role assumption | - | +| `--session-name` | No | Session name for role assumption | `kafka-proxy` | +| `--timeout` | No | Request timeout in seconds | `10` | + +## Authentication Methods + +### 1. Default Credentials +Uses the default AWS credential provider chain (environment variables, IAM roles, etc.). + +```bash +./plugin-aws-msk-iam-provider --region us-east-1 +``` + +### 2. Named Profile +Uses a specific AWS named profile. + +```bash +./plugin-aws-msk-iam-provider --region us-east-1 --profile my-profile +``` + +### 3. IAM Role Assumption +Assumes an IAM role for authentication. + +```bash +./plugin-aws-msk-iam-provider --region us-east-1 --role-arn arn:aws:iam::123456789012:role/MyRole --session-name kafka-proxy +``` + +### 4. IAM Role with External ID +Assumes an IAM role with external ID for enhanced security. + +```bash +./plugin-aws-msk-iam-provider --region us-east-1 --role-arn arn:aws:iam::123456789012:role/MyRole --external-id my-external-id --session-name kafka-proxy +``` + +## Integration with Kafka Proxy + +To use this plugin with the Kafka Proxy, configure it in your proxy configuration: + +```yaml +token-provider: + plugin: aws-msk-iam-provider + args: + - "--region=us-east-1" + - "--profile=my-profile" + # or for role assumption: + # - "--role-arn=arn:aws:iam::123456789012:role/MyRole" + # - "--session-name=kafka-proxy" +``` + +## Token Refresh + +The plugin automatically refreshes tokens 5 minutes before they expire. AWS MSK IAM tokens are typically valid for 15 minutes, and the plugin sets an internal expiry of 14 minutes to ensure safe refresh timing. + +## Error Handling + +- **Initial Token Generation**: Uses exponential backoff with up to 3 retries +- **Token Refresh**: Uses exponential backoff with up to 60 seconds max elapsed time +- **Graceful Degradation**: Continues operation even if refresh fails, will retry on next request + +## Dependencies + +This plugin uses the [AWS MSK IAM SASL Signer for Go](https://github.com/aws/aws-msk-iam-sasl-signer-go) library to generate authentication tokens. + +## Troubleshooting + +### Common Issues + +1. **Access Denied**: Ensure the IAM user/role has the necessary permissions for MSK cluster access +2. **Region Mismatch**: Verify the region matches your MSK cluster location +3. **Profile Not Found**: Check that the AWS profile exists in your credentials file +4. **Role Assumption Failed**: Verify the role ARN and ensure your credentials can assume the role + +### Debug Mode + +To enable debug logging for AWS credentials, set the environment variable: + +```bash +export AWS_DEBUG_CREDS=true +``` + +This will log the IAM identity being used for authentication. + +## Security Considerations + +- **External ID**: Use external ID when assuming roles for enhanced security +- **Session Names**: Use descriptive session names for better audit trails +- **Credentials**: Ensure credentials are properly secured and rotated regularly +- **Network**: Use TLS encryption for all communications with MSK clusters diff --git a/cmd/plugin-aws-msk-iam-provider/main.go b/cmd/plugin-aws-msk-iam-provider/main.go new file mode 100644 index 00000000..27eaaf0c --- /dev/null +++ b/cmd/plugin-aws-msk-iam-provider/main.go @@ -0,0 +1,35 @@ +package main + +import ( + "os" + + awsmskiamprovider "github.com/grepplabs/kafka-proxy/pkg/libs/aws-msk-iam-provider" + "github.com/grepplabs/kafka-proxy/plugin/token-provider/shared" + "github.com/hashicorp/go-plugin" + "github.com/sirupsen/logrus" +) + +func main() { + // Create the AWS MSK IAM token signer + signer := awsmskiamprovider.NewAwsMskIamTokenSigner() + + // Create the factory with the signer + factory := awsmskiamprovider.NewFactory(signer) + + // Create the token provider + tokenProvider, err := factory.New(os.Args[1:]) + + if err != nil { + logrus.Errorf("cannot initialize aws-msk-iam-token provider: %v", err) + os.Exit(1) + } + + plugin.Serve(&plugin.ServeConfig{ + HandshakeConfig: shared.Handshake, + Plugins: map[string]plugin.Plugin{ + "tokenProvider": &shared.TokenProviderPlugin{Impl: tokenProvider}, + }, + // A non-nil value here enables gRPC serving for this plugin... + GRPCServer: plugin.DefaultGRPCServer, + }) +} diff --git a/go.mod b/go.mod index 0ff2de45..b3024103 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/grepplabs/kafka-proxy -go 1.23 +go 1.23.0 require ( github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 @@ -39,6 +39,7 @@ require ( require ( cloud.google.com/go/compute/metadata v0.5.2 // indirect github.com/Azure/go-ntlmssp v0.0.0-20200615164410-66371956d46c // indirect + github.com/aws/aws-msk-iam-sasl-signer-go v1.0.4 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.28 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.32 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.32 // indirect diff --git a/go.sum b/go.sum index 97c78010..fb3e5903 100644 --- a/go.sum +++ b/go.sum @@ -14,6 +14,8 @@ github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk5 github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= +github.com/aws/aws-msk-iam-sasl-signer-go v1.0.4 h1:2jAwFwA0Xgcx94dUId+K24yFabsKYDtAhCgyMit6OqE= +github.com/aws/aws-msk-iam-sasl-signer-go v1.0.4/go.mod h1:MVYeeOhILFFemC/XlYTClvBjYZrg/EPd3ts885KrNTI= github.com/aws/aws-sdk-go-v2 v1.36.1 h1:iTDl5U6oAhkNPba0e1t1hrwAo02ZMqbrGq4k5JBWM5E= github.com/aws/aws-sdk-go-v2 v1.36.1/go.mod h1:5PMILGVKiW32oDzjj6RU52yrNrDPUHcbZQYr1sM7qmM= github.com/aws/aws-sdk-go-v2/config v1.29.6 h1:fqgqEKK5HaZVWLQoLiC9Q+xDlSp+1LYidp6ybGE2OGg= diff --git a/pkg/libs/aws-msk-iam-provider/factory.go b/pkg/libs/aws-msk-iam-provider/factory.go new file mode 100644 index 00000000..ae0e85cb --- /dev/null +++ b/pkg/libs/aws-msk-iam-provider/factory.go @@ -0,0 +1,71 @@ +package awsmskiamprovider + +import ( + "flag" + + "github.com/grepplabs/kafka-proxy/pkg/apis" + "github.com/grepplabs/kafka-proxy/pkg/registry" +) + +func init() { + registry.NewComponentInterface(new(apis.TokenProviderFactory)) + registry.Register(new(Factory), "aws-msk-iam-provider") +} + +func (f *pluginMeta) flagSet() *flag.FlagSet { + fs := flag.NewFlagSet("aws msk iam provider settings", flag.ContinueOnError) + return fs +} + +type pluginMeta struct { + region string + profile string + roleARN string + externalID string + sessionName string + timeout int +} + +// Factory type +type Factory struct { + signer TokenSigner +} + +// NewFactory creates a new factory with the given signer +func NewFactory(signer TokenSigner) *Factory { + return &Factory{ + signer: signer, + } +} + +// New implements apis.TokenProviderFactory +func (t *Factory) New(params []string) (apis.TokenProvider, error) { + pluginMeta := &pluginMeta{} + fs := pluginMeta.flagSet() + fs.StringVar(&pluginMeta.region, "region", "", "AWS region (required)") + fs.StringVar(&pluginMeta.profile, "profile", "", "AWS named profile (optional)") + fs.StringVar(&pluginMeta.roleARN, "role-arn", "", "IAM role ARN to assume (optional)") + fs.StringVar(&pluginMeta.externalID, "external-id", "", "External ID for role assumption (optional)") + fs.StringVar(&pluginMeta.sessionName, "session-name", "kafka-proxy", "Session name for role assumption") + fs.IntVar(&pluginMeta.timeout, "timeout", 10, "Request timeout in seconds") + + err := fs.Parse(params) + if err != nil { + return nil, err + } + + if pluginMeta.region == "" { + return nil, flag.ErrHelp + } + + options := TokenProviderOptions{ + Region: pluginMeta.region, + Profile: pluginMeta.profile, + RoleARN: pluginMeta.roleARN, + ExternalID: pluginMeta.externalID, + SessionName: pluginMeta.sessionName, + Timeout: pluginMeta.timeout, + } + + return NewTokenProvider(options, t.signer) +} diff --git a/pkg/libs/aws-msk-iam-provider/factory_test.go b/pkg/libs/aws-msk-iam-provider/factory_test.go new file mode 100644 index 00000000..d5af3350 --- /dev/null +++ b/pkg/libs/aws-msk-iam-provider/factory_test.go @@ -0,0 +1,60 @@ +package awsmskiamprovider + +import ( + "testing" + + "github.com/grepplabs/kafka-proxy/pkg/apis" + "github.com/stretchr/testify/assert" +) + +func TestFactory_New(t *testing.T) { + tests := []struct { + name string + params []string + wantErr bool + errorSubstr string + }{ + { + name: "missing region", + params: []string{}, + wantErr: true, + errorSubstr: "help requested", + }, + { + name: "invalid flag", + params: []string{"--invalid-flag"}, + wantErr: true, + errorSubstr: "flag provided but not defined", + }, + { + name: "happy path", + params: []string{"--region", "us-east-1"}, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockSigner := NewMockTokenSigner() + factory := NewFactory(mockSigner) + provider, err := factory.New(tt.params) + + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, provider) + if tt.errorSubstr != "" { + assert.Contains(t, err.Error(), tt.errorSubstr) + } + } else { + assert.NoError(t, err) + assert.NotNil(t, provider) + } + }) + } +} + +func TestFactory_ImplementsInterface(t *testing.T) { + mockSigner := NewMockTokenSigner() + factory := NewFactory(mockSigner) + assert.Implements(t, (*apis.TokenProviderFactory)(nil), factory) +} diff --git a/pkg/libs/aws-msk-iam-provider/mock.go b/pkg/libs/aws-msk-iam-provider/mock.go new file mode 100644 index 00000000..39760968 --- /dev/null +++ b/pkg/libs/aws-msk-iam-provider/mock.go @@ -0,0 +1,158 @@ +package awsmskiamprovider + +import ( + "context" + "time" +) + +// MockTokenSigner provides a mock implementation for testing +type MockTokenSigner struct { + // Mock responses + MockToken string + MockError error + MockExpiry time.Time + + // Call tracking + GenerateCallCount int + LastRegion string + LastProfile string + LastRoleARN string + LastExternalID string + LastSessionName string +} + +// NewMockTokenSigner creates a new mock token signer +func NewMockTokenSigner() *MockTokenSigner { + return &MockTokenSigner{ + MockToken: "mock-aws-msk-iam-token", + MockExpiry: time.Now().Add(15 * time.Minute), + } +} + +// GenerateAuthToken mocks the AWS MSK IAM token generation +func (m *MockTokenSigner) GenerateAuthToken(ctx context.Context, region string) (string, int64, error) { + m.GenerateCallCount++ + m.LastRegion = region + + // Check if context is cancelled + select { + case <-ctx.Done(): + return "", 0, ctx.Err() + default: + } + + if m.MockError != nil { + return "", 0, m.MockError + } + + return m.MockToken, m.MockExpiry.Unix(), nil +} + +// GenerateAuthTokenFromProfile mocks profile-based token generation +func (m *MockTokenSigner) GenerateAuthTokenFromProfile(ctx context.Context, region string, profile string) (string, int64, error) { + m.GenerateCallCount++ + m.LastRegion = region + m.LastProfile = profile + + // Check if context is cancelled + select { + case <-ctx.Done(): + return "", 0, ctx.Err() + default: + } + + if m.MockError != nil { + return "", 0, m.MockError + } + + return m.MockToken, m.MockExpiry.Unix(), nil +} + +// GenerateAuthTokenFromRole mocks role-based token generation +func (m *MockTokenSigner) GenerateAuthTokenFromRole(ctx context.Context, region string, roleARN string, sessionName string) (string, int64, error) { + m.GenerateCallCount++ + m.LastRegion = region + m.LastRoleARN = roleARN + m.LastSessionName = sessionName + + // Check if context is cancelled + select { + case <-ctx.Done(): + return "", 0, ctx.Err() + default: + } + + if m.MockError != nil { + return "", 0, m.MockError + } + + return m.MockToken, m.MockExpiry.Unix(), nil +} + +// GenerateAuthTokenFromRoleWithExternalId mocks role-based token generation with external ID +func (m *MockTokenSigner) GenerateAuthTokenFromRoleWithExternalId(ctx context.Context, region string, roleARN string, sessionName string, externalID string) (string, int64, error) { + m.GenerateCallCount++ + m.LastRegion = region + m.LastRoleARN = roleARN + m.LastSessionName = sessionName + m.LastExternalID = externalID + + // Check if context is cancelled + select { + case <-ctx.Done(): + return "", 0, ctx.Err() + default: + } + + if m.MockError != nil { + return "", 0, m.MockError + } + + return m.MockToken, m.MockExpiry.Unix(), nil +} + +// Reset resets the mock state +func (m *MockTokenSigner) Reset() { + m.GenerateCallCount = 0 + m.LastRegion = "" + m.LastProfile = "" + m.LastRoleARN = "" + m.LastExternalID = "" + m.LastSessionName = "" +} + +// SetMockResponse sets the mock response +func (m *MockTokenSigner) SetMockResponse(token string, expiry time.Time, err error) { + m.MockToken = token + m.MockExpiry = expiry + m.MockError = err +} + +// SetMockError sets just the mock error +func (m *MockTokenSigner) SetMockError(err error) { + m.MockError = err +} + +// SetMockToken sets just the mock token +func (m *MockTokenSigner) SetMockToken(token string) { + m.MockToken = token +} + +// SetMockExpiry sets just the mock expiry +func (m *MockTokenSigner) SetMockExpiry(expiry time.Time) { + m.MockExpiry = expiry +} + +// GetCallCount returns the number of times GenerateAuthToken was called +func (m *MockTokenSigner) GetCallCount() int { + return m.GenerateCallCount +} + +// WasCalledWith checks if the mock was called with specific parameters +func (m *MockTokenSigner) WasCalledWith(region, profile, roleARN, externalID, sessionName string) bool { + return m.LastRegion == region && + m.LastProfile == profile && + m.LastRoleARN == roleARN && + m.LastExternalID == externalID && + m.LastSessionName == sessionName +} diff --git a/pkg/libs/aws-msk-iam-provider/plugin.go b/pkg/libs/aws-msk-iam-provider/plugin.go new file mode 100644 index 00000000..d511b7e3 --- /dev/null +++ b/pkg/libs/aws-msk-iam-provider/plugin.go @@ -0,0 +1,181 @@ +package awsmskiamprovider + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/cenkalti/backoff" + "github.com/grepplabs/kafka-proxy/pkg/apis" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" +) + +const ( + StatusOK = 0 + StatusGetTokenFailed = 1 +) + +var ( + // Token expiry buffer - refresh token 5 minutes before expiry + tokenExpiryBuffer = 5 * time.Minute + nowFn = time.Now +) + +// TokenProvider - represents AWS MSK IAM TokenProvider +type TokenProvider struct { + timeout time.Duration + region string + + // Authentication configuration + profile string + roleARN string + externalID string + sessionName string + + // Token generation + signer TokenSigner + + currentToken string + currentExpiry time.Time + l sync.RWMutex +} + +// TokenProviderOptions - options specific for AWS MSK IAM +type TokenProviderOptions struct { + Region string + Profile string + RoleARN string + ExternalID string + SessionName string + Timeout int +} + +// NewTokenProvider - Generate new AWS MSK IAM token provider +func NewTokenProvider(options TokenProviderOptions, signer TokenSigner) (*TokenProvider, error) { + tokenProvider := &TokenProvider{ + timeout: time.Duration(options.Timeout) * time.Second, + region: options.Region, + profile: options.Profile, + roleARN: options.RoleARN, + externalID: options.ExternalID, + sessionName: options.SessionName, + signer: signer, + } + + // Initialize with first token + op := func() error { + return initToken(tokenProvider) + } + + err := backoff.Retry( + op, + backoff.WithMaxTries(backoff.NewConstantBackOff(1*time.Second), 3)) + + if err != nil { + return nil, errors.Wrap(err, "getting of initial AWS MSK IAM token failed") + } + + tokenRefresher := &TokenRefresher{ + tokenProvider: tokenProvider, + stopChannel: make(chan bool, 1), + } + + go tokenRefresher.refreshLoop() + + return tokenProvider, nil +} + +func initToken(tokenProvider *TokenProvider) error { + ctx, cancel := context.WithTimeout(context.Background(), tokenProvider.timeout) + defer cancel() + + resp, err := tokenProvider.GetToken(ctx, apis.TokenRequest{}) + + if err != nil { + return err + } + + if !resp.Success { + return fmt.Errorf("get token failed with status: %d", resp.Status) + } + + return nil +} + +func (p *TokenProvider) getCurrentToken() string { + p.l.RLock() + defer p.l.RUnlock() + + if p.currentToken == "" { + return "" + } + + if p.shouldRenew() { + return "" + } + + return p.currentToken +} + +func (p *TokenProvider) shouldRenew() bool { + return nowFn().After(p.currentExpiry.Add(-tokenExpiryBuffer)) +} + +func (p *TokenProvider) setCurrentToken(token string, expiry time.Time) { + p.l.Lock() + defer p.l.Unlock() + + p.currentToken = token + p.currentExpiry = expiry +} + +// GetToken implements apis.TokenProvider.GetToken method +func (p *TokenProvider) GetToken(parent context.Context, _ apis.TokenRequest) (apis.TokenResponse, error) { + currentToken := p.getCurrentToken() + + if currentToken != "" { + return getTokenResponse(currentToken, StatusOK) + } + + ctx, cancel := context.WithTimeout(parent, p.timeout) + defer cancel() + + var token string + var err error + + // Generate token based on configuration + if p.roleARN != "" { + if p.externalID != "" { + token, _, err = p.signer.GenerateAuthTokenFromRoleWithExternalId( + ctx, p.region, p.roleARN, p.sessionName, p.externalID) + } else { + token, _, err = p.signer.GenerateAuthTokenFromRole( + ctx, p.region, p.roleARN, p.sessionName) + } + } else if p.profile != "" { + token, _, err = p.signer.GenerateAuthTokenFromProfile(ctx, p.region, p.profile) + } else { + token, _, err = p.signer.GenerateAuthToken(ctx, p.region) + } + + if err != nil { + logrus.Errorf("GenerateAuthToken failed: %v", err) + return getTokenResponse("", StatusGetTokenFailed) + } + + // AWS MSK IAM tokens are typically valid for 15 minutes + // We'll set expiry to 14 minutes to be safe + expiry := nowFn().Add(14 * time.Minute) + p.setCurrentToken(token, expiry) + + logrus.Infof("New AWS MSK IAM token generated, expires at %v", expiry) + + return getTokenResponse(token, StatusOK) +} + +func getTokenResponse(token string, status int) (apis.TokenResponse, error) { + success := status == StatusOK + return apis.TokenResponse{Success: success, Status: int32(status), Token: token}, nil +} diff --git a/pkg/libs/aws-msk-iam-provider/plugin_test.go b/pkg/libs/aws-msk-iam-provider/plugin_test.go new file mode 100644 index 00000000..41439479 --- /dev/null +++ b/pkg/libs/aws-msk-iam-provider/plugin_test.go @@ -0,0 +1,508 @@ +package awsmskiamprovider + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/grepplabs/kafka-proxy/pkg/apis" + "github.com/stretchr/testify/assert" +) + +func TestTokenProvider_NewTokenProvider(t *testing.T) { + tests := []struct { + name string + options TokenProviderOptions + wantErr bool + }{ + { + name: "valid options with region only", + options: TokenProviderOptions{ + Region: "us-east-1", + Timeout: 10, + SessionName: "test-session", + }, + wantErr: false, + }, + { + name: "valid options with profile", + options: TokenProviderOptions{ + Region: "us-west-2", + Profile: "test-profile", + Timeout: 30, + SessionName: "test-session", + }, + wantErr: false, + }, + { + name: "valid options with role ARN", + options: TokenProviderOptions{ + Region: "eu-west-1", + RoleARN: "arn:aws:iam::123456789012:role/TestRole", + Timeout: 15, + SessionName: "test-session", + }, + wantErr: false, + }, + { + name: "valid options with role ARN and external ID", + options: TokenProviderOptions{ + Region: "ap-southeast-1", + RoleARN: "arn:aws:iam::123456789012:role/TestRole", + ExternalID: "test-external-id", + Timeout: 20, + SessionName: "test-session", + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Use mock signer instead of real AWS calls + mockSigner := NewMockTokenSigner() + provider, err := NewTokenProvider(tt.options, mockSigner) + + assert.NoError(t, err) + assert.NotNil(t, provider) + assert.Equal(t, tt.options.Region, provider.region) + assert.Equal(t, time.Duration(tt.options.Timeout)*time.Second, provider.timeout) + assert.Equal(t, tt.options.SessionName, provider.sessionName) + }) + } +} + +func TestTokenProvider_GetToken(t *testing.T) { + tests := []struct { + name string + options TokenProviderOptions + wantErr bool + expectedToken string + }{ + { + name: "provider with region only", + options: TokenProviderOptions{ + Region: "us-east-1", + Timeout: 10, + SessionName: "test", + }, + wantErr: false, + expectedToken: "mock-aws-msk-iam-token", + }, + { + name: "provider with profile", + options: TokenProviderOptions{ + Region: "us-west-2", + Profile: "test-profile", + Timeout: 10, + SessionName: "test", + }, + wantErr: false, + expectedToken: "mock-aws-msk-iam-token", + }, + { + name: "provider with role ARN", + options: TokenProviderOptions{ + Region: "eu-west-1", + RoleARN: "arn:aws:iam::123456789012:role/TestRole", + Timeout: 10, + SessionName: "test", + }, + wantErr: false, + expectedToken: "mock-aws-msk-iam-token", + }, + { + name: "provider with role ARN and external ID", + options: TokenProviderOptions{ + Region: "ap-southeast-1", + RoleARN: "arn:aws:iam::123456789012:role/TestRole", + ExternalID: "test-external-id", + Timeout: 10, + SessionName: "test", + }, + wantErr: false, + expectedToken: "mock-aws-msk-iam-token", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockSigner := NewMockTokenSigner() + provider, err := NewTokenProvider(tt.options, mockSigner) + assert.NoError(t, err) + + ctx := context.Background() + resp, err := provider.GetToken(ctx, apis.TokenRequest{}) + + if tt.wantErr { + assert.Error(t, err) + assert.False(t, resp.Success) + assert.Equal(t, int32(StatusGetTokenFailed), resp.Status) + } else { + assert.NoError(t, err) + assert.True(t, resp.Success) + assert.Equal(t, int32(StatusOK), resp.Status) + assert.Equal(t, tt.expectedToken, resp.Token) + assert.Equal(t, 1, mockSigner.GetCallCount()) + } + }) + } +} + +func TestTokenProvider_GetToken_WithCurrentToken(t *testing.T) { + options := TokenProviderOptions{ + Region: "us-east-1", + Timeout: 10, + SessionName: "test", + } + + mockSigner := NewMockTokenSigner() + provider, err := NewTokenProvider(options, mockSigner) + assert.NoError(t, err) + + // The provider should already have a token from initialization + // Let's verify it's using the cached token on subsequent calls + initialCallCount := mockSigner.GetCallCount() + + ctx := context.Background() + resp, err := provider.GetToken(ctx, apis.TokenRequest{}) + + // Should return the current token without trying to generate a new one + assert.NoError(t, err) + assert.True(t, resp.Success) + assert.Equal(t, int32(StatusOK), resp.Status) + assert.NotEmpty(t, resp.Token) + + // Should not have called the mock signer again since we had a valid token + assert.Equal(t, initialCallCount, mockSigner.GetCallCount()) +} + +func TestTokenProvider_shouldRenew(t *testing.T) { + tests := []struct { + name string + provider *TokenProvider + expectedResult bool + }{ + { + name: "no current token", + provider: &TokenProvider{}, + expectedResult: true, + }, + { + name: "current token but no expiry set", + provider: &TokenProvider{ + currentToken: "test-token", + }, + expectedResult: true, + }, + { + name: "current token with future expiry", + provider: func() *TokenProvider { + p := &TokenProvider{ + currentToken: "test-token", + } + p.setCurrentToken("test-token", time.Now().Add(10*time.Minute)) + return p + }(), + expectedResult: false, + }, + { + name: "current token with past expiry", + provider: func() *TokenProvider { + p := &TokenProvider{ + currentToken: "test-token", + } + p.setCurrentToken("test-token", time.Now().Add(-10*time.Minute)) + return p + }(), + expectedResult: true, + }, + { + name: "current token expiring soon (within buffer)", + provider: func() *TokenProvider { + p := &TokenProvider{ + currentToken: "test-token", + } + // Set expiry to 3 minutes from now (less than 5-minute buffer) + p.setCurrentToken("test-token", time.Now().Add(3*time.Minute)) + return p + }(), + expectedResult: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.provider.shouldRenew() + assert.Equal(t, tt.expectedResult, result) + }) + } +} + +func TestTokenProvider_getCurrentToken(t *testing.T) { + tests := []struct { + name string + provider *TokenProvider + expectedToken string + }{ + { + name: "no current token", + provider: &TokenProvider{}, + expectedToken: "", + }, + { + name: "current token but no expiry set", + provider: &TokenProvider{ + currentToken: "test-token", + }, + expectedToken: "", + }, + { + name: "current token with future expiry", + provider: func() *TokenProvider { + p := &TokenProvider{ + currentToken: "test-token", + } + p.setCurrentToken("test-token", time.Now().Add(10*time.Minute)) + return p + }(), + expectedToken: "test-token", + }, + { + name: "current token with past expiry", + provider: func() *TokenProvider { + p := &TokenProvider{ + currentToken: "test-token", + } + p.setCurrentToken("test-token", time.Now().Add(-10*time.Minute)) + return p + }(), + expectedToken: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + token := tt.provider.getCurrentToken() + assert.Equal(t, tt.expectedToken, token) + }) + } +} + +func TestTokenProvider_setCurrentToken(t *testing.T) { + provider := &TokenProvider{} + expectedToken := "test-token" + expectedExpiry := time.Now().Add(10 * time.Minute) + + provider.setCurrentToken(expectedToken, expectedExpiry) + + assert.Equal(t, expectedToken, provider.currentToken) + assert.Equal(t, expectedExpiry, provider.currentExpiry) +} + +func TestGetTokenResponse(t *testing.T) { + tests := []struct { + name string + token string + status int + expectedSuccess bool + expectedStatus int32 + }{ + { + name: "successful response", + token: "test-token", + status: StatusOK, + expectedSuccess: true, + expectedStatus: int32(StatusOK), + }, + { + name: "failed response", + token: "", + status: StatusGetTokenFailed, + expectedSuccess: false, + expectedStatus: int32(StatusGetTokenFailed), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resp, err := getTokenResponse(tt.token, tt.status) + + assert.NoError(t, err) + assert.Equal(t, tt.expectedSuccess, resp.Success) + assert.Equal(t, tt.expectedStatus, resp.Status) + assert.Equal(t, tt.token, resp.Token) + }) + } +} + +func TestTokenProvider_NewTokenProvider_WithMockError(t *testing.T) { + options := TokenProviderOptions{ + Region: "us-east-1", + Timeout: 10, + SessionName: "test", + } + + mockSigner := NewMockTokenSigner() + mockSigner.SetMockError(fmt.Errorf("mock AWS error")) + provider, err := NewTokenProvider(options, mockSigner) + + // NewTokenProvider should fail when initial token generation fails + assert.Error(t, err) + assert.Nil(t, provider) + assert.Contains(t, err.Error(), "getting of initial AWS MSK IAM token failed") + // The retry logic will call the signer multiple times (3 retries + 1 initial attempt) + assert.GreaterOrEqual(t, mockSigner.GetCallCount(), 1) +} + +func TestTokenProvider_GetToken_WithMockError(t *testing.T) { + options := TokenProviderOptions{ + Region: "us-east-1", + Timeout: 10, + SessionName: "test", + } + + mockSigner := NewMockTokenSigner() + provider, err := NewTokenProvider(options, mockSigner) + assert.NoError(t, err) + + mockSigner.SetMockError(fmt.Errorf("mock AWS error")) + provider.setCurrentToken("test-token", time.Now().Add(-10*time.Minute)) + + ctx := context.Background() + resp, err := provider.GetToken(ctx, apis.TokenRequest{}) + + assert.NoError(t, err) + assert.False(t, resp.Success) + assert.Equal(t, int32(StatusGetTokenFailed), resp.Status) + assert.Equal(t, "", resp.Token) + assert.Equal(t, 2, mockSigner.GetCallCount()) +} + +func TestTokenProvider_GetToken_MultipleCalls(t *testing.T) { + options := TokenProviderOptions{ + Region: "us-east-1", + Timeout: 10, + SessionName: "test", + } + + mockSigner := NewMockTokenSigner() + provider, err := NewTokenProvider(options, mockSigner) + assert.NoError(t, err) + + ctx := context.Background() + + // First call - should generate a new token + resp1, err := provider.GetToken(ctx, apis.TokenRequest{}) + assert.NoError(t, err) + assert.True(t, resp1.Success) + assert.Equal(t, 1, mockSigner.GetCallCount()) + + // Second call - should return the cached token + resp2, err := provider.GetToken(ctx, apis.TokenRequest{}) + assert.NoError(t, err) + assert.True(t, resp2.Success) + assert.Equal(t, resp1.Token, resp2.Token) + assert.Equal(t, 1, mockSigner.GetCallCount()) // Should not have called again +} + +func TestTokenProvider_GetToken_WithExpiredToken(t *testing.T) { + options := TokenProviderOptions{ + Region: "us-east-1", + Timeout: 10, + SessionName: "test", + } + + mockSigner := NewMockTokenSigner() + provider, err := NewTokenProvider(options, mockSigner) + assert.NoError(t, err) + + // Set a token that's already expired + provider.setCurrentToken("expired-token", time.Now().Add(-1*time.Hour)) + + ctx := context.Background() + resp, err := provider.GetToken(ctx, apis.TokenRequest{}) + + // Should generate a new token since the current one is expired + assert.NoError(t, err) + assert.True(t, resp.Success) + assert.Equal(t, int32(StatusOK), resp.Status) + assert.NotEqual(t, "expired-token", resp.Token) + assert.Equal(t, 2, mockSigner.GetCallCount()) // 1 from init + 1 from this call +} + +func TestTokenProvider_GetToken_WithTokenExpiringSoon(t *testing.T) { + options := TokenProviderOptions{ + Region: "us-east-1", + Timeout: 10, + SessionName: "test", + } + + mockSigner := NewMockTokenSigner() + provider, err := NewTokenProvider(options, mockSigner) + assert.NoError(t, err) + + // Set a token that expires in 3 minutes (within the 5-minute buffer) + provider.setCurrentToken("expiring-soon-token", time.Now().Add(3*time.Minute)) + + ctx := context.Background() + resp, err := provider.GetToken(ctx, apis.TokenRequest{}) + + // Should generate a new token since the current one expires soon + assert.NoError(t, err) + assert.True(t, resp.Success) + assert.Equal(t, int32(StatusOK), resp.Status) + assert.NotEqual(t, "expiring-soon-token", resp.Token) + assert.Equal(t, 2, mockSigner.GetCallCount()) // 1 from init + 1 from this call +} + +func TestTokenProvider_GetToken_WithValidToken(t *testing.T) { + options := TokenProviderOptions{ + Region: "us-east-1", + Timeout: 10, + SessionName: "test", + } + + mockSigner := NewMockTokenSigner() + provider, err := NewTokenProvider(options, mockSigner) + assert.NoError(t, err) + + // Set a token that's valid for a long time + provider.setCurrentToken("valid-token", time.Now().Add(30*time.Minute)) + + ctx := context.Background() + resp, err := provider.GetToken(ctx, apis.TokenRequest{}) + + // Should return the existing token without generating a new one + assert.NoError(t, err) + assert.True(t, resp.Success) + assert.Equal(t, int32(StatusOK), resp.Status) + assert.Equal(t, "valid-token", resp.Token) + assert.Equal(t, 1, mockSigner.GetCallCount()) // Only from init, not from this call +} + +func TestTokenProvider_GetToken_WithEmptyToken(t *testing.T) { + options := TokenProviderOptions{ + Region: "us-east-1", + Timeout: 10, + SessionName: "test", + } + + mockSigner := NewMockTokenSigner() + provider, err := NewTokenProvider(options, mockSigner) + assert.NoError(t, err) + + // Clear the current token + provider.setCurrentToken("", time.Time{}) + + ctx := context.Background() + resp, err := provider.GetToken(ctx, apis.TokenRequest{}) + + // Should generate a new token since there's no current token + assert.NoError(t, err) + assert.True(t, resp.Success) + assert.Equal(t, int32(StatusOK), resp.Status) + assert.NotEmpty(t, resp.Token) + assert.Equal(t, 2, mockSigner.GetCallCount()) // 1 from init + 1 from this call +} diff --git a/pkg/libs/aws-msk-iam-provider/signer.go b/pkg/libs/aws-msk-iam-provider/signer.go new file mode 100644 index 00000000..1a05b363 --- /dev/null +++ b/pkg/libs/aws-msk-iam-provider/signer.go @@ -0,0 +1,43 @@ +package awsmskiamprovider + +import ( + "context" + + "github.com/aws/aws-msk-iam-sasl-signer-go/signer" +) + +// TokenSigner defines the interface for AWS MSK IAM token generation +type TokenSigner interface { + GenerateAuthToken(ctx context.Context, region string) (string, int64, error) + GenerateAuthTokenFromProfile(ctx context.Context, region string, profile string) (string, int64, error) + GenerateAuthTokenFromRole(ctx context.Context, region string, roleARN string, sessionName string) (string, int64, error) + GenerateAuthTokenFromRoleWithExternalId(ctx context.Context, region string, roleARN string, sessionName string, externalID string) (string, int64, error) +} + +// AwsMskIamTokenSigner implements TokenSigner using the actual AWS MSK IAM SASL signer +type AwsMskIamTokenSigner struct{} + +// NewAwsMskIamTokenSigner creates a new AWS MSK IAM token signer +func NewAwsMskIamTokenSigner() *AwsMskIamTokenSigner { + return &AwsMskIamTokenSigner{} +} + +// GenerateAuthToken generates a token using default credentials +func (r *AwsMskIamTokenSigner) GenerateAuthToken(ctx context.Context, region string) (string, int64, error) { + return signer.GenerateAuthToken(ctx, region) +} + +// GenerateAuthTokenFromProfile generates a token using a named profile +func (r *AwsMskIamTokenSigner) GenerateAuthTokenFromProfile(ctx context.Context, region string, profile string) (string, int64, error) { + return signer.GenerateAuthTokenFromProfile(ctx, region, profile) +} + +// GenerateAuthTokenFromRole generates a token using role assumption +func (r *AwsMskIamTokenSigner) GenerateAuthTokenFromRole(ctx context.Context, region string, roleARN string, sessionName string) (string, int64, error) { + return signer.GenerateAuthTokenFromRole(ctx, region, roleARN, sessionName) +} + +// GenerateAuthTokenFromRoleWithExternalId generates a token using role assumption with external ID +func (r *AwsMskIamTokenSigner) GenerateAuthTokenFromRoleWithExternalId(ctx context.Context, region string, roleARN string, sessionName string, externalID string) (string, int64, error) { + return signer.GenerateAuthTokenFromRoleWithExternalId(ctx, region, roleARN, sessionName, externalID) +} diff --git a/pkg/libs/aws-msk-iam-provider/ticker.go b/pkg/libs/aws-msk-iam-provider/ticker.go new file mode 100644 index 00000000..30858aaf --- /dev/null +++ b/pkg/libs/aws-msk-iam-provider/ticker.go @@ -0,0 +1,79 @@ +package awsmskiamprovider + +import ( + "context" + "fmt" + "time" + + "github.com/cenkalti/backoff" + "github.com/grepplabs/kafka-proxy/pkg/apis" + "github.com/sirupsen/logrus" +) + +// TokenRefresher - struct providing refreshing of AWS MSK IAM tokens +type TokenRefresher struct { + tokenProvider *TokenProvider + stopChannel chan bool +} + +func (p *TokenRefresher) refreshLoop() { + defer func() { + if r := recover(); r != nil { + var ok bool + err, ok := r.(error) + + if ok { + logrus.Error(fmt.Sprintf("AWS MSK IAM token refresh loop error %v", err)) + } + } + }() + + // Check every 30 seconds if token needs refresh + syncTicker := time.NewTicker(30 * time.Second) + + for { + select { + case <-syncTicker.C: + p.refreshTick() + case <-p.stopChannel: + return + } + } +} + +func (p *TokenRefresher) tryRefresh() error { + ctx, cancel := context.WithTimeout(context.Background(), p.tokenProvider.timeout) + defer cancel() + + resp, err := p.tokenProvider.GetToken(ctx, apis.TokenRequest{}) + + if err != nil { + return err + } + + if !resp.Success { + return fmt.Errorf("get token failed with status: %d", resp.Status) + } + + return nil +} + +func (p *TokenRefresher) refreshTick() { + // Check if we need to refresh the token + if p.tokenProvider.shouldRenew() { + op := func() error { + return p.tryRefresh() + } + + backOff := backoff.NewExponentialBackOff() + backOff.MaxElapsedTime = 60 * time.Second + backOff.MaxInterval = 10 * time.Second + err := backoff.Retry(op, backOff) + + if err != nil { + logrus.Errorf("refreshing of AWS MSK IAM token failed: %v", err) + } else { + logrus.Debugf("AWS MSK IAM token refreshed successfully") + } + } +} diff --git a/pkg/libs/aws-msk-iam-provider/ticker_test.go b/pkg/libs/aws-msk-iam-provider/ticker_test.go new file mode 100644 index 00000000..56a9f5d8 --- /dev/null +++ b/pkg/libs/aws-msk-iam-provider/ticker_test.go @@ -0,0 +1,45 @@ +package awsmskiamprovider + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestTokenRefresher_creation(t *testing.T) { + provider := &TokenProvider{ + timeout: 10 * time.Second, + region: "us-east-1", + } + + refresher := &TokenRefresher{ + tokenProvider: provider, + stopChannel: make(chan bool, 1), + } + + assert.NotNil(t, refresher) + assert.Equal(t, provider, refresher.tokenProvider) + assert.NotNil(t, refresher.stopChannel) +} + +func TestTokenRefresher_stopChannel(t *testing.T) { + provider := &TokenProvider{ + timeout: 10 * time.Second, + region: "us-east-1", + } + + refresher := &TokenRefresher{ + tokenProvider: provider, + stopChannel: make(chan bool, 1), + } + + // Test that stop channel can receive a signal + refresher.stopChannel <- true + select { + case <-refresher.stopChannel: + // Successfully received the stop signal + default: + t.Error("Stop channel should be able to receive signals") + } +} diff --git a/vendor/github.com/aws/aws-msk-iam-sasl-signer-go/LICENSE b/vendor/github.com/aws/aws-msk-iam-sasl-signer-go/LICENSE new file mode 100644 index 00000000..67db8588 --- /dev/null +++ b/vendor/github.com/aws/aws-msk-iam-sasl-signer-go/LICENSE @@ -0,0 +1,175 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. diff --git a/vendor/github.com/aws/aws-msk-iam-sasl-signer-go/NOTICE b/vendor/github.com/aws/aws-msk-iam-sasl-signer-go/NOTICE new file mode 100644 index 00000000..616fc588 --- /dev/null +++ b/vendor/github.com/aws/aws-msk-iam-sasl-signer-go/NOTICE @@ -0,0 +1 @@ +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. diff --git a/vendor/github.com/aws/aws-msk-iam-sasl-signer-go/signer/msk_auth_token_provider.go b/vendor/github.com/aws/aws-msk-iam-sasl-signer-go/signer/msk_auth_token_provider.go new file mode 100644 index 00000000..0ac4d962 --- /dev/null +++ b/vendor/github.com/aws/aws-msk-iam-sasl-signer-go/signer/msk_auth_token_provider.go @@ -0,0 +1,380 @@ +package signer + +import ( + "context" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "fmt" + "log" + "net/http" + "net/url" + "runtime" + "strconv" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/sts" +) + +const ( + ActionType = "Action" // ActionType represents the key for the action type in the request. + ActionName = "kafka-cluster:Connect" // ActionName represents the specific action name for connecting to a Kafka cluster. + SigningName = "kafka-cluster" // SigningName represents the signing name for the Kafka cluster. + UserAgentKey = "User-Agent" // UserAgentKey represents the key for the User-Agent parameter in the request. + LibName = "aws-msk-iam-sasl-signer-go" // LibName represents the name of the library. + DateQueryKey = "X-Amz-Date" // DateQueryKey represents the key for the date in the query parameters. + ExpiresQueryKey = "X-Amz-Expires" // ExpiresQueryKey represents the key for the expiration time in the query parameters. + DefaultSessionName = "MSKSASLDefaultSession" // DefaultSessionName represents the default session name for assuming a role. + DefaultExpirySeconds = 900 // DefaultExpirySeconds represents the default expiration time in seconds. +) + +var ( + endpointURLTemplate = "kafka.%s.amazonaws.com" // endpointURLTemplate represents the template for the Kafka endpoint URL + AwsDebugCreds = false // AwsDebugCreds flag indicates whether credentials should be debugged +) + +// GenerateAuthToken generates base64 encoded signed url as auth token from default credentials. +// Loads the IAM credentials from default credentials provider chain. +func GenerateAuthToken(ctx context.Context, region string) (string, int64, error) { + credentials, err := loadDefaultCredentials(ctx, region) + + if err != nil { + return "", 0, fmt.Errorf("failed to load credentials: %w", err) + } + + return constructAuthToken(ctx, region, credentials) +} + +// GenerateAuthTokenFromProfile generates base64 encoded signed url as auth token by loading IAM credentials from an AWS named profile. +func GenerateAuthTokenFromProfile(ctx context.Context, region string, awsProfile string) (string, int64, error) { + return GenerateAuthTokenFromProfileWithSharedConfigFiles(ctx, region, awsProfile, nil) +} + +// GenerateAuthTokenFromProfileWithSharedConfigFiles generates base64 encoded signed url as auth token by loading IAM credentials from an AWS named profile with shared config files. +func GenerateAuthTokenFromProfileWithSharedConfigFiles(ctx context.Context, region string, awsProfile string, sharedConfigFiles []string) (string, int64, error) { + credentials, err := loadCredentialsFromProfile(ctx, region, awsProfile, sharedConfigFiles) + + if err != nil { + return "", 0, fmt.Errorf("failed to load credentials: %w", err) + } + + return constructAuthToken(ctx, region, credentials) +} + +// GenerateAuthTokenFromRole generates base64 encoded signed url as auth token by loading IAM credentials from an aws role Arn +func GenerateAuthTokenFromRole( + ctx context.Context, region string, roleArn string, stsSessionName string, +) (string, int64, error) { + return GenerateAuthTokenFromRoleWithExternalId(ctx, region, roleArn, stsSessionName, "") +} + +// GenerateAuthTokenFromRoleWithExternalId generates base64 encoded signed url as auth token by loading IAM credentials from an aws role Arn +// +// If the provided externalId is empty, it behaves exactly like GenerateAuthTokenFromRole. +func GenerateAuthTokenFromRoleWithExternalId( + ctx context.Context, region string, roleArn string, stsSessionName string, externalId string, +) (string, int64, error) { + if stsSessionName == "" { + stsSessionName = DefaultSessionName + } + credentials, err := loadCredentialsFromRoleArn(ctx, region, roleArn, stsSessionName, externalId) + + if err != nil { + return "", 0, fmt.Errorf("failed to load credentials: %w", err) + } + + return constructAuthToken(ctx, region, credentials) +} + +// GenerateAuthTokenFromWebIdentity generates base64 encoded signed url as auth token by loading IAM credentials from a web identity role Arn. +func GenerateAuthTokenFromWebIdentity( + ctx context.Context, region string, roleArn string, webIdentityToken string, stsSessionName string, +) (string, int64, error) { + if stsSessionName == "" { + stsSessionName = DefaultSessionName + } + + credentials, err := loadCredentialsFromWebIdentityParameters(ctx, region, roleArn, webIdentityToken, stsSessionName) + + if err != nil { + return "", 0, fmt.Errorf("failed to load credentials: %w", err) + } + + return constructAuthToken(ctx, region, credentials) +} + +// GenerateAuthTokenFromCredentialsProvider generates base64 encoded signed url as auth token by loading IAM credentials +// from an aws credentials provider +func GenerateAuthTokenFromCredentialsProvider( + ctx context.Context, region string, credentialsProvider aws.CredentialsProvider, +) (string, int64, error) { + credentials, err := loadCredentialsFromCredentialsProvider(ctx, credentialsProvider) + + if err != nil { + return "", 0, fmt.Errorf("failed to load credentials: %w", err) + } + + return constructAuthToken(ctx, region, credentials) +} + +// Loads credentials from the default credential chain. +func loadDefaultCredentials(ctx context.Context, region string) (*aws.Credentials, error) { + cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region)) + + if err != nil { + return nil, fmt.Errorf("unable to load SDK config: %w", err) + } + + return loadCredentialsFromCredentialsProvider(ctx, cfg.Credentials) +} + +// Loads credentials from a named aws profile. +func loadCredentialsFromProfile(ctx context.Context, region string, awsProfile string, sharedConfigFiles []string) (*aws.Credentials, error) { + optsFns := []func(*config.LoadOptions) error{ + config.WithRegion(region), + config.WithSharedConfigProfile(awsProfile), + } + if len(sharedConfigFiles) > 0 { + optsFns = append(optsFns, config.WithSharedConfigFiles(sharedConfigFiles)) + } + cfg, err := config.LoadDefaultConfig(ctx, + optsFns..., + ) + + if err != nil { + return nil, fmt.Errorf("unable to load SDK config: %w", err) + } + + return loadCredentialsFromCredentialsProvider(ctx, cfg.Credentials) +} + +// Loads credentials from a named by assuming the passed role. +// This implementation creates a new sts client for every call to get or refresh token. In order to avoid this, please +// use your own credentials provider. +// If you wish to use regional endpoint, please pass your own credentials provider. +func loadCredentialsFromRoleArn( + ctx context.Context, region string, roleArn string, stsSessionName string, externalId string, +) (*aws.Credentials, error) { + cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region)) + + if err != nil { + return nil, fmt.Errorf("unable to load SDK config: %w", err) + } + + stsClient := sts.NewFromConfig(cfg) + + assumeRoleInput := &sts.AssumeRoleInput{ + RoleArn: aws.String(roleArn), + RoleSessionName: aws.String(stsSessionName), + } + if externalId != "" { + assumeRoleInput.ExternalId = aws.String(externalId) + } + assumeRoleOutput, err := stsClient.AssumeRole(ctx, assumeRoleInput) + if err != nil { + return nil, fmt.Errorf("unable to assume role, %s: %w", roleArn, err) + } + + //Create new aws.Credentials instance using the credentials from AssumeRoleOutput.Credentials + creds := aws.Credentials{ + AccessKeyID: *assumeRoleOutput.Credentials.AccessKeyId, + SecretAccessKey: *assumeRoleOutput.Credentials.SecretAccessKey, + SessionToken: *assumeRoleOutput.Credentials.SessionToken, + } + + return &creds, nil +} + +// Loads credentials from a named by assuming the passed web identity role and id token. +// This implementation creates a new sts client for every call to get or refresh token. In order to avoid this, please +// use your own credentials' provider. +// If you wish to use regional endpoint, please pass your own credentials' provider. +func loadCredentialsFromWebIdentityParameters( + ctx context.Context, region, roleArn, webIdentityToken, stsSessionName string, +) (*aws.Credentials, error) { + cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region)) + + if err != nil { + return nil, fmt.Errorf("unable to load SDK config: %w", err) + } + + stsClient := sts.NewFromConfig(cfg) + + assumeRoleWithWebIdentityInput := &sts.AssumeRoleWithWebIdentityInput{ + RoleArn: aws.String(roleArn), + RoleSessionName: aws.String(stsSessionName), + WebIdentityToken: aws.String(webIdentityToken), + } + assumeRoleWithWebIdentityOutput, err := stsClient.AssumeRoleWithWebIdentity(ctx, assumeRoleWithWebIdentityInput) + if err != nil { + return nil, fmt.Errorf("unable to assume role with web identity, %s: %w", roleArn, err) + } + + //Create new aws.Credentials instance using the credentials from AssumeRoleWithWebIdentityOutput.Credentials + creds := aws.Credentials{ + AccessKeyID: *assumeRoleWithWebIdentityOutput.Credentials.AccessKeyId, + SecretAccessKey: *assumeRoleWithWebIdentityOutput.Credentials.SecretAccessKey, + SessionToken: *assumeRoleWithWebIdentityOutput.Credentials.SessionToken, + } + + return &creds, nil +} + +// Loads credentials from the credentials provider +func loadCredentialsFromCredentialsProvider( + ctx context.Context, credentialsProvider aws.CredentialsProvider, +) (*aws.Credentials, error) { + creds, err := credentialsProvider.Retrieve(ctx) + return &creds, err +} + +// Constructs Auth Token. +func constructAuthToken(ctx context.Context, region string, credentials *aws.Credentials) (string, int64, error) { + endpointURL := fmt.Sprintf(endpointURLTemplate, region) + + if credentials == nil || credentials.AccessKeyID == "" || credentials.SecretAccessKey == "" { + return "", 0, fmt.Errorf("aws credentials cannot be empty") + } + + if AwsDebugCreds { + logCallerIdentity(ctx, region, *credentials) + } + + req, err := buildRequest(DefaultExpirySeconds, endpointURL) + if err != nil { + return "", 0, fmt.Errorf("failed to build request for signing: %w", err) + } + + signedURL, err := signRequest(ctx, req, region, credentials) + if err != nil { + return "", 0, fmt.Errorf("failed to sign request with aws sig v4: %w", err) + } + + expirationTimeMs, err := getExpirationTimeMs(signedURL) + if err != nil { + return "", 0, fmt.Errorf("failed to extract expiration from signed url: %w", err) + } + + signedURLWithUserAgent, err := addUserAgent(signedURL) + if err != nil { + return "", 0, fmt.Errorf("failed to add user agent to the signed url: %w", err) + } + + return base64Encode(signedURLWithUserAgent), expirationTimeMs, nil +} + +// Build https request with query parameters in order to sign. +func buildRequest(expirySeconds int, endpointURL string) (*http.Request, error) { + query := url.Values{ + ActionType: {ActionName}, + ExpiresQueryKey: {strconv.FormatInt(int64(expirySeconds), 10)}, + } + + authURL := url.URL{ + Host: endpointURL, + Scheme: "https", + Path: "/", + RawQuery: query.Encode(), + } + + return http.NewRequest(http.MethodGet, authURL.String(), nil) +} + +// Sign request with aws sig v4. +func signRequest(ctx context.Context, req *http.Request, region string, credentials *aws.Credentials) (string, error) { + signer := v4.NewSigner() + signedURL, _, err := signer.PresignHTTP(ctx, *credentials, req, + calculateSHA256Hash(""), + SigningName, + region, + time.Now().UTC(), + ) + + return signedURL, err +} + +// Parses the URL and gets the expiration time in millis associated with the signed url +func getExpirationTimeMs(signedURL string) (int64, error) { + parsedURL, err := url.Parse(signedURL) + + if err != nil { + return 0, fmt.Errorf("failed to parse the signed url: %w", err) + } + + params := parsedURL.Query() + date, err := time.Parse("20060102T150405Z", params.Get(DateQueryKey)) + + if err != nil { + return 0, fmt.Errorf("failed to parse the 'X-Amz-Date' param from signed url: %w", err) + } + + signingTimeMs := date.UnixNano() / int64(time.Millisecond) + expiryDurationSeconds, err := strconv.ParseInt(params.Get(ExpiresQueryKey), 10, 64) + + if err != nil { + return 0, fmt.Errorf("failed to parse the 'X-Amz-Expires' param from signed url: %w", err) + } + + expiryDurationMs := expiryDurationSeconds * 1000 + expiryMs := signingTimeMs + expiryDurationMs + return expiryMs, nil +} + +// Calculate sha256Hash and hex encode it. +func calculateSHA256Hash(input string) string { + hash := sha256.Sum256([]byte(input)) + return hex.EncodeToString(hash[:]) +} + +// Base64 encode with raw url encoding. +func base64Encode(signedURL string) string { + signedURLBytes := []byte(signedURL) + return base64.RawURLEncoding.EncodeToString(signedURLBytes) +} + +// Add user agent to the signed url +func addUserAgent(signedURL string) (string, error) { + parsedSignedURL, err := url.Parse(signedURL) + + if err != nil { + return "", fmt.Errorf("failed to parse signed url: %w", err) + } + + query := parsedSignedURL.Query() + userAgent := strings.Join([]string{LibName, version, runtime.Version()}, "/") + query.Set(UserAgentKey, userAgent) + parsedSignedURL.RawQuery = query.Encode() + + return parsedSignedURL.String(), nil +} + +// Log caller identity to debug which credentials are being picked up +func logCallerIdentity(ctx context.Context, region string, awsCredentials aws.Credentials) { + cfg, err := config.LoadDefaultConfig(ctx, + config.WithRegion(region), + config.WithCredentialsProvider(credentials.StaticCredentialsProvider{ + Value: awsCredentials, + }), + ) + if err != nil { + log.Printf("failed to load AWS configuration: %v", err) + } + + stsClient := sts.NewFromConfig(cfg) + + callerIdentity, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) + + if err != nil { + log.Printf("failed to get caller identity: %v", err) + } + + log.Printf("Credentials Identity: {UserId: %s, Account: %s, Arn: %s}\n", + *callerIdentity.UserId, + *callerIdentity.Account, + *callerIdentity.Arn) +} diff --git a/vendor/github.com/aws/aws-msk-iam-sasl-signer-go/signer/version.go b/vendor/github.com/aws/aws-msk-iam-sasl-signer-go/signer/version.go new file mode 100644 index 00000000..1b70a0f1 --- /dev/null +++ b/vendor/github.com/aws/aws-msk-iam-sasl-signer-go/signer/version.go @@ -0,0 +1,3 @@ +package signer + +const version = "1.0.4" diff --git a/vendor/modules.txt b/vendor/modules.txt index 30411369..3ba1e799 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -7,6 +7,9 @@ github.com/Azure/go-ntlmssp # github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 ## explicit github.com/armon/go-socks5 +# github.com/aws/aws-msk-iam-sasl-signer-go v1.0.4 +## explicit; go 1.21 +github.com/aws/aws-msk-iam-sasl-signer-go/signer # github.com/aws/aws-sdk-go-v2 v1.36.1 ## explicit; go 1.21 github.com/aws/aws-sdk-go-v2/aws