diff --git a/agent/api/task/task.go b/agent/api/task/task.go index 8e4604f7c6f..27300bb9d48 100644 --- a/agent/api/task/task.go +++ b/agent/api/task/task.go @@ -416,7 +416,7 @@ func (task *Task) PostUnmarshalTask(cfg *config.Config, return apierrors.NewResourceInitError(task.Arn, err) } - task.initSecretResources(credentialsManager, resourceFields) + task.initSecretResources(cfg, credentialsManager, resourceFields) task.initializeCredentialsEndpoint(credentialsManager) @@ -651,14 +651,14 @@ func (task *Task) populateTaskARN() { } } -func (task *Task) initSecretResources(credentialsManager credentials.Manager, +func (task *Task) initSecretResources(cfg *config.Config, credentialsManager credentials.Manager, resourceFields *taskresource.ResourceFields) { if task.requiresASMDockerAuthData() { task.initializeASMAuthResource(credentialsManager, resourceFields) } if task.requiresSSMSecret() { - task.initializeSSMSecretResource(credentialsManager, resourceFields) + task.initializeSSMSecretResource(cfg, credentialsManager, resourceFields) } if task.requiresASMSecret() { @@ -1109,10 +1109,11 @@ func (task *Task) requiresSSMSecret() bool { } // initializeSSMSecretResource builds the resource dependency map for the SSM ssmsecret resource -func (task *Task) initializeSSMSecretResource(credentialsManager credentials.Manager, +func (task *Task) initializeSSMSecretResource(cfg *config.Config, + credentialsManager credentials.Manager, resourceFields *taskresource.ResourceFields) { ssmSecretResource := ssmsecret.NewSSMSecretResource(task.Arn, task.getAllSSMSecretRequirements(), - task.ExecutionCredentialsID, credentialsManager, resourceFields.SSMClientCreator) + task.ExecutionCredentialsID, credentialsManager, resourceFields.SSMClientCreator, cfg.InstanceIPCompatibility) task.AddResource(ssmsecret.ResourceName, ssmSecretResource) // for every container that needs ssm secret vending as env, it needs to wait all secrets got retrieved diff --git a/agent/api/task/task_test.go b/agent/api/task/task_test.go index 651c1cd97f5..db7b48f9256 100644 --- a/agent/api/task/task_test.go +++ b/agent/api/task/task_test.go @@ -3037,7 +3037,7 @@ func TestInitializeAndGetSSMSecretResource(t *testing.T) { }, } - task.initializeSSMSecretResource(credentialsManager, resFields) + task.initializeSSMSecretResource(&config.Config{InstanceIPCompatibility: testIPCompatibility}, credentialsManager, resFields) resourceDep := apicontainer.ResourceDependency{ Name: ssmsecret.ResourceName, diff --git a/agent/api/task/task_windows.go b/agent/api/task/task_windows.go index f4137d67c1c..c836654514e 100644 --- a/agent/api/task/task_windows.go +++ b/agent/api/task/task_windows.go @@ -203,7 +203,8 @@ func (task *Task) addFSxWindowsFileServerResource( credentialsManager, resourceFields.SSMClientCreator, resourceFields.ASMClientCreator, - resourceFields.FSxClientCreator) + resourceFields.FSxClientCreator, + cfg.InstanceIPCompatibility) if err != nil { return err } diff --git a/agent/engine/docker_task_engine_test.go b/agent/engine/docker_task_engine_test.go index 9465a32c634..288d52be502 100644 --- a/agent/engine/docker_task_engine_test.go +++ b/agent/engine/docker_task_engine_test.go @@ -2879,12 +2879,14 @@ func TestTaskSecretsEnvironmentVariables(t *testing.T) { }, } + testIPCompatibility := ipcompatibility.NewIPCompatibility(true, true) ssmSecretRes := ssmsecret.NewSSMSecretResource( testTask.Arn, ssmRequirements, credentialsID, credentialsManager, - ssmClientCreator) + ssmClientCreator, + testIPCompatibility) // required for validating asm workflows asmClientCreator := mock_asm_factory.NewMockClientCreator(ctrl) @@ -2923,7 +2925,7 @@ func TestTaskSecretsEnvironmentVariables(t *testing.T) { reqSecretNames := []string{ssmSecretValueFrom} credentialsManager.EXPECT().GetTaskCredentials(credentialsID).Return(taskIAMcreds, true).Times(2) - ssmClientCreator.EXPECT().NewSSMClient(region, executionRoleCredentials).Return(mockSSMClient, nil) + ssmClientCreator.EXPECT().NewSSMClient(region, executionRoleCredentials, testIPCompatibility).Return(mockSSMClient, nil) asmClientCreator.EXPECT().NewASMClient(region, executionRoleCredentials).Return(mockASMClient, nil) mockSSMClient.EXPECT().GetParameters(gomock.Any(), gomock.Any()).Do(func(ctx context.Context, in *ssm.GetParametersInput, optFns ...func(*ssm.Options)) { diff --git a/agent/ssm/factory/factory.go b/agent/ssm/factory/factory.go index 24a2f387bc1..8794d60cc9f 100644 --- a/agent/ssm/factory/factory.go +++ b/agent/ssm/factory/factory.go @@ -18,11 +18,14 @@ import ( "time" "github.com/aws/amazon-ecs-agent/agent/config" + "github.com/aws/amazon-ecs-agent/agent/config/ipcompatibility" ssmclient "github.com/aws/amazon-ecs-agent/agent/ssm" agentversion "github.com/aws/amazon-ecs-agent/agent/version" "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" "github.com/aws/amazon-ecs-agent/ecs-agent/httpclient" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger" + "github.com/aws/aws-sdk-go-v2/aws" awsconfig "github.com/aws/aws-sdk-go-v2/config" awscreds "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/ssm" @@ -33,7 +36,7 @@ const ( ) type SSMClientCreator interface { - NewSSMClient(region string, creds credentials.IAMRoleCredentials) (ssmclient.SSMClient, error) + NewSSMClient(region string, creds credentials.IAMRoleCredentials, ipCompatibility ipcompatibility.IPCompatibility) (ssmclient.SSMClient, error) } func NewSSMClientCreator() SSMClientCreator { @@ -44,16 +47,24 @@ type ssmClientCreator struct{} // SSM Client will automatically retry 3 times when has throttling error func (*ssmClientCreator) NewSSMClient(region string, - creds credentials.IAMRoleCredentials) (ssmclient.SSMClient, error) { - cfg, err := awsconfig.LoadDefaultConfig( - context.TODO(), + creds credentials.IAMRoleCredentials, + ipCompatibility ipcompatibility.IPCompatibility) (ssmclient.SSMClient, error) { + + opts := []func(*awsconfig.LoadOptions) error{ awsconfig.WithHTTPClient(httpclient.New(roundtripTimeout, false, agentversion.String(), config.OSType)), awsconfig.WithRegion(region), awsconfig.WithCredentialsProvider( awscreds.NewStaticCredentialsProvider(creds.AccessKeyID, creds.SecretAccessKey, creds.SessionToken), ), - ) + } + + if ipCompatibility.IsIPv6Only() { + logger.Debug("Configuring SSM Client DualStack endpoint") + opts = append(opts, awsconfig.WithUseDualStackEndpoint(aws.DualStackEndpointStateEnabled)) + } + + cfg, err := awsconfig.LoadDefaultConfig(context.TODO(), opts...) if err != nil { return nil, err diff --git a/agent/ssm/factory/mocks/factory_mocks.go b/agent/ssm/factory/mocks/factory_mocks.go index 1f1c1c1a0b0..8be677a9a22 100644 --- a/agent/ssm/factory/mocks/factory_mocks.go +++ b/agent/ssm/factory/mocks/factory_mocks.go @@ -21,6 +21,7 @@ package mock_factory import ( reflect "reflect" + ipcompatibility "github.com/aws/amazon-ecs-agent/agent/config/ipcompatibility" ssm "github.com/aws/amazon-ecs-agent/agent/ssm" credentials "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" gomock "github.com/golang/mock/gomock" @@ -50,16 +51,16 @@ func (m *MockSSMClientCreator) EXPECT() *MockSSMClientCreatorMockRecorder { } // NewSSMClient mocks base method. -func (m *MockSSMClientCreator) NewSSMClient(arg0 string, arg1 credentials.IAMRoleCredentials) (ssm.SSMClient, error) { +func (m *MockSSMClientCreator) NewSSMClient(arg0 string, arg1 credentials.IAMRoleCredentials, arg2 ipcompatibility.IPCompatibility) (ssm.SSMClient, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NewSSMClient", arg0, arg1) + ret := m.ctrl.Call(m, "NewSSMClient", arg0, arg1, arg2) ret0, _ := ret[0].(ssm.SSMClient) ret1, _ := ret[1].(error) return ret0, ret1 } // NewSSMClient indicates an expected call of NewSSMClient. -func (mr *MockSSMClientCreatorMockRecorder) NewSSMClient(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockSSMClientCreatorMockRecorder) NewSSMClient(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewSSMClient", reflect.TypeOf((*MockSSMClientCreator)(nil).NewSSMClient), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewSSMClient", reflect.TypeOf((*MockSSMClientCreator)(nil).NewSSMClient), arg0, arg1, arg2) } diff --git a/agent/taskresource/credentialspec/credentialspec_linux.go b/agent/taskresource/credentialspec/credentialspec_linux.go index 39ab301ecae..106780cb1e1 100644 --- a/agent/taskresource/credentialspec/credentialspec_linux.go +++ b/agent/taskresource/credentialspec/credentialspec_linux.go @@ -501,7 +501,7 @@ func (cs *CredentialSpecResource) handleSSMCredentialspecFile(originalCredential } ssmParams := []string{ssmParam[1]} - ssmClient, err := cs.ssmClientCreator.NewSSMClient(cs.region, iamCredentials) + ssmClient, err := cs.ssmClientCreator.NewSSMClient(cs.region, iamCredentials, cs.ipCompatibility) if err != nil { errorEvents <- fmt.Errorf("unable to create SSM client: %v", err) return diff --git a/agent/taskresource/credentialspec/credentialspec_linux_test.go b/agent/taskresource/credentialspec/credentialspec_linux_test.go index a0becac73b2..e65250eaaef 100644 --- a/agent/taskresource/credentialspec/credentialspec_linux_test.go +++ b/agent/taskresource/credentialspec/credentialspec_linux_test.go @@ -168,7 +168,7 @@ func TestHandleSSMCredentialspecFile(t *testing.T) { expectedKerberosTicketPath := "/var/credentials-fetcher/krbdir/123456/webapp01" gomock.InOrder( - ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any()).Return(mockSSMClient, nil), + ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any(), testConfig.InstanceIPCompatibility).Return(mockSSMClient, nil), mockSSMClient.EXPECT().GetParameters(gomock.Any(), gomock.Any()).Return(ssmClientOutput, nil).Times(1), ) @@ -243,7 +243,7 @@ func TestHandleSSMDomainlessCredentialspecFile(t *testing.T) { expectedKerberosTicketPath := "/var/credentials-fetcher/krbdir/123456/webapp01" gomock.InOrder( - ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any()).Return(mockSSMClient, nil), + ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any(), testConfig.InstanceIPCompatibility).Return(mockSSMClient, nil), mockSSMClient.EXPECT().GetParameters(gomock.Any(), gomock.Any()).Return(ssmClientOutput, nil).Times(1), ) @@ -333,7 +333,7 @@ func TestHandleSSMCredentialspecFileGetSSMParamErr(t *testing.T) { }, apitaskstatus.TaskStatusNone, apitaskstatus.TaskRunning) gomock.InOrder( - ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any()).Return(mockSSMClient, nil), + ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any(), testConfig.InstanceIPCompatibility).Return(mockSSMClient, nil), mockSSMClient.EXPECT().GetParameters(gomock.Any(), gomock.Any()).Return(nil, errors.New("test-error")).Times(1), ) @@ -954,7 +954,7 @@ func TestSkipCredentialFetcherInvocation(t *testing.T) { gomock.InOrder( credentialsManager.EXPECT().GetTaskCredentials(gomock.Any()).Return(taskRoleCredentials, true).Times(1), - ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any()).Return(mockSSMClient, nil), + ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any(), testConfig.InstanceIPCompatibility).Return(mockSSMClient, nil), mockSSMClient.EXPECT().GetParameters(gomock.Any(), gomock.Any()).Return(ssmClientOutput, nil).Times(1), ) diff --git a/agent/taskresource/credentialspec/credentialspec_windows.go b/agent/taskresource/credentialspec/credentialspec_windows.go index 813c795eeab..7f6944e9943 100644 --- a/agent/taskresource/credentialspec/credentialspec_windows.go +++ b/agent/taskresource/credentialspec/credentialspec_windows.go @@ -305,7 +305,7 @@ func (cs *CredentialSpecResource) handleSSMCredentialspecFile(originalCredential return err } - ssmClient, err := cs.ssmClientCreator.NewSSMClient(cs.region, iamCredentials) + ssmClient, err := cs.ssmClientCreator.NewSSMClient(cs.region, iamCredentials, cs.ipCompatibility) if err != nil { cs.setTerminalReason(err.Error()) return err diff --git a/agent/taskresource/credentialspec/credentialspec_windows_test.go b/agent/taskresource/credentialspec/credentialspec_windows_test.go index 4ffbcec876c..9911d6d24a1 100644 --- a/agent/taskresource/credentialspec/credentialspec_windows_test.go +++ b/agent/taskresource/credentialspec/credentialspec_windows_test.go @@ -492,7 +492,7 @@ func TestHandleSSMCredentialspecFile(t *testing.T) { } gomock.InOrder( - ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any()).Return(mockSSMClient, nil), + ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any(), testConfig.InstanceIPCompatibility).Return(mockSSMClient, nil), mockSSMClient.EXPECT().GetParameters(gomock.Any(), gomock.Any()).Return(ssmClientOutput, nil).Times(1), mockIO.EXPECT().WriteFile(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil), ) @@ -582,7 +582,7 @@ func TestHandleSSMCredentialspecFileGetSSMParamErr(t *testing.T) { }, apitaskstatus.TaskStatusNone, apitaskstatus.TaskRunning) gomock.InOrder( - ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any()).Return(mockSSMClient, nil), + ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any(), testConfig.InstanceIPCompatibility).Return(mockSSMClient, nil), mockSSMClient.EXPECT().GetParameters(gomock.Any(), gomock.Any()).Return(nil, errors.New("test-error")).Times(1), ) @@ -640,7 +640,7 @@ func TestHandleSSMCredentialspecFileIOErr(t *testing.T) { } gomock.InOrder( - ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any()).Return(mockSSMClient, nil), + ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any(), testConfig.InstanceIPCompatibility).Return(mockSSMClient, nil), mockSSMClient.EXPECT().GetParameters(gomock.Any(), gomock.Any()).Return(ssmClientOutput, nil).Times(1), mockIO.EXPECT().WriteFile(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("test-error")), ) @@ -995,7 +995,7 @@ func TestCreateSSM(t *testing.T) { gomock.InOrder( credentialsManager.EXPECT().GetTaskCredentials(gomock.Any()).Return(creds, true), - ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any()).Return(mockSSMClient, nil), + ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any(), testConfig.InstanceIPCompatibility).Return(mockSSMClient, nil), mockSSMClient.EXPECT().GetParameters(gomock.Any(), gomock.Any()).Return(ssmClientOutput, nil).Times(1), mockIO.EXPECT().WriteFile(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil), ) diff --git a/agent/taskresource/fsxwindowsfileserver/fsxwindowsfileserver_windows.go b/agent/taskresource/fsxwindowsfileserver/fsxwindowsfileserver_windows.go index 280d5b128e4..8d15b00dcc8 100644 --- a/agent/taskresource/fsxwindowsfileserver/fsxwindowsfileserver_windows.go +++ b/agent/taskresource/fsxwindowsfileserver/fsxwindowsfileserver_windows.go @@ -28,6 +28,7 @@ import ( "github.com/aws/amazon-ecs-agent/agent/asm" asmfactory "github.com/aws/amazon-ecs-agent/agent/asm/factory" "github.com/aws/amazon-ecs-agent/agent/config" + "github.com/aws/amazon-ecs-agent/agent/config/ipcompatibility" "github.com/aws/amazon-ecs-agent/agent/fsx" fsxfactory "github.com/aws/amazon-ecs-agent/agent/fsx/factory" "github.com/aws/amazon-ecs-agent/agent/ssm" @@ -66,6 +67,7 @@ type FSxWindowsFileServerResource struct { asmClientCreator asmfactory.ClientCreator // fsxClientCreator is a factory interface that creates new FSx clients. fsxClientCreator fsxfactory.FSxClientCreator + ipCompatibility ipcompatibility.IPCompatibility // fields that are set later during resource creation FSxWindowsFileServerDNSName string @@ -114,7 +116,8 @@ func NewFSxWindowsFileServerResource( credentialsManager credentials.Manager, ssmClientCreator ssmfactory.SSMClientCreator, asmClientCreator asmfactory.ClientCreator, - fsxClientCreator fsxfactory.FSxClientCreator) (*FSxWindowsFileServerResource, error) { + fsxClientCreator fsxfactory.FSxClientCreator, + ipCompatibility ipcompatibility.IPCompatibility) (*FSxWindowsFileServerResource, error) { fv := &FSxWindowsFileServerResource{ Name: name, @@ -131,6 +134,7 @@ func NewFSxWindowsFileServerResource( ssmClientCreator: ssmClientCreator, asmClientCreator: asmClientCreator, fsxClientCreator: fsxClientCreator, + ipCompatibility: ipCompatibility, } fv.initStatusToTransition() @@ -147,6 +151,7 @@ func (fv *FSxWindowsFileServerResource) Initialize( fv.ssmClientCreator = resourceFields.SSMClientCreator fv.asmClientCreator = resourceFields.ASMClientCreator fv.fsxClientCreator = resourceFields.FSxClientCreator + fv.ipCompatibility = config.InstanceIPCompatibility fv.initStatusToTransition() } @@ -480,7 +485,7 @@ func (fv *FSxWindowsFileServerResource) retrieveSSMCredentials(credentialsParame return err } - ssmClient, err := fv.ssmClientCreator.NewSSMClient(fv.region, iamCredentials) + ssmClient, err := fv.ssmClientCreator.NewSSMClient(fv.region, iamCredentials, fv.ipCompatibility) if err != nil { return err } diff --git a/agent/taskresource/fsxwindowsfileserver/fsxwindowsfileserver_windows_test.go b/agent/taskresource/fsxwindowsfileserver/fsxwindowsfileserver_windows_test.go index 1a0e7086210..0987e079b64 100644 --- a/agent/taskresource/fsxwindowsfileserver/fsxwindowsfileserver_windows_test.go +++ b/agent/taskresource/fsxwindowsfileserver/fsxwindowsfileserver_windows_test.go @@ -26,6 +26,7 @@ import ( "time" "github.com/aws/amazon-ecs-agent/agent/config" + "github.com/aws/amazon-ecs-agent/agent/config/ipcompatibility" "github.com/aws/amazon-ecs-agent/agent/utils" mock_asm_factory "github.com/aws/amazon-ecs-agent/agent/asm/factory/mocks" @@ -61,6 +62,10 @@ const ( hostPath = `Z:\` ) +var testConfig = &config.Config{ + InstanceIPCompatibility: ipcompatibility.NewIPCompatibility(true, true), +} + func setup(t *testing.T) ( *FSxWindowsFileServerResource, *mock_credentials.MockManager, *mock_ssm_factory.MockSSMClientCreator, *mock_asm_factory.MockClientCreator, *mock_fsx_factory.MockFSxClientCreator, *mock_ssmiface.MockSSMClient, @@ -83,7 +88,7 @@ func setup(t *testing.T) ( taskARN: taskARN, } fv.Initialize( - &config.Config{}, + testConfig, &taskresource.ResourceFields{ ResourceFieldsCommon: &taskresource.ResourceFieldsCommon{ SSMClientCreator: ssmClientCreator, @@ -102,6 +107,7 @@ func TestInitialize(t *testing.T) { assert.NotNil(t, fv.asmClientCreator) assert.NotNil(t, fv.fsxClientCreator) assert.NotNil(t, fv.statusToTransitions) + assert.NotNil(t, fv.ipCompatibility) } func TestMarshalUnmarshalJSON(t *testing.T) { @@ -161,7 +167,7 @@ func TestRetrieveCredentials(t *testing.T) { } gomock.InOrder( - ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any()).Return(mockSSMClient, nil), + ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any(), testConfig.InstanceIPCompatibility).Return(mockSSMClient, nil), mockSSMClient.EXPECT().GetParameters(gomock.Any(), gomock.Any()).Return(ssmClientOutput, nil).Times(1), ) @@ -226,7 +232,7 @@ func TestRetrieveSSMCredentials(t *testing.T) { } gomock.InOrder( - ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any()).Return(mockSSMClient, nil), + ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any(), testConfig.InstanceIPCompatibility).Return(mockSSMClient, nil), mockSSMClient.EXPECT().GetParameters(gomock.Any(), &ssm.GetParametersInput{ Names: []string{tc.CredentialsParameterName}, WithDecryption: aws.Bool(false), @@ -526,7 +532,7 @@ func TestCreateUnavailableLocalPath(t *testing.T) { executionCredentialsID: executionCredentialsID, } fv.Initialize( - &config.Config{}, + testConfig, &taskresource.ResourceFields{ ResourceFieldsCommon: &taskresource.ResourceFieldsCommon{ SSMClientCreator: ssmClientCreator, @@ -566,7 +572,7 @@ func TestCreateUnavailableLocalPath(t *testing.T) { gomock.InOrder( credentialsManager.EXPECT().GetTaskCredentials(gomock.Any()).Return(creds, true), - ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any()).Return(mockSSMClient, nil), + ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any(), testConfig.InstanceIPCompatibility).Return(mockSSMClient, nil), mockSSMClient.EXPECT().GetParameters(gomock.Any(), gomock.Any()).Return(ssmClientOutput, nil).Times(1), fsxClientCreator.EXPECT().NewFSxClient(gomock.Any(), gomock.Any()).Return(mockFSxClient, nil), mockFSxClient.EXPECT().DescribeFileSystems(gomock.Any(), gomock.Any()).Return(fsxClientOutput, nil).Times(1), @@ -613,7 +619,7 @@ func TestCreateSSM(t *testing.T) { executionCredentialsID: executionCredentialsID, } fv.Initialize( - &config.Config{}, + testConfig, &taskresource.ResourceFields{ ResourceFieldsCommon: &taskresource.ResourceFieldsCommon{ SSMClientCreator: ssmClientCreator, @@ -653,7 +659,7 @@ func TestCreateSSM(t *testing.T) { gomock.InOrder( credentialsManager.EXPECT().GetTaskCredentials(gomock.Any()).Return(creds, true), - ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any()).Return(mockSSMClient, nil), + ssmClientCreator.EXPECT().NewSSMClient(gomock.Any(), gomock.Any(), testConfig.InstanceIPCompatibility).Return(mockSSMClient, nil), mockSSMClient.EXPECT().GetParameters(gomock.Any(), gomock.Any()).Return(ssmClientOutput, nil).Times(1), fsxClientCreator.EXPECT().NewFSxClient(gomock.Any(), gomock.Any()).Return(mockFSxClient, nil), mockFSxClient.EXPECT().DescribeFileSystems(gomock.Any(), gomock.Any()).Return(fsxClientOutput, nil).Times(1), diff --git a/agent/taskresource/ssmsecret/ssmsecret.go b/agent/taskresource/ssmsecret/ssmsecret.go index bcfaebcf7c6..2d0609d592b 100644 --- a/agent/taskresource/ssmsecret/ssmsecret.go +++ b/agent/taskresource/ssmsecret/ssmsecret.go @@ -24,6 +24,7 @@ import ( apicontainer "github.com/aws/amazon-ecs-agent/agent/api/container" "github.com/aws/amazon-ecs-agent/agent/config" + "github.com/aws/amazon-ecs-agent/agent/config/ipcompatibility" "github.com/aws/amazon-ecs-agent/agent/ssm" "github.com/aws/amazon-ecs-agent/agent/ssm/factory" "github.com/aws/amazon-ecs-agent/agent/taskresource" @@ -65,6 +66,7 @@ type SSMSecretResource struct { // ssmClientCreator is a factory interface that creates new SSM clients. This is // needed mostly for testing. ssmClientCreator factory.SSMClientCreator + ipCompatibility ipcompatibility.IPCompatibility // terminalReason should be set for resource creation failures. This ensures // the resource object carries some context for why provisioning failed. @@ -80,7 +82,8 @@ func NewSSMSecretResource(taskARN string, ssmSecrets map[string][]apicontainer.Secret, executionCredentialsID string, credentialsManager credentials.Manager, - ssmClientCreator factory.SSMClientCreator) *SSMSecretResource { + ssmClientCreator factory.SSMClientCreator, + ipCompatibility ipcompatibility.IPCompatibility) *SSMSecretResource { s := &SSMSecretResource{ taskARN: taskARN, @@ -88,6 +91,7 @@ func NewSSMSecretResource(taskARN string, credentialsManager: credentialsManager, executionCredentialsID: executionCredentialsID, ssmClientCreator: ssmClientCreator, + ipCompatibility: ipCompatibility, } s.initStatusToTransition() @@ -335,7 +339,7 @@ func (secret *SSMSecretResource) retrieveSSMSecretValuesByRegion(region string, func (secret *SSMSecretResource) retrieveSSMSecretValues(region string, names []string, iamCredentials credentials.IAMRoleCredentials, wg *sync.WaitGroup, errorEvents chan error) { defer wg.Done() - ssmClient, err := secret.ssmClientCreator.NewSSMClient(region, iamCredentials) + ssmClient, err := secret.ssmClientCreator.NewSSMClient(region, iamCredentials, secret.ipCompatibility) if err != nil { errorEvents <- fmt.Errorf("unable to create SSM client in %s: %v", region, err) return @@ -419,6 +423,7 @@ func (secret *SSMSecretResource) Initialize( secret.initStatusToTransition() secret.credentialsManager = resourceFields.CredentialsManager secret.ssmClientCreator = resourceFields.SSMClientCreator + secret.ipCompatibility = config.InstanceIPCompatibility // if task hasn't turn to 'created' status, and it's desire status is 'running' // the resource status needs to be reset to 'NONE' status so the secret value diff --git a/agent/taskresource/ssmsecret/ssmsecret_test.go b/agent/taskresource/ssmsecret/ssmsecret_test.go index ca872dbe694..7f9888b4a48 100644 --- a/agent/taskresource/ssmsecret/ssmsecret_test.go +++ b/agent/taskresource/ssmsecret/ssmsecret_test.go @@ -25,6 +25,7 @@ import ( apicontainer "github.com/aws/amazon-ecs-agent/agent/api/container" "github.com/aws/amazon-ecs-agent/agent/config" + "github.com/aws/amazon-ecs-agent/agent/config/ipcompatibility" mock_factory "github.com/aws/amazon-ecs-agent/agent/ssm/factory/mocks" mock_ssm "github.com/aws/amazon-ecs-agent/agent/ssm/mocks" "github.com/aws/amazon-ecs-agent/agent/taskresource" @@ -57,6 +58,8 @@ const ( taskARN = "task1" ) +var testIPCompatibility = ipcompatibility.NewIPCompatibility(true, true) + func TestCreateAndGetWithOneCall(t *testing.T) { requiredSecretData := make(map[string][]apicontainer.Secret) secretsInRegion1 := []apicontainer.Secret{ @@ -105,7 +108,7 @@ func TestCreateAndGetWithOneCall(t *testing.T) { allNames := []string{valueFrom1, valueFrom2} credentialsManager.EXPECT().GetTaskCredentials(executionCredentialsID).Return(creds, true) - ssmClientCreator.EXPECT().NewSSMClient(region1, iamRoleCreds).Return(mockSSMClient, nil) + ssmClientCreator.EXPECT().NewSSMClient(region1, iamRoleCreds, testIPCompatibility).Return(mockSSMClient, nil) mockSSMClient.EXPECT().GetParameters(gomock.Any(), gomock.Any()).Do(func(ctx context.Context, in *ssm.GetParametersInput, optFns ...func(*ssm.Options)) { assert.Equal(t, in.Names, allNames) }).Return(ssmOutput, nil).Times(1) @@ -115,6 +118,7 @@ func TestCreateAndGetWithOneCall(t *testing.T) { requiredSecrets: requiredSecretData, credentialsManager: credentialsManager, ssmClientCreator: ssmClientCreator, + ipCompatibility: testIPCompatibility, } require.NoError(t, ssmRes.Create()) @@ -173,8 +177,8 @@ func TestCreateAndGetWithTwoCallsAcrossRegions(t *testing.T) { allNames := []string{valueFrom1} credentialsManager.EXPECT().GetTaskCredentials(executionCredentialsID).Return(creds, true) - ssmClientCreator.EXPECT().NewSSMClient(region1, iamRoleCreds).Return(mockSSMClient, nil) - ssmClientCreator.EXPECT().NewSSMClient(region2, iamRoleCreds).Return(mockSSMClient, nil) + ssmClientCreator.EXPECT().NewSSMClient(region1, iamRoleCreds, testIPCompatibility).Return(mockSSMClient, nil) + ssmClientCreator.EXPECT().NewSSMClient(region2, iamRoleCreds, testIPCompatibility).Return(mockSSMClient, nil) mockSSMClient.EXPECT().GetParameters(gomock.Any(), gomock.Any()).Do(func(ctx context.Context, in *ssm.GetParametersInput, optFns ...func(*ssm.Options)) { assert.Equal(t, in.Names, allNames) }).Return(ssmOutput, nil).Times(2) @@ -184,6 +188,7 @@ func TestCreateAndGetWithTwoCallsAcrossRegions(t *testing.T) { requiredSecrets: requiredSecretData, credentialsManager: credentialsManager, ssmClientCreator: ssmClientCreator, + ipCompatibility: testIPCompatibility, } require.NoError(t, ssmRes.Create()) @@ -270,7 +275,7 @@ func TestCreateAndGetWithTwoCallsInSameRegion(t *testing.T) { } credentialsManager.EXPECT().GetTaskCredentials(executionCredentialsID).Return(creds, true) - ssmClientCreator.EXPECT().NewSSMClient(region1, iamRoleCreds).Return(mockSSMClient, nil).Times(2) + ssmClientCreator.EXPECT().NewSSMClient(region1, iamRoleCreds, testIPCompatibility).Return(mockSSMClient, nil).Times(2) mockSSMClient.EXPECT().GetParameters(gomock.Any(), ssmInput1).Return(ssmOutput1, nil) mockSSMClient.EXPECT().GetParameters(gomock.Any(), ssmInput2).Return(ssmOutput2, nil) @@ -279,6 +284,7 @@ func TestCreateAndGetWithTwoCallsInSameRegion(t *testing.T) { requiredSecrets: requiredSecretData, credentialsManager: credentialsManager, ssmClientCreator: ssmClientCreator, + ipCompatibility: testIPCompatibility, } require.NoError(t, ssmRes.Create()) @@ -331,8 +337,8 @@ func TestCreateReturnMultipleErrors(t *testing.T) { allNames := []string{valueFrom1} credentialsManager.EXPECT().GetTaskCredentials(executionCredentialsID).Return(creds, true) - ssmClientCreator.EXPECT().NewSSMClient(region1, iamRoleCreds).Return(mockSSMClient, nil) - ssmClientCreator.EXPECT().NewSSMClient(region2, iamRoleCreds).Return(mockSSMClient, nil) + ssmClientCreator.EXPECT().NewSSMClient(region1, iamRoleCreds, testIPCompatibility).Return(mockSSMClient, nil) + ssmClientCreator.EXPECT().NewSSMClient(region2, iamRoleCreds, testIPCompatibility).Return(mockSSMClient, nil) mockSSMClient.EXPECT().GetParameters(gomock.Any(), gomock.Any()).Do(func(ctx context.Context, in *ssm.GetParametersInput, optFns ...func(*ssm.Options)) { assert.Equal(t, in.Names, allNames) }).Return(ssmOutput, nil).Times(2) @@ -342,6 +348,7 @@ func TestCreateReturnMultipleErrors(t *testing.T) { requiredSecrets: requiredSecretData, credentialsManager: credentialsManager, ssmClientCreator: ssmClientCreator, + ipCompatibility: testIPCompatibility, } assert.Error(t, ssmRes.Create()) @@ -381,7 +388,7 @@ func TestCreateReturnError(t *testing.T) { allNames := []string{valueFrom1} gomock.InOrder( credentialsManager.EXPECT().GetTaskCredentials(executionCredentialsID).Return(creds, true), - ssmClientCreator.EXPECT().NewSSMClient(region1, iamRoleCreds).Return(mockSSMClient, nil), + ssmClientCreator.EXPECT().NewSSMClient(region1, iamRoleCreds, testIPCompatibility).Return(mockSSMClient, nil), mockSSMClient.EXPECT().GetParameters(gomock.Any(), gomock.Any()).Do(func(ctx context.Context, in *ssm.GetParametersInput, optFns ...func(*ssm.Options)) { assert.Equal(t, in.Names, allNames) }).Return(ssmOutput, nil), @@ -391,6 +398,7 @@ func TestCreateReturnError(t *testing.T) { requiredSecrets: requiredSecretData, credentialsManager: credentialsManager, ssmClientCreator: ssmClientCreator, + ipCompatibility: testIPCompatibility, } assert.Error(t, ssmRes.Create()) @@ -501,7 +509,7 @@ func TestInitialize(t *testing.T) { desiredStatusUnsafe: resourcestatus.ResourceCreated, } ssmRes.Initialize( - &config.Config{}, + &config.Config{InstanceIPCompatibility: testIPCompatibility}, &taskresource.ResourceFields{ ResourceFieldsCommon: &taskresource.ResourceFieldsCommon{ SSMClientCreator: ssmClientCreator, @@ -510,6 +518,7 @@ func TestInitialize(t *testing.T) { }, apitaskstatus.TaskStatusNone, apitaskstatus.TaskRunning) assert.Equal(t, resourcestatus.ResourceStatusNone, ssmRes.GetKnownStatus()) assert.Equal(t, resourcestatus.ResourceCreated, ssmRes.GetDesiredStatus()) + assert.Equal(t, testIPCompatibility, ssmRes.ipCompatibility) }