diff --git a/conn.go b/conn.go index 03f9d190..121297b8 100644 --- a/conn.go +++ b/conn.go @@ -105,9 +105,9 @@ type ConnConfig struct { // If multiple hosts were given in the Host parameter, then // this parameter may specify a single port number to be used for all hosts, // or for those that haven't port explicitly defined. - Port uint16 - Database string - User string // default: OS user name + Port uint16 + Database string + User string // default: OS user name Password string TLSConfig *tls.Config // config for TLS connection -- nil disables TLS UseFallbackTLS bool // Try FallbackTLSConfig if connecting with TLSConfig fails. Used for preferring TLS, but allowing unencrypted, or vice-versa @@ -307,7 +307,8 @@ type Identifier []string func (ident Identifier) Sanitize() string { parts := make([]string, len(ident)) for i := range ident { - parts[i] = `"` + strings.Replace(ident[i], `"`, `""`, -1) + `"` + s := strings.Replace(ident[i], string([]byte{0}), "", -1) + parts[i] = `"` + strings.Replace(s, `"`, `""`, -1) + `"` } return strings.Join(parts, ".") } diff --git a/conn_test.go b/conn_test.go index 7719bec7..fea3b659 100644 --- a/conn_test.go +++ b/conn_test.go @@ -84,7 +84,6 @@ func TestConnect(t *testing.T) { } } - func TestConnectWithMultiHost(t *testing.T) { t.Parallel() @@ -129,7 +128,6 @@ func TestConnectWithMultiHost(t *testing.T) { } } - func TestConnectWithMultiHostWritable(t *testing.T) { t.Parallel() @@ -818,9 +816,9 @@ func TestParseDSN(t *testing.T) { TLSConfig: &tls.Config{ InsecureSkipVerify: true, }, - UseFallbackTLS: true, - FallbackTLSConfig: nil, - RuntimeParams: map[string]string{}, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, TargetSessionAttrs: pgx.ReadWriteTargetSession, }, }, @@ -2319,6 +2317,24 @@ func TestSetLogLevel(t *testing.T) { } } +func TestIdentifierSanitizeNullSentToServer(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ident := pgx.Identifier{"foo" + string([]byte{0}) + "bar"} + + var n int64 + err := conn.QueryRow(`select 1 as ` + ident.Sanitize()).Scan(&n) + if err != nil { + t.Fatal(err) + } + if n != 1 { + t.Fatal("unexpected n") + } +} + func TestIdentifierSanitize(t *testing.T) { t.Parallel() @@ -2346,6 +2362,10 @@ func TestIdentifierSanitize(t *testing.T) { ident: pgx.Identifier{`you should " not do this`, `please don't`}, expected: `"you should "" not do this"."please don't"`, }, + { + ident: pgx.Identifier{`you should ` + string([]byte{0}) + `not do this`}, + expected: `"you should not do this"`, + }, } for i, tt := range tests {