diff --git a/mysql/awsmysql/awsmysql.go b/mysql/awsmysql/awsmysql.go index 619ee8757..fa08050e9 100644 --- a/mysql/awsmysql/awsmysql.go +++ b/mysql/awsmysql/awsmysql.go @@ -32,6 +32,7 @@ import ( "database/sql" "database/sql/driver" "fmt" + "net/http" "net/url" "sync/atomic" @@ -50,7 +51,7 @@ import ( // Set is a Wire provider set that provides a *sql.DB given // *Params and an HTTP client. var Set = wire.NewSet( - wire.Struct(new(URLOpener), "CertSource"), + wire.Struct(new(URLOpener), "CertSource", "HTTPClient"), rds.CertFetcherSet, ) @@ -66,6 +67,9 @@ var Set = wire.NewSet( // - aws_profile: the AWS shared config profile to use // - aws_role_arn: the ARN of the role to assume type URLOpener struct { + // HTTPClient is the HTTP client used to fetch RDS certificates, + // and IAM authentication tokens. + HTTPClient *http.Client // CertSource specifies how the opener will obtain the RDS Certificate // Authority. If nil, it will use the default *rds.CertFetcher. CertSource rds.CertPoolProvider @@ -85,7 +89,7 @@ func init() { func (uo *URLOpener) OpenMySQLURL(ctx context.Context, u *url.URL) (*sql.DB, error) { source := uo.CertSource if source == nil { - source = new(rds.CertFetcher) + source = &rds.CertFetcher{Client: uo.HTTPClient} } if u.Host == "" { return nil, fmt.Errorf("open OpenMySQLURL: empty endpoint") @@ -100,6 +104,7 @@ func (uo *URLOpener) OpenMySQLURL(ctx context.Context, u *url.URL) (*sql.DB, err ) q.Del("aws_profile") cfg, err := config.LoadDefaultConfig(ctx, + config.WithHTTPClient(uo.HTTPClient), // Ignored if nil. config.WithSharedConfigProfile(profile)) // Ignored if empty. if err != nil { return nil, fmt.Errorf("open OpenMySQLURL: load AWS config: %v", err)