Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PostgreSQL dialog URL parser #346

Merged
merged 7 commits into from
Feb 17, 2019
163 changes: 163 additions & 0 deletions dialect_postgresql.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@ import (
"fmt"
"io"
"os/exec"
"strings"
"sync"
"unicode"

"github.com/gobuffalo/fizz"
"github.com/gobuffalo/fizz/translators"
"github.com/gobuffalo/pop/columns"
"github.com/gobuffalo/pop/logging"
"github.com/jmoiron/sqlx"
pg "github.com/lib/pq"
"github.com/markbates/going/defaults"
"github.com/pkg/errors"
)
Expand All @@ -23,6 +26,7 @@ func init() {
AvailableDialects = append(AvailableDialects, namePostgreSQL)
dialectSynonyms["postgresql"] = namePostgreSQL
dialectSynonyms["pg"] = namePostgreSQL
urlParser[namePostgreSQL] = urlParserPostgreSQL
finalizer[namePostgreSQL] = finalizerPostgreSQL
newConnection[namePostgreSQL] = newPostgreSQL
}
Expand Down Expand Up @@ -208,6 +212,51 @@ func newPostgreSQL(deets *ConnectionDetails) (dialect, error) {
return cd, nil
}

// urlParserPostgreSQL parses the options the same way official lib/pg does:
// https://godoc.org/github.com/lib/pq#hdr-Connection_String_Parameters
// After parsed, they are set to ConnectionDetails instance
func urlParserPostgreSQL(cd *ConnectionDetails) error {
var err error
name := cd.URL
if strings.HasPrefix(name, "postgres://") || strings.HasPrefix(name, "postgresql://") {
name, err = pg.ParseURL(name)
if err != nil {
return err
}
}

o := make(values)
if err := parseOpts(name, o); err != nil {
return err
}

if dbname, ok := o["dbname"]; ok {
cd.Database = dbname
}
if host, ok := o["host"]; ok {
cd.Host = host
}
if password, ok := o["password"]; ok {
cd.Password = password
}
if user, ok := o["user"]; ok {
cd.User = user
}
if port, ok := o["port"]; ok {
cd.Port = port
}

options := []string{"sslmode", "fallback_application_name", "connect_timeout", "sslcert", "sslkey", "sslrootcert"}

for i := range options {
if opt, ok := o[options[i]]; ok {
cd.Options[options[i]] = opt
}
}

return nil
}

func finalizerPostgreSQL(cd *ConnectionDetails) {
cd.Options["sslmode"] = defaults.String(cd.Options["sslmode"], "disable")
cd.Port = defaults.String(cd.Port, portPostgreSQL)
Expand All @@ -230,3 +279,117 @@ BEGIN
END LOOP;
END
$func$;`

// Code below is ported from: https://github.com/lib/pq/blob/master/conn.go
type values map[string]string

// scanner implements a tokenizer for libpq-style option strings.
type scanner struct {
s []rune
i int
}

// newScanner returns a new scanner initialized with the option string s.
func newScanner(s string) *scanner {
return &scanner{[]rune(s), 0}
}

// Next returns the next rune.
// It returns 0, false if the end of the text has been reached.
func (s *scanner) Next() (rune, bool) {
if s.i >= len(s.s) {
return 0, false
}
r := s.s[s.i]
s.i++
return r, true
}

// SkipSpaces returns the next non-whitespace rune.
// It returns 0, false if the end of the text has been reached.
func (s *scanner) SkipSpaces() (rune, bool) {
r, ok := s.Next()
for unicode.IsSpace(r) && ok {
r, ok = s.Next()
}
return r, ok
}

// parseOpts parses the options from name and adds them to the values.
//
// The parsing code is based on conninfo_parse from libpq's fe-connect.c
func parseOpts(name string, o values) error {
s := newScanner(name)

for {
var (
keyRunes, valRunes []rune
r rune
ok bool
)

if r, ok = s.SkipSpaces(); !ok {
break
}

// Scan the key
for !unicode.IsSpace(r) && r != '=' {
keyRunes = append(keyRunes, r)
if r, ok = s.Next(); !ok {
break
}
}

// Skip any whitespace if we're not at the = yet
if r != '=' {
r, ok = s.SkipSpaces()
}

// The current character should be =
if r != '=' || !ok {
return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes))
}

// Skip any whitespace after the =
if r, ok = s.SkipSpaces(); !ok {
// If we reach the end here, the last value is just an empty string as per libpq.
o[string(keyRunes)] = ""
break
}

if r != '\'' {
for !unicode.IsSpace(r) {
if r == '\\' {
if r, ok = s.Next(); !ok {
return fmt.Errorf(`missing character after backslash`)
}
}
valRunes = append(valRunes, r)

if r, ok = s.Next(); !ok {
break
}
}
} else {
quote:
for {
if r, ok = s.Next(); !ok {
return fmt.Errorf(`unterminated quoted string literal in connection string`)
}
switch r {
case '\'':
break quote
case '\\':
r, _ = s.Next()
fallthrough
default:
valRunes = append(valRunes, r)
}
}
}

o[string(keyRunes)] = string(valRunes)
}

return nil
}
81 changes: 81 additions & 0 deletions dialect_postgresql_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package pop

import (
"testing"

"github.com/stretchr/testify/require"
)

func Test_PostgreSQL_Connection_String(t *testing.T) {
r := require.New(t)

url := "host=host port=port dbname=database user=user password=pass"
cd := &ConnectionDetails{
Dialect: "postgres",
URL: url,
}
err := cd.Finalize()
r.NoError(err)

r.Equal(url, cd.URL)
r.Equal("postgres", cd.Dialect)
r.Equal("host", cd.Host)
r.Equal("pass", cd.Password)
r.Equal("port", cd.Port)
r.Equal("user", cd.User)
r.Equal("database", cd.Database)
}

func Test_PostgreSQL_Connection_String_Options(t *testing.T) {
r := require.New(t)

url := "host=host port=port dbname=database user=user password=pass sslmode=disable fallback_application_name=test_app connect_timeout=10 sslcert=/some/location sslkey=/some/other/location sslrootcert=/root/location"
cd := &ConnectionDetails{
Dialect: "postgres",
URL: url,
}
err := cd.Finalize()
r.NoError(err)

r.Equal(url, cd.URL)

r.Equal("disable", cd.Options["sslmode"])
r.Equal("test_app", cd.Options["fallback_application_name"])
r.Equal("10", cd.Options["connect_timeout"])
r.Equal("/some/location", cd.Options["sslcert"])
r.Equal("/some/other/location", cd.Options["sslkey"])
r.Equal("/root/location", cd.Options["sslrootcert"])
}

func Test_PostgreSQL_Connection_String_Without_User(t *testing.T) {
r := require.New(t)

url := "dbname=database"
cd := &ConnectionDetails{
Dialect: "postgres",
URL: url,
}
err := cd.Finalize()
r.NoError(err)

r.Equal(url, cd.URL)
r.Equal("postgres", cd.Dialect)
r.Equal("", cd.Host)
r.Equal("", cd.Password)
r.Equal(portPostgreSQL, cd.Port) // fallback
r.Equal("", cd.User)
r.Equal("database", cd.Database)
}

func Test_PostgreSQL_Connection_String_Failure(t *testing.T) {
r := require.New(t)

url := "abc"
cd := &ConnectionDetails{
Dialect: "postgres",
URL: url,
}
err := cd.Finalize()
r.Error(err)
r.Equal("postgres", cd.Dialect)
}