Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions agent/api/task/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ func (task *Task) PostUnmarshalTask(cfg *config.Config,
return apierrors.NewResourceInitError(task.Arn, err)
}

task.initSecretResources(credentialsManager, resourceFields)
task.initSecretResources(cfg, credentialsManager, resourceFields)

task.initializeCredentialsEndpoint(credentialsManager)

Expand Down Expand Up @@ -651,14 +651,14 @@ func (task *Task) populateTaskARN() {
}
}

func (task *Task) initSecretResources(credentialsManager credentials.Manager,
func (task *Task) initSecretResources(cfg *config.Config, credentialsManager credentials.Manager,
resourceFields *taskresource.ResourceFields) {
if task.requiresASMDockerAuthData() {
task.initializeASMAuthResource(credentialsManager, resourceFields)
}

if task.requiresSSMSecret() {
task.initializeSSMSecretResource(credentialsManager, resourceFields)
task.initializeSSMSecretResource(cfg, credentialsManager, resourceFields)
}

if task.requiresASMSecret() {
Expand Down Expand Up @@ -1109,10 +1109,11 @@ func (task *Task) requiresSSMSecret() bool {
}

// initializeSSMSecretResource builds the resource dependency map for the SSM ssmsecret resource
func (task *Task) initializeSSMSecretResource(credentialsManager credentials.Manager,
func (task *Task) initializeSSMSecretResource(cfg *config.Config,
credentialsManager credentials.Manager,
resourceFields *taskresource.ResourceFields) {
ssmSecretResource := ssmsecret.NewSSMSecretResource(task.Arn, task.getAllSSMSecretRequirements(),
task.ExecutionCredentialsID, credentialsManager, resourceFields.SSMClientCreator)
task.ExecutionCredentialsID, credentialsManager, resourceFields.SSMClientCreator, cfg.InstanceIPCompatibility)
task.AddResource(ssmsecret.ResourceName, ssmSecretResource)

// for every container that needs ssm secret vending as env, it needs to wait all secrets got retrieved
Expand Down
2 changes: 1 addition & 1 deletion agent/api/task/task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3037,7 +3037,7 @@ func TestInitializeAndGetSSMSecretResource(t *testing.T) {
},
}

task.initializeSSMSecretResource(credentialsManager, resFields)
task.initializeSSMSecretResource(&config.Config{InstanceIPCompatibility: testIPCompatibility}, credentialsManager, resFields)

resourceDep := apicontainer.ResourceDependency{
Name: ssmsecret.ResourceName,
Expand Down
3 changes: 2 additions & 1 deletion agent/api/task/task_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ func (task *Task) addFSxWindowsFileServerResource(
credentialsManager,
resourceFields.SSMClientCreator,
resourceFields.ASMClientCreator,
resourceFields.FSxClientCreator)
resourceFields.FSxClientCreator,
cfg.InstanceIPCompatibility)
if err != nil {
return err
}
Expand Down
6 changes: 4 additions & 2 deletions agent/engine/docker_task_engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2879,12 +2879,14 @@ func TestTaskSecretsEnvironmentVariables(t *testing.T) {
},
}

testIPCompatibility := ipcompatibility.NewIPCompatibility(true, true)
ssmSecretRes := ssmsecret.NewSSMSecretResource(
testTask.Arn,
ssmRequirements,
credentialsID,
credentialsManager,
ssmClientCreator)
ssmClientCreator,
testIPCompatibility)

// required for validating asm workflows
asmClientCreator := mock_asm_factory.NewMockClientCreator(ctrl)
Expand Down Expand Up @@ -2923,7 +2925,7 @@ func TestTaskSecretsEnvironmentVariables(t *testing.T) {
reqSecretNames := []string{ssmSecretValueFrom}

credentialsManager.EXPECT().GetTaskCredentials(credentialsID).Return(taskIAMcreds, true).Times(2)
ssmClientCreator.EXPECT().NewSSMClient(region, executionRoleCredentials).Return(mockSSMClient, nil)
ssmClientCreator.EXPECT().NewSSMClient(region, executionRoleCredentials, testIPCompatibility).Return(mockSSMClient, nil)
asmClientCreator.EXPECT().NewASMClient(region, executionRoleCredentials).Return(mockASMClient, nil)

mockSSMClient.EXPECT().GetParameters(gomock.Any(), gomock.Any()).Do(func(ctx context.Context, in *ssm.GetParametersInput, optFns ...func(*ssm.Options)) {
Expand Down
21 changes: 16 additions & 5 deletions agent/ssm/factory/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@ import (
"time"

"github.com/aws/amazon-ecs-agent/agent/config"
"github.com/aws/amazon-ecs-agent/agent/config/ipcompatibility"
ssmclient "github.com/aws/amazon-ecs-agent/agent/ssm"
agentversion "github.com/aws/amazon-ecs-agent/agent/version"
"github.com/aws/amazon-ecs-agent/ecs-agent/credentials"
"github.com/aws/amazon-ecs-agent/ecs-agent/httpclient"
"github.com/aws/amazon-ecs-agent/ecs-agent/logger"

"github.com/aws/aws-sdk-go-v2/aws"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
awscreds "github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/ssm"
Expand All @@ -33,7 +36,7 @@ const (
)

type SSMClientCreator interface {
NewSSMClient(region string, creds credentials.IAMRoleCredentials) (ssmclient.SSMClient, error)
NewSSMClient(region string, creds credentials.IAMRoleCredentials, ipCompatibility ipcompatibility.IPCompatibility) (ssmclient.SSMClient, error)
}

func NewSSMClientCreator() SSMClientCreator {
Expand All @@ -44,16 +47,24 @@ type ssmClientCreator struct{}

// SSM Client will automatically retry 3 times when has throttling error
func (*ssmClientCreator) NewSSMClient(region string,
creds credentials.IAMRoleCredentials) (ssmclient.SSMClient, error) {
cfg, err := awsconfig.LoadDefaultConfig(
context.TODO(),
creds credentials.IAMRoleCredentials,
ipCompatibility ipcompatibility.IPCompatibility) (ssmclient.SSMClient, error) {

opts := []func(*awsconfig.LoadOptions) error{
awsconfig.WithHTTPClient(httpclient.New(roundtripTimeout, false, agentversion.String(), config.OSType)),
awsconfig.WithRegion(region),
awsconfig.WithCredentialsProvider(
awscreds.NewStaticCredentialsProvider(creds.AccessKeyID, creds.SecretAccessKey,
creds.SessionToken),
),
)
}

if ipCompatibility.IsIPv6Only() {
logger.Debug("Configuring SSM Client DualStack endpoint")
opts = append(opts, awsconfig.WithUseDualStackEndpoint(aws.DualStackEndpointStateEnabled))
}

cfg, err := awsconfig.LoadDefaultConfig(context.TODO(), opts...)

if err != nil {
return nil, err
Expand Down
9 changes: 5 additions & 4 deletions agent/ssm/factory/mocks/factory_mocks.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ func (cs *CredentialSpecResource) handleSSMCredentialspecFile(originalCredential
}
ssmParams := []string{ssmParam[1]}

ssmClient, err := cs.ssmClientCreator.NewSSMClient(cs.region, iamCredentials)
ssmClient, err := cs.ssmClientCreator.NewSSMClient(cs.region, iamCredentials, cs.ipCompatibility)
if err != nil {
errorEvents <- fmt.Errorf("unable to create SSM client: %v", err)
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ func TestHandleSSMCredentialspecFile(t *testing.T) {
expectedKerberosTicketPath := "/var/credentials-fetcher/krbdir/123456/webapp01"

gomock.InOrder(
ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any()).Return(mockSSMClient, nil),
ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any(), testConfig.InstanceIPCompatibility).Return(mockSSMClient, nil),
mockSSMClient.EXPECT().GetParameters(gomock.Any(), gomock.Any()).Return(ssmClientOutput, nil).Times(1),
)

Expand Down Expand Up @@ -243,7 +243,7 @@ func TestHandleSSMDomainlessCredentialspecFile(t *testing.T) {
expectedKerberosTicketPath := "/var/credentials-fetcher/krbdir/123456/webapp01"

gomock.InOrder(
ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any()).Return(mockSSMClient, nil),
ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any(), testConfig.InstanceIPCompatibility).Return(mockSSMClient, nil),
mockSSMClient.EXPECT().GetParameters(gomock.Any(), gomock.Any()).Return(ssmClientOutput, nil).Times(1),
)

Expand Down Expand Up @@ -333,7 +333,7 @@ func TestHandleSSMCredentialspecFileGetSSMParamErr(t *testing.T) {
}, apitaskstatus.TaskStatusNone, apitaskstatus.TaskRunning)

gomock.InOrder(
ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any()).Return(mockSSMClient, nil),
ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any(), testConfig.InstanceIPCompatibility).Return(mockSSMClient, nil),
mockSSMClient.EXPECT().GetParameters(gomock.Any(), gomock.Any()).Return(nil, errors.New("test-error")).Times(1),
)

Expand Down Expand Up @@ -954,7 +954,7 @@ func TestSkipCredentialFetcherInvocation(t *testing.T) {

gomock.InOrder(
credentialsManager.EXPECT().GetTaskCredentials(gomock.Any()).Return(taskRoleCredentials, true).Times(1),
ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any()).Return(mockSSMClient, nil),
ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any(), testConfig.InstanceIPCompatibility).Return(mockSSMClient, nil),
mockSSMClient.EXPECT().GetParameters(gomock.Any(), gomock.Any()).Return(ssmClientOutput, nil).Times(1),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ func (cs *CredentialSpecResource) handleSSMCredentialspecFile(originalCredential
return err
}

ssmClient, err := cs.ssmClientCreator.NewSSMClient(cs.region, iamCredentials)
ssmClient, err := cs.ssmClientCreator.NewSSMClient(cs.region, iamCredentials, cs.ipCompatibility)
if err != nil {
cs.setTerminalReason(err.Error())
return err
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ func TestHandleSSMCredentialspecFile(t *testing.T) {
}

gomock.InOrder(
ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any()).Return(mockSSMClient, nil),
ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any(), testConfig.InstanceIPCompatibility).Return(mockSSMClient, nil),
mockSSMClient.EXPECT().GetParameters(gomock.Any(), gomock.Any()).Return(ssmClientOutput, nil).Times(1),
mockIO.EXPECT().WriteFile(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil),
)
Expand Down Expand Up @@ -582,7 +582,7 @@ func TestHandleSSMCredentialspecFileGetSSMParamErr(t *testing.T) {
}, apitaskstatus.TaskStatusNone, apitaskstatus.TaskRunning)

gomock.InOrder(
ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any()).Return(mockSSMClient, nil),
ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any(), testConfig.InstanceIPCompatibility).Return(mockSSMClient, nil),
mockSSMClient.EXPECT().GetParameters(gomock.Any(), gomock.Any()).Return(nil, errors.New("test-error")).Times(1),
)

Expand Down Expand Up @@ -640,7 +640,7 @@ func TestHandleSSMCredentialspecFileIOErr(t *testing.T) {
}

gomock.InOrder(
ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any()).Return(mockSSMClient, nil),
ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any(), testConfig.InstanceIPCompatibility).Return(mockSSMClient, nil),
mockSSMClient.EXPECT().GetParameters(gomock.Any(), gomock.Any()).Return(ssmClientOutput, nil).Times(1),
mockIO.EXPECT().WriteFile(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("test-error")),
)
Expand Down Expand Up @@ -995,7 +995,7 @@ func TestCreateSSM(t *testing.T) {

gomock.InOrder(
credentialsManager.EXPECT().GetTaskCredentials(gomock.Any()).Return(creds, true),
ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any()).Return(mockSSMClient, nil),
ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any(), testConfig.InstanceIPCompatibility).Return(mockSSMClient, nil),
mockSSMClient.EXPECT().GetParameters(gomock.Any(), gomock.Any()).Return(ssmClientOutput, nil).Times(1),
mockIO.EXPECT().WriteFile(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/aws/amazon-ecs-agent/agent/asm"
asmfactory "github.com/aws/amazon-ecs-agent/agent/asm/factory"
"github.com/aws/amazon-ecs-agent/agent/config"
"github.com/aws/amazon-ecs-agent/agent/config/ipcompatibility"
"github.com/aws/amazon-ecs-agent/agent/fsx"
fsxfactory "github.com/aws/amazon-ecs-agent/agent/fsx/factory"
"github.com/aws/amazon-ecs-agent/agent/ssm"
Expand Down Expand Up @@ -66,6 +67,7 @@ type FSxWindowsFileServerResource struct {
asmClientCreator asmfactory.ClientCreator
// fsxClientCreator is a factory interface that creates new FSx clients.
fsxClientCreator fsxfactory.FSxClientCreator
ipCompatibility ipcompatibility.IPCompatibility

// fields that are set later during resource creation
FSxWindowsFileServerDNSName string
Expand Down Expand Up @@ -114,7 +116,8 @@ func NewFSxWindowsFileServerResource(
credentialsManager credentials.Manager,
ssmClientCreator ssmfactory.SSMClientCreator,
asmClientCreator asmfactory.ClientCreator,
fsxClientCreator fsxfactory.FSxClientCreator) (*FSxWindowsFileServerResource, error) {
fsxClientCreator fsxfactory.FSxClientCreator,
ipCompatibility ipcompatibility.IPCompatibility) (*FSxWindowsFileServerResource, error) {

fv := &FSxWindowsFileServerResource{
Name: name,
Expand All @@ -131,6 +134,7 @@ func NewFSxWindowsFileServerResource(
ssmClientCreator: ssmClientCreator,
asmClientCreator: asmClientCreator,
fsxClientCreator: fsxClientCreator,
ipCompatibility: ipCompatibility,
}

fv.initStatusToTransition()
Expand All @@ -147,6 +151,7 @@ func (fv *FSxWindowsFileServerResource) Initialize(
fv.ssmClientCreator = resourceFields.SSMClientCreator
fv.asmClientCreator = resourceFields.ASMClientCreator
fv.fsxClientCreator = resourceFields.FSxClientCreator
fv.ipCompatibility = config.InstanceIPCompatibility
fv.initStatusToTransition()
}

Expand Down Expand Up @@ -480,7 +485,7 @@ func (fv *FSxWindowsFileServerResource) retrieveSSMCredentials(credentialsParame
return err
}

ssmClient, err := fv.ssmClientCreator.NewSSMClient(fv.region, iamCredentials)
ssmClient, err := fv.ssmClientCreator.NewSSMClient(fv.region, iamCredentials, fv.ipCompatibility)
if err != nil {
return err
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"time"

"github.com/aws/amazon-ecs-agent/agent/config"
"github.com/aws/amazon-ecs-agent/agent/config/ipcompatibility"
"github.com/aws/amazon-ecs-agent/agent/utils"

mock_asm_factory "github.com/aws/amazon-ecs-agent/agent/asm/factory/mocks"
Expand Down Expand Up @@ -61,6 +62,10 @@ const (
hostPath = `Z:\`
)

var testConfig = &config.Config{
InstanceIPCompatibility: ipcompatibility.NewIPCompatibility(true, true),
}

func setup(t *testing.T) (
*FSxWindowsFileServerResource, *mock_credentials.MockManager, *mock_ssm_factory.MockSSMClientCreator,
*mock_asm_factory.MockClientCreator, *mock_fsx_factory.MockFSxClientCreator, *mock_ssmiface.MockSSMClient,
Expand All @@ -83,7 +88,7 @@ func setup(t *testing.T) (
taskARN: taskARN,
}
fv.Initialize(
&config.Config{},
testConfig,
&taskresource.ResourceFields{
ResourceFieldsCommon: &taskresource.ResourceFieldsCommon{
SSMClientCreator: ssmClientCreator,
Expand All @@ -102,6 +107,7 @@ func TestInitialize(t *testing.T) {
assert.NotNil(t, fv.asmClientCreator)
assert.NotNil(t, fv.fsxClientCreator)
assert.NotNil(t, fv.statusToTransitions)
assert.NotNil(t, fv.ipCompatibility)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit (non-blocking): Could do a more specific check using assert.Equal here, given that config (or testConfig in this case) contains what we expect to be the source of truth for IP compatibility.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. Ignoring this for now as we implicitly check it in the ssmClientCreator.NewSSMClient input args

fv, _, ssmClientCreator, _, _, mockSSMClient, _, _ := setup(t)
...
ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any(), testConfig.InstanceIPCompatibility).Return(mockSSMClient, nil),

Thanks

}

func TestMarshalUnmarshalJSON(t *testing.T) {
Expand Down Expand Up @@ -161,7 +167,7 @@ func TestRetrieveCredentials(t *testing.T) {
}

gomock.InOrder(
ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any()).Return(mockSSMClient, nil),
ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any(), testConfig.InstanceIPCompatibility).Return(mockSSMClient, nil),
mockSSMClient.EXPECT().GetParameters(gomock.Any(), gomock.Any()).Return(ssmClientOutput, nil).Times(1),
)

Expand Down Expand Up @@ -226,7 +232,7 @@ func TestRetrieveSSMCredentials(t *testing.T) {
}

gomock.InOrder(
ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any()).Return(mockSSMClient, nil),
ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any(), testConfig.InstanceIPCompatibility).Return(mockSSMClient, nil),
mockSSMClient.EXPECT().GetParameters(gomock.Any(), &ssm.GetParametersInput{
Names: []string{tc.CredentialsParameterName},
WithDecryption: aws.Bool(false),
Expand Down Expand Up @@ -526,7 +532,7 @@ func TestCreateUnavailableLocalPath(t *testing.T) {
executionCredentialsID: executionCredentialsID,
}
fv.Initialize(
&config.Config{},
testConfig,
&taskresource.ResourceFields{
ResourceFieldsCommon: &taskresource.ResourceFieldsCommon{
SSMClientCreator: ssmClientCreator,
Expand Down Expand Up @@ -566,7 +572,7 @@ func TestCreateUnavailableLocalPath(t *testing.T) {

gomock.InOrder(
credentialsManager.EXPECT().GetTaskCredentials(gomock.Any()).Return(creds, true),
ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any()).Return(mockSSMClient, nil),
ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any(), testConfig.InstanceIPCompatibility).Return(mockSSMClient, nil),
mockSSMClient.EXPECT().GetParameters(gomock.Any(), gomock.Any()).Return(ssmClientOutput, nil).Times(1),
fsxClientCreator.EXPECT().NewFSxClient(gomock.Any(), gomock.Any()).Return(mockFSxClient, nil),
mockFSxClient.EXPECT().DescribeFileSystems(gomock.Any(), gomock.Any()).Return(fsxClientOutput, nil).Times(1),
Expand Down Expand Up @@ -613,7 +619,7 @@ func TestCreateSSM(t *testing.T) {
executionCredentialsID: executionCredentialsID,
}
fv.Initialize(
&config.Config{},
testConfig,
&taskresource.ResourceFields{
ResourceFieldsCommon: &taskresource.ResourceFieldsCommon{
SSMClientCreator: ssmClientCreator,
Expand Down Expand Up @@ -653,7 +659,7 @@ func TestCreateSSM(t *testing.T) {

gomock.InOrder(
credentialsManager.EXPECT().GetTaskCredentials(gomock.Any()).Return(creds, true),
ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any()).Return(mockSSMClient, nil),
ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any(), testConfig.InstanceIPCompatibility).Return(mockSSMClient, nil),
mockSSMClient.EXPECT().GetParameters(gomock.Any(), gomock.Any()).Return(ssmClientOutput, nil).Times(1),
fsxClientCreator.EXPECT().NewFSxClient(gomock.Any(), gomock.Any()).Return(mockFSxClient, nil),
mockFSxClient.EXPECT().DescribeFileSystems(gomock.Any(), gomock.Any()).Return(fsxClientOutput, nil).Times(1),
Expand Down
Loading
Loading