diff --git a/CHANGELOG.md b/CHANGELOG.md index 0dbef44a..305c2a36 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,29 @@ +# 3.3.0 (December 1, 2018) + +## Features + +* Add CopyFromReader and CopyToWriter (Murat Kabilov) +* Add MacaddrArray (Anthony Regeda) +* Add float types to FieldDescription.Type (David Yamnitsky) +* Add CheckedOutConnections helper method (MOZGIII) +* Add host query parameter to support Unix sockets (Jörg Thalheim) +* Custom cancelation hook for use with PostgreSQL-like databases (James Hartig) +* Added LastStmtSent for safe retry logic (James Hartig) + +## Fixes + +* Do not silently ignore assign NULL to \*string +* Fix issue with JSON and driver.Valuer conversion +* Fix race with stdlib Driver.configs Open (Greg Curtis) + +## Changes + +* Connection pool uses connections in queue order instead of stack. This + minimized the time any connection is idle vs. any other connection. + (Anthony Regeda) +* FieldDescription.Modifier is int32 instead of uint32 +* tls: stop sending ssl_renegotiation_limit in startup message (Tejas Manohar) + # 3.2.0 (August 7, 2018) ## Features diff --git a/conn.go b/conn.go index c6249c60..447af243 100644 --- a/conn.go +++ b/conn.go @@ -113,6 +113,7 @@ type Conn struct { pendingReadyForQueryCount int // number of ReadyForQuery messages expected cancelQueryCompleted chan struct{} + lastStmtSent bool // context support ctxInProgress bool @@ -944,7 +945,7 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { // PrepareEx creates a prepared statement with name and sql. sql can contain placeholders // for bound parameters. These placeholders are referenced positional as $1, $2, etc. -// It defers from Prepare as it allows additional options (such as parameter OIDs) to be passed via struct +// It differs from Prepare as it allows additional options (such as parameter OIDs) to be passed via struct // // PrepareEx is idempotent; i.e. it is safe to call PrepareEx multiple times with the same // name and sql arguments. This allows a code path to PrepareEx and Query/Exec without @@ -1616,6 +1617,7 @@ func (c *Conn) Ping(ctx context.Context) error { } func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (CommandTag, error) { + c.lastStmtSent = false err := c.waitForPreviousCancelQuery(ctx) if err != nil { return "", err @@ -1654,6 +1656,7 @@ func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions, }() if (options == nil && c.config.PreferSimpleProtocol) || (options != nil && options.SimpleProtocol) { + c.lastStmtSent = true err = c.sanitizeAndSendSimpleQuery(sql, arguments...) if err != nil { return "", err @@ -1671,6 +1674,7 @@ func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions, buf = appendSync(buf) n, err := c.BaseConn.NetConn.Write(buf) + c.lastStmtSent = true if err != nil && fatalWriteErr(n, err) { c.die(err) return "", err @@ -1687,11 +1691,13 @@ func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions, } } + c.lastStmtSent = true err = c.sendPreparedQuery(ps, arguments...) if err != nil { return "", err } } else { + c.lastStmtSent = true if err = c.sendQuery(sql, arguments...); err != nil { return } @@ -1862,3 +1868,14 @@ func connInfoFromRows(rows *Rows, err error) (map[string]pgtype.OID, error) { return nameOIDs, err } + +// LastStmtSent returns true if the last call to Query(Ex)/Exec(Ex) attempted to +// send the statement over the wire. Each call to a Query(Ex)/Exec(Ex) resets +// the value to false initially until the statement has been sent. This does +// NOT mean that the statement was successful or even received, it just means +// that a write was attempted and therefore it could have been executed. Calls +// to prepare a statement are ignored, only when the prepared statement is +// attempted to be executed will this return true. +func (c *Conn) LastStmtSent() bool { + return c.lastStmtSent +} diff --git a/conn_pool.go b/conn_pool.go index fda874ba..947ebe1c 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -449,7 +449,7 @@ func (p *ConnPool) Prepare(name, sql string) (*PreparedStatement, error) { // // PrepareEx creates a prepared statement with name and sql. sql can contain placeholders // for bound parameters. These placeholders are referenced positional as $1, $2, etc. -// It defers from Prepare as it allows additional options (such as parameter OIDs) to be passed via struct +// It differs from Prepare as it allows additional options (such as parameter OIDs) to be passed via struct // // PrepareEx is idempotent; i.e. it is safe to call PrepareEx multiple times with the same // name and sql arguments. This allows a code path to PrepareEx and Query/Exec/Prepare without diff --git a/conn_test.go b/conn_test.go index b245af2e..a9aaae21 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1131,12 +1131,32 @@ func TestExecFailure(t *testing.T) { if _, err := conn.Exec("selct;"); err == nil { t.Fatal("Expected SQL syntax error") } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } rows, _ := conn.Query("select 1") rows.Close() if rows.Err() != nil { t.Fatalf("Exec failure appears to have broken connection: %v", rows.Err()) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } +} + +func TestExecFailureWithArguments(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + if _, err := conn.Exec("selct $1;", 1); err == nil { + t.Fatal("Expected SQL syntax error") + } + if conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return false") + } } func TestExecExContextWithoutCancelation(t *testing.T) { @@ -1155,6 +1175,9 @@ func TestExecExContextWithoutCancelation(t *testing.T) { if commandTag != "CREATE TABLE" { t.Fatalf("Unexpected results from ExecEx: %v", commandTag) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } } func TestExecExContextFailureWithoutCancelation(t *testing.T) { @@ -1169,12 +1192,35 @@ func TestExecExContextFailureWithoutCancelation(t *testing.T) { if _, err := conn.ExecEx(ctx, "selct;", nil); err == nil { t.Fatal("Expected SQL syntax error") } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } rows, _ := conn.Query("select 1") rows.Close() if rows.Err() != nil { t.Fatalf("ExecEx failure appears to have broken connection: %v", rows.Err()) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } +} + +func TestExecExContextFailureWithoutCancelationWithArguments(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + if _, err := conn.ExecEx(ctx, "selct $1;", nil, 1); err == nil { + t.Fatal("Expected SQL syntax error") + } + if conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return false") + } } func TestExecExContextCancelationCancelsQuery(t *testing.T) { @@ -1193,10 +1239,27 @@ func TestExecExContextCancelationCancelsQuery(t *testing.T) { if err != context.Canceled { t.Fatalf("Expected context.Canceled err, got %v", err) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } ensureConnValid(t, conn) } +func TestExecFailureCloseBefore(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + closeConn(t, conn) + + if _, err := conn.Exec("select 1"); err == nil { + t.Fatal("Expected network error") + } + if conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return false") + } +} + func TestExecExExtendedProtocol(t *testing.T) { t.Parallel() @@ -1246,6 +1309,9 @@ func TestExecExSimpleProtocol(t *testing.T) { if commandTag != "CREATE TABLE" { t.Fatalf("Unexpected results from ExecEx: %v", commandTag) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } commandTag, err = conn.ExecEx( ctx, @@ -1259,6 +1325,9 @@ func TestExecExSimpleProtocol(t *testing.T) { if commandTag != "INSERT 0 1" { t.Fatalf("Unexpected results from ExecEx: %v", commandTag) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } } func TestConnExecExSuppliedCorrectParameterOIDs(t *testing.T) { @@ -1281,6 +1350,9 @@ func TestConnExecExSuppliedCorrectParameterOIDs(t *testing.T) { if commandTag != "INSERT 0 1" { t.Fatalf("Unexpected results from ExecEx: %v", commandTag) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } } func TestConnExecExSuppliedIncorrectParameterOIDs(t *testing.T) { @@ -1300,6 +1372,9 @@ func TestConnExecExSuppliedIncorrectParameterOIDs(t *testing.T) { if err == nil { t.Fatal("expected error but got none") } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } } func TestConnExecExIncorrectParameterOIDsAfterAnotherQuery(t *testing.T) { @@ -1328,6 +1403,23 @@ func TestConnExecExIncorrectParameterOIDsAfterAnotherQuery(t *testing.T) { if err == nil { t.Fatal("expected error but got none") } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } +} + +func TestExecExFailureCloseBefore(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + closeConn(t, conn) + + if _, err := conn.ExecEx(context.Background(), "select 1", nil); err == nil { + t.Fatal("Expected network error") + } + if conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return false") + } } func TestPrepare(t *testing.T) { diff --git a/pgtype/uuid.go b/pgtype/uuid.go index f8297b39..5e1eead5 100644 --- a/pgtype/uuid.go +++ b/pgtype/uuid.go @@ -87,6 +87,9 @@ func (src *UUID) AssignTo(dst interface{}) error { // parseUUID converts a string UUID in standard form to a byte array. func parseUUID(src string) (dst [16]byte, err error) { + if len(src) < 36 { + return dst, errors.Errorf("cannot parse UUID %v", src) + } src = src[0:8] + src[9:13] + src[14:18] + src[19:23] + src[24:] buf, err := hex.DecodeString(src) if err != nil { diff --git a/query.go b/query.go index c79540fa..d807e22c 100644 --- a/query.go +++ b/query.go @@ -368,6 +368,7 @@ type QueryExOptions struct { } func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (rows *Rows, err error) { + c.lastStmtSent = false rows = c.getRows(sql, args) err = c.waitForPreviousCancelQuery(ctx) @@ -394,6 +395,7 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, } if (options == nil && c.config.PreferSimpleProtocol) || (options != nil && options.SimpleProtocol) { + c.lastStmtSent = true err = c.sanitizeAndSendSimpleQuery(sql, args...) if err != nil { rows.fatal(err) @@ -414,6 +416,7 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, buf = appendSync(buf) n, err := c.BaseConn.NetConn.Write(buf) + c.lastStmtSent = true if err != nil && fatalWriteErr(n, err) { rows.fatal(err) c.die(err) @@ -459,6 +462,7 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, rows.sql = ps.SQL rows.fields = ps.FieldDescriptions + c.lastStmtSent = true err = c.sendPreparedQuery(ps, args...) if err != nil { rows.fatal(err) diff --git a/query_test.go b/query_test.go index 6b6b5fac..06b7b8b7 100644 --- a/query_test.go +++ b/query_test.go @@ -283,6 +283,9 @@ func TestConnQueryCloseEarlyWithErrorOnWire(t *testing.T) { if err != nil { t.Fatalf("conn.Query failed: %v", err) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } rows.Close() ensureConnValid(t, conn) @@ -431,6 +434,9 @@ func TestQueryEncodeError(t *testing.T) { if err != nil { t.Errorf("conn.Query failure: %v", err) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } defer rows.Close() rows.Next() @@ -1186,6 +1192,9 @@ func TestQueryExContextSuccess(t *testing.T) { if err != nil { t.Fatal(err) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } var result, rowCount int for rows.Next() { @@ -1263,6 +1272,9 @@ func TestQueryExContextCancelationCancelsQuery(t *testing.T) { if err != nil { t.Fatal(err) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } for rows.Next() { t.Fatal("No rows should ever be ready -- context cancel apparently did not happen") @@ -1292,6 +1304,9 @@ func TestQueryRowExContextSuccess(t *testing.T) { if result != 42 { t.Fatalf("Expected result 42, got %d", result) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } ensureConnValid(t, conn) } @@ -1331,6 +1346,9 @@ func TestQueryRowExContextCancelationCancelsQuery(t *testing.T) { if err != context.Canceled { t.Fatalf("Expected context.Canceled error, got %v", err) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } ensureConnValid(t, conn) } @@ -1384,6 +1402,9 @@ func TestConnSimpleProtocol(t *testing.T) { if expected != actual { t.Errorf("expected %v got %v", expected, actual) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } } { @@ -1401,6 +1422,9 @@ func TestConnSimpleProtocol(t *testing.T) { if expected != actual { t.Errorf("expected %v got %v", expected, actual) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } } { @@ -1418,6 +1442,9 @@ func TestConnSimpleProtocol(t *testing.T) { if expected != actual { t.Errorf("expected %v got %v", expected, actual) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } } { @@ -1435,6 +1462,9 @@ func TestConnSimpleProtocol(t *testing.T) { if bytes.Compare(actual, expected) != 0 { t.Errorf("expected %v got %v", expected, actual) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } } { @@ -1452,6 +1482,9 @@ func TestConnSimpleProtocol(t *testing.T) { if expected != actual { t.Errorf("expected %v got %v", expected, actual) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } } // Test high-level type @@ -1471,6 +1504,9 @@ func TestConnSimpleProtocol(t *testing.T) { if expected != actual { t.Errorf("expected %v got %v", expected, actual) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } } // Test multiple args in single query @@ -1510,6 +1546,9 @@ func TestConnSimpleProtocol(t *testing.T) { if expectedString != actualString { t.Errorf("expected %v got %v", expectedString, actualString) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } } // Test dangerous cases @@ -1529,6 +1568,9 @@ func TestConnSimpleProtocol(t *testing.T) { if expected != actual { t.Errorf("expected %v got %v", expected, actual) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } } ensureConnValid(t, conn) @@ -1577,3 +1619,17 @@ func TestConnSimpleProtocolRefusesNonStandardConformingStrings(t *testing.T) { ensureConnValid(t, conn) } + +func TestQueryExCloseBefore(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + closeConn(t, conn) + + if _, err := conn.QueryEx(context.Background(), "select 1", nil); err == nil { + t.Fatal("Expected network error") + } + if conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return false") + } +} diff --git a/stdlib/sql.go b/stdlib/sql.go index 3ec58552..b83e527b 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -75,6 +75,7 @@ import ( "encoding/binary" "fmt" "io" + "net" "reflect" "strings" "sync" @@ -140,8 +141,10 @@ func (d *Driver) Open(name string) (driver.Conn, error) { if len(name) >= 9 && name[0] == 0 { idBuf := []byte(name)[1:9] id := int64(binary.BigEndian.Uint64(idBuf)) + d.configMutex.Lock() connConfig = d.configs[id].ConnConfig afterConnect = d.configs[id].AfterConnect + d.configMutex.Unlock() name = name[9:] } @@ -290,6 +293,12 @@ func (c *Conn) Exec(query string, argsV []driver.Value) (driver.Result, error) { args := valueToInterface(argsV) commandTag, err := c.conn.Exec(query, args...) + // if we got a network error before we had a chance to send the query, retry + if err != nil && !c.conn.LastStmtSent() { + if _, is := err.(net.Error); is { + return nil, driver.ErrBadConn + } + } return driver.RowsAffected(commandTag.RowsAffected()), err } @@ -301,6 +310,12 @@ func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.Nam args := namedValueToInterface(argsV) commandTag, err := c.conn.ExecEx(ctx, query, nil, args...) + // if we got a network error before we had a chance to send the query, retry + if err != nil && !c.conn.LastStmtSent() { + if _, is := err.(net.Error); is { + return nil, driver.ErrBadConn + } + } return driver.RowsAffected(commandTag.RowsAffected()), err } @@ -321,6 +336,12 @@ func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) { rows, err := c.conn.Query(query, valueToInterface(argsV)...) if err != nil { + // if we got a network error before we had a chance to send the query, retry + if !c.conn.LastStmtSent() { + if _, is := err.(net.Error); is { + return nil, driver.ErrBadConn + } + } return nil, err } @@ -337,6 +358,11 @@ func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.Na if !c.connConfig.PreferSimpleProtocol { ps, err := c.conn.PrepareEx(ctx, "", query, nil) if err != nil { + // since PrepareEx failed, we didn't actually get to send the values, so + // we can safely retry + if _, is := err.(net.Error); is { + return nil, driver.ErrBadConn + } return nil, err } @@ -346,6 +372,12 @@ func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.Na rows, err := c.conn.QueryEx(ctx, query, nil, namedValueToInterface(argsV)...) if err != nil { + // if we got a network error before we had a chance to send the query, retry + if !c.conn.LastStmtSent() { + if _, is := err.(net.Error); is { + return nil, driver.ErrBadConn + } + } return nil, err } diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 78f3e6d4..cf2b91b1 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "database/sql" + "database/sql/driver" "encoding/json" "fmt" "math" @@ -989,6 +990,28 @@ func TestConnExecContextCancel(t *testing.T) { } } +func TestConnExecContextFailureRetry(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + // we get a connection, immediately close it, and then get it back + { + conn, err := stdlib.AcquireConn(db) + if err != nil { + t.Fatalf("stdlib.AcquireConn unexpectedly failed: %v", err) + } + conn.Close() + stdlib.ReleaseConn(db, conn) + } + conn, err := db.Conn(context.Background()) + if err != nil { + t.Fatalf("db.Conn unexpectedly failed: %v", err) + } + if _, err := conn.ExecContext(context.Background(), "select 1"); err != driver.ErrBadConn { + t.Fatalf("Expected conn.ExecContext to return driver.ErrBadConn, but instead received: %v", err) + } +} + func TestConnQueryContextSuccess(t *testing.T) { db := openDB(t) defer closeDB(t, db) @@ -1083,6 +1106,28 @@ func TestConnQueryContextCancel(t *testing.T) { } } +func TestConnQueryContextFailureRetry(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + // we get a connection, immediately close it, and then get it back + { + conn, err := stdlib.AcquireConn(db) + if err != nil { + t.Fatalf("stdlib.AcquireConn unexpectedly failed: %v", err) + } + conn.Close() + stdlib.ReleaseConn(db, conn) + } + conn, err := db.Conn(context.Background()) + if err != nil { + t.Fatalf("db.Conn unexpectedly failed: %v", err) + } + if _, err := conn.QueryContext(context.Background(), "select 1"); err != driver.ErrBadConn { + t.Fatalf("Expected conn.QueryContext to return driver.ErrBadConn, but instead received: %v", err) + } +} + func TestRowsColumnTypeDatabaseTypeName(t *testing.T) { db := openDB(t) defer closeDB(t, db)