Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions mysql/awsmysql/awsmysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"database/sql"
"database/sql/driver"
"fmt"
"net/http"
"net/url"
"sync/atomic"

Expand All @@ -50,7 +51,7 @@ import (
// Set is a Wire provider set that provides a *sql.DB given
// *Params and an HTTP client.
var Set = wire.NewSet(
wire.Struct(new(URLOpener), "CertSource"),
wire.Struct(new(URLOpener), "CertSource", "HTTPClient"),
rds.CertFetcherSet,
)

Expand All @@ -66,6 +67,9 @@ var Set = wire.NewSet(
// - aws_profile: the AWS shared config profile to use
// - aws_role_arn: the ARN of the role to assume
type URLOpener struct {
// HTTPClient is the HTTP client used to fetch RDS certificates.
// and IAM authentication tokens.
HTTPClient *http.Client
// CertSource specifies how the opener will obtain the RDS Certificate
// Authority. If nil, it will use the default *rds.CertFetcher.
CertSource rds.CertPoolProvider
Expand All @@ -85,7 +89,7 @@ func init() {
func (uo *URLOpener) OpenMySQLURL(ctx context.Context, u *url.URL) (*sql.DB, error) {
source := uo.CertSource
if source == nil {
source = new(rds.CertFetcher)
source = &rds.CertFetcher{Client: uo.HTTPClient}
}
if u.Host == "" {
return nil, fmt.Errorf("open OpenMySQLURL: empty endpoint")
Expand All @@ -100,6 +104,7 @@ func (uo *URLOpener) OpenMySQLURL(ctx context.Context, u *url.URL) (*sql.DB, err
)
q.Del("aws_profile")
cfg, err := config.LoadDefaultConfig(ctx,
config.WithHTTPClient(uo.HTTPClient), // Ignored if nil.
config.WithSharedConfigProfile(profile)) // Ignored if empty.
if err != nil {
return nil, fmt.Errorf("open OpenMySQLURL: load AWS config: %v", err)
Expand Down