diff --git a/conn.go b/conn.go index b4469af9..b67eb98a 100644 --- a/conn.go +++ b/conn.go @@ -17,7 +17,6 @@ import ( "os/user" "path/filepath" "reflect" - "regexp" "strconv" "strings" "sync" @@ -1062,7 +1061,7 @@ func ParseURI(uri string) (ConnConfig, error) { return cp, nil } -var dsnRegexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`) +var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1} // ParseDSN parses a database DSN (data source name) into a ConnConfig // @@ -1078,35 +1077,79 @@ var dsnRegexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`) func ParseDSN(s string) (ConnConfig, error) { var cp ConnConfig - m := dsnRegexp.FindAllStringSubmatch(s, -1) - tlsArgs := configTLSArgs{} cp.RuntimeParams = make(map[string]string) var hostval, portval string - for _, b := range m { - switch b[1] { + for len(s) > 0 { + var key, val string + eqIdx := strings.IndexRune(s, '=') + if eqIdx < 0 { + return cp, errors.New("invalid dsn") + } + + key = strings.Trim(s[:eqIdx], " \t\n\r\v\f") + s = strings.TrimLeft(s[eqIdx+1:], " \t\n\r\v\f") + if s[0] != '\'' { + end := 0 + for ; end < len(s); end++ { + if asciiSpace[s[end]] == 1 { + break + } + if s[end] == '\\' { + end++ + } + } + val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1) + if end == len(s) { + s = "" + } else { + s = s[end+1:] + } + } else { // quoted string + s = s[1:] + end := 0 + for ; end < len(s); end++ { + if s[end] == '\'' { + break + } + if s[end] == '\\' { + end++ + } + } + if end == len(s) { + return cp, errors.New("unterminated quoted string in connection info string") + } + val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1) + if end == len(s) { + s = "" + } else { + s = s[end+1:] + } + } + + switch key { case "user": - cp.User = b[2] + cp.User = val case "password": - cp.Password = b[2] + cp.Password = val case "host": - hostval = b[2] + hostval = val case "port": - portval = b[2] + portval = val case "dbname": - cp.Database = b[2] + cp.Database = val case "sslmode": - tlsArgs.sslMode = b[2] + tlsArgs.sslMode = val case "sslrootcert": - tlsArgs.sslRootCert = b[2] + tlsArgs.sslRootCert = val case "sslcert": - tlsArgs.sslCert = b[2] + tlsArgs.sslCert = val case "sslkey": - tlsArgs.sslKey = b[2] + tlsArgs.sslKey = val case "connect_timeout": - timeout, err := strconv.ParseInt(b[2], 10, 64) + timeout, err := strconv.ParseInt(val, 10, 64) if err != nil { return cp, err } @@ -1114,12 +1157,12 @@ func ParseDSN(s string) (ConnConfig, error) { d.Timeout = time.Duration(timeout) * time.Second cp.Dial = d.Dial case "target_session_attrs": - cp.TargetSessionAttrs = TargetSessionType(b[2]) + cp.TargetSessionAttrs = TargetSessionType(val) if err := cp.TargetSessionAttrs.isValid(); err != nil { return cp, err } default: - cp.RuntimeParams[b[1]] = b[2] + cp.RuntimeParams[key] = val } } diff --git a/conn_test.go b/conn_test.go index c6ce50cc..42e9c00b 100644 --- a/conn_test.go +++ b/conn_test.go @@ -717,6 +717,38 @@ func TestParseDSN(t *testing.T) { RuntimeParams: map[string]string{}, }, }, + { + url: "user=jack\\'s password=secret host=localhost port=5432 dbname=mydb", + connParams: pgx.ConnConfig{ + User: "jack's", + Password: "secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "user=jack password=sooper\\\\secret host=localhost port=5432 dbname=mydb", + connParams: pgx.ConnConfig{ + User: "jack", + Password: "sooper\\secret", + Host: "localhost", + Port: 5432, + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, { url: "user=jack host=localhost port=5432 dbname=mydb", connParams: pgx.ConnConfig{ @@ -822,6 +854,62 @@ func TestParseDSN(t *testing.T) { TargetSessionAttrs: pgx.ReadWriteTargetSession, }, }, + { + url: "user='jack' host='localhost' dbname='mydb'", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "user='jack\\'s' host='localhost' dbname='mydb'", + connParams: pgx.ConnConfig{ + User: "jack's", + Host: "localhost", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "user='jack' password='' host='localhost' dbname='mydb'", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "user = 'jack' password = '' host = 'localhost' dbname = 'mydb'", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, } for i, tt := range tests {