diff --git a/README.md b/README.md index 6d6a5e2..1c6c2c8 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,7 @@ Usage of pdd: --database string Database name (default "postgres") --user string Database user (default "postgres") --pass string Database password (default "postgres") + --enable-ssl Use TSL/SSL for connection --dial-timeout duration Dial timeout for establishing new connections (default 5s) --read-timeout duration Timeout for socket reads. If reached, commands will fail (default 30s) --max-retry int Maximum number of retries before giving up. diff --git a/cmd/main.go b/cmd/main.go index 8a6c786..8fe374e 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -60,6 +60,7 @@ func main() { flag.StringVar(&dbc.Database, "database", database.DefaultDatabase, "Database name") flag.StringVar(&dbc.User, "user", database.DefaultUser, "Database user") flag.StringVar(&dbc.Password, "pass", database.DefaultPassword, "Database password") + flag.BoolVar(&dbc.EnableSSL, "enable-ssl", false, "Use TSL/SSL for connection") flag.DurationVar(&dbc.DialTimeout, "dial-timeout", database.DefaultDialTimeout, "Dial timeout for establishing new connections") flag.DurationVar(&dbc.ReadTimeout, "read-timeout", database.DefaultReadTimeout, "Timeout for socket reads. If reached, commands will fail") flag.IntVar(&dbc.MaxRetries, "max-retry", database.DefaultMaxRetries, "Maximum number of retries before giving up.") diff --git a/database/config.go b/database/config.go index ec34b72..813cbe7 100644 --- a/database/config.go +++ b/database/config.go @@ -19,6 +19,7 @@ type Config struct { Database string User string Password string + EnableSSL bool MaxRetries int DialTimeout time.Duration ReadTimeout time.Duration diff --git a/database/database.go b/database/database.go index 625959b..96908b5 100644 --- a/database/database.go +++ b/database/database.go @@ -2,6 +2,7 @@ package database import ( "context" + "crypto/tls" "fmt" "io" "strings" @@ -31,6 +32,13 @@ type db struct { // ConnectDB connects to a database using provided options. func ConnectDB(logger log.Logger, cfg *Config) (DB, error) { + var tlsConfig tls.Config + + if cfg.EnableSSL { + serverName := strings.Split(cfg.Addr, ":")[0] + tlsConfig = tls.Config{ServerName: serverName} + } + pgdb := pg.Connect( &pg.Options{ Addr: cfg.Addr, @@ -40,6 +48,7 @@ func ConnectDB(logger log.Logger, cfg *Config) (DB, error) { MaxRetries: cfg.MaxRetries, DialTimeout: cfg.DialTimeout, ReadTimeout: cfg.ReadTimeout, + TLSConfig: &tlsConfig, }, )