Merge pull request #598 from jbarone/master

Fixes #376
pull/615/head
Jack Christensen 2019-09-14 18:13:01 -05:00 committed by GitHub
commit b3305b36c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 149 additions and 18 deletions

79
conn.go
View File

@ -17,7 +17,6 @@ import (
"os/user"
"path/filepath"
"reflect"
"regexp"
"strconv"
"strings"
"sync"
@ -1062,7 +1061,7 @@ func ParseURI(uri string) (ConnConfig, error) {
return cp, nil
}
var dsnRegexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`)
var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1}
// ParseDSN parses a database DSN (data source name) into a ConnConfig
//
@ -1078,35 +1077,79 @@ var dsnRegexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`)
func ParseDSN(s string) (ConnConfig, error) {
var cp ConnConfig
m := dsnRegexp.FindAllStringSubmatch(s, -1)
tlsArgs := configTLSArgs{}
cp.RuntimeParams = make(map[string]string)
var hostval, portval string
for _, b := range m {
switch b[1] {
for len(s) > 0 {
var key, val string
eqIdx := strings.IndexRune(s, '=')
if eqIdx < 0 {
return cp, errors.New("invalid dsn")
}
key = strings.Trim(s[:eqIdx], " \t\n\r\v\f")
s = strings.TrimLeft(s[eqIdx+1:], " \t\n\r\v\f")
if s[0] != '\'' {
end := 0
for ; end < len(s); end++ {
if asciiSpace[s[end]] == 1 {
break
}
if s[end] == '\\' {
end++
}
}
val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
if end == len(s) {
s = ""
} else {
s = s[end+1:]
}
} else { // quoted string
s = s[1:]
end := 0
for ; end < len(s); end++ {
if s[end] == '\'' {
break
}
if s[end] == '\\' {
end++
}
}
if end == len(s) {
return cp, errors.New("unterminated quoted string in connection info string")
}
val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
if end == len(s) {
s = ""
} else {
s = s[end+1:]
}
}
switch key {
case "user":
cp.User = b[2]
cp.User = val
case "password":
cp.Password = b[2]
cp.Password = val
case "host":
hostval = b[2]
hostval = val
case "port":
portval = b[2]
portval = val
case "dbname":
cp.Database = b[2]
cp.Database = val
case "sslmode":
tlsArgs.sslMode = b[2]
tlsArgs.sslMode = val
case "sslrootcert":
tlsArgs.sslRootCert = b[2]
tlsArgs.sslRootCert = val
case "sslcert":
tlsArgs.sslCert = b[2]
tlsArgs.sslCert = val
case "sslkey":
tlsArgs.sslKey = b[2]
tlsArgs.sslKey = val
case "connect_timeout":
timeout, err := strconv.ParseInt(b[2], 10, 64)
timeout, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return cp, err
}
@ -1114,12 +1157,12 @@ func ParseDSN(s string) (ConnConfig, error) {
d.Timeout = time.Duration(timeout) * time.Second
cp.Dial = d.Dial
case "target_session_attrs":
cp.TargetSessionAttrs = TargetSessionType(b[2])
cp.TargetSessionAttrs = TargetSessionType(val)
if err := cp.TargetSessionAttrs.isValid(); err != nil {
return cp, err
}
default:
cp.RuntimeParams[b[1]] = b[2]
cp.RuntimeParams[key] = val
}
}

View File

@ -717,6 +717,38 @@ func TestParseDSN(t *testing.T) {
RuntimeParams: map[string]string{},
},
},
{
url: "user=jack\\'s password=secret host=localhost port=5432 dbname=mydb",
connParams: pgx.ConnConfig{
User: "jack's",
Password: "secret",
Host: "localhost",
Port: 5432,
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
UseFallbackTLS: true,
FallbackTLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{
url: "user=jack password=sooper\\\\secret host=localhost port=5432 dbname=mydb",
connParams: pgx.ConnConfig{
User: "jack",
Password: "sooper\\secret",
Host: "localhost",
Port: 5432,
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
UseFallbackTLS: true,
FallbackTLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{
url: "user=jack host=localhost port=5432 dbname=mydb",
connParams: pgx.ConnConfig{
@ -822,6 +854,62 @@ func TestParseDSN(t *testing.T) {
TargetSessionAttrs: pgx.ReadWriteTargetSession,
},
},
{
url: "user='jack' host='localhost' dbname='mydb'",
connParams: pgx.ConnConfig{
User: "jack",
Host: "localhost",
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
UseFallbackTLS: true,
FallbackTLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{
url: "user='jack\\'s' host='localhost' dbname='mydb'",
connParams: pgx.ConnConfig{
User: "jack's",
Host: "localhost",
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
UseFallbackTLS: true,
FallbackTLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{
url: "user='jack' password='' host='localhost' dbname='mydb'",
connParams: pgx.ConnConfig{
User: "jack",
Host: "localhost",
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
UseFallbackTLS: true,
FallbackTLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{
url: "user = 'jack' password = '' host = 'localhost' dbname = 'mydb'",
connParams: pgx.ConnConfig{
User: "jack",
Host: "localhost",
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
UseFallbackTLS: true,
FallbackTLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
}
for i, tt := range tests {