diff --git a/conn.go b/conn.go index 50a6ffb0..43cde230 100644 --- a/conn.go +++ b/conn.go @@ -135,6 +135,7 @@ type Conn struct { pendingReadyForQueryCount int // number of ReadyForQuery messages expected cancelQueryCompleted chan struct{} + lastStmtSent bool // context support ctxInProgress bool @@ -1731,6 +1732,7 @@ func (c *Conn) Ping(ctx context.Context) error { } func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (CommandTag, error) { + c.lastStmtSent = false err := c.waitForPreviousCancelQuery(ctx) if err != nil { return "", err @@ -1770,6 +1772,7 @@ func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions, }() if (options == nil && c.config.PreferSimpleProtocol) || (options != nil && options.SimpleProtocol) { + c.lastStmtSent = true err = c.sanitizeAndSendSimpleQuery(sql, arguments...) if err != nil { return "", err @@ -1786,6 +1789,7 @@ func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions, buf = appendSync(buf) + c.lastStmtSent = true n, err := c.conn.Write(buf) if err != nil && fatalWriteErr(n, err) { c.die(err) @@ -1803,11 +1807,13 @@ func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions, } } + c.lastStmtSent = true err = c.sendPreparedQuery(ps, arguments...) if err != nil { return "", err } } else { + c.lastStmtSent = true if err = c.sendQuery(sql, arguments...); err != nil { return } @@ -1978,3 +1984,14 @@ func connInfoFromRows(rows *Rows, err error) (map[string]pgtype.OID, error) { return nameOIDs, err } + +// LastStmtSent returns true if the last call to Query(Ex)/Exec(Ex) attempted to +// send the statement over the wire. Each call to a Query(Ex)/Exec(Ex) resets +// the value to false initially until the statement has been sent. This does +// NOT mean that the statement was successful or even received, it just means +// that a write was attempted and therefore it could have been executed. Calls +// to prepare a statement are ignored, only when the prepared statement is +// attempted to be executed will this return true. +func (c *Conn) LastStmtSent() bool { + return c.lastStmtSent +} diff --git a/conn_test.go b/conn_test.go index c0419d90..c745d392 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1131,12 +1131,32 @@ func TestExecFailure(t *testing.T) { if _, err := conn.Exec("selct;"); err == nil { t.Fatal("Expected SQL syntax error") } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } rows, _ := conn.Query("select 1") rows.Close() if rows.Err() != nil { t.Fatalf("Exec failure appears to have broken connection: %v", rows.Err()) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } +} + +func TestExecFailureWithArguments(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + if _, err := conn.Exec("selct $1;", 1); err == nil { + t.Fatal("Expected SQL syntax error") + } + if conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return false") + } } func TestExecExContextWithoutCancelation(t *testing.T) { @@ -1155,6 +1175,9 @@ func TestExecExContextWithoutCancelation(t *testing.T) { if commandTag != "CREATE TABLE" { t.Fatalf("Unexpected results from ExecEx: %v", commandTag) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } } func TestExecExContextFailureWithoutCancelation(t *testing.T) { @@ -1169,12 +1192,35 @@ func TestExecExContextFailureWithoutCancelation(t *testing.T) { if _, err := conn.ExecEx(ctx, "selct;", nil); err == nil { t.Fatal("Expected SQL syntax error") } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } rows, _ := conn.Query("select 1") rows.Close() if rows.Err() != nil { t.Fatalf("ExecEx failure appears to have broken connection: %v", rows.Err()) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } +} + +func TestExecExContextFailureWithoutCancelationWithArguments(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + if _, err := conn.ExecEx(ctx, "selct $1;", nil, 1); err == nil { + t.Fatal("Expected SQL syntax error") + } + if conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return false") + } } func TestExecExContextCancelationCancelsQuery(t *testing.T) { @@ -1193,10 +1239,27 @@ func TestExecExContextCancelationCancelsQuery(t *testing.T) { if err != context.Canceled { t.Fatalf("Expected context.Canceled err, got %v", err) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } ensureConnValid(t, conn) } +func TestExecFailureCloseBefore(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + closeConn(t, conn) + + if _, err := conn.Exec("select 1"); err == nil { + t.Fatal("Expected network error") + } + if conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return false") + } +} + func TestExecExExtendedProtocol(t *testing.T) { t.Parallel() @@ -1246,6 +1309,9 @@ func TestExecExSimpleProtocol(t *testing.T) { if commandTag != "CREATE TABLE" { t.Fatalf("Unexpected results from ExecEx: %v", commandTag) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } commandTag, err = conn.ExecEx( ctx, @@ -1259,6 +1325,9 @@ func TestExecExSimpleProtocol(t *testing.T) { if commandTag != "INSERT 0 1" { t.Fatalf("Unexpected results from ExecEx: %v", commandTag) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } } func TestConnExecExSuppliedCorrectParameterOIDs(t *testing.T) { @@ -1281,6 +1350,9 @@ func TestConnExecExSuppliedCorrectParameterOIDs(t *testing.T) { if commandTag != "INSERT 0 1" { t.Fatalf("Unexpected results from ExecEx: %v", commandTag) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } } func TestConnExecExSuppliedIncorrectParameterOIDs(t *testing.T) { @@ -1300,6 +1372,9 @@ func TestConnExecExSuppliedIncorrectParameterOIDs(t *testing.T) { if err == nil { t.Fatal("expected error but got none") } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } } func TestConnExecExIncorrectParameterOIDsAfterAnotherQuery(t *testing.T) { @@ -1328,6 +1403,23 @@ func TestConnExecExIncorrectParameterOIDsAfterAnotherQuery(t *testing.T) { if err == nil { t.Fatal("expected error but got none") } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } +} + +func TestExecExFailureCloseBefore(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + closeConn(t, conn) + + if _, err := conn.ExecEx(context.Background(), "select 1", nil); err == nil { + t.Fatal("Expected network error") + } + if conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return false") + } } func TestPrepare(t *testing.T) { diff --git a/query.go b/query.go index c014cacd..ad3ed84b 100644 --- a/query.go +++ b/query.go @@ -368,6 +368,7 @@ type QueryExOptions struct { } func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (rows *Rows, err error) { + c.lastStmtSent = false c.lastActivityTime = time.Now() rows = c.getRows(sql, args) @@ -395,6 +396,7 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, } if (options == nil && c.config.PreferSimpleProtocol) || (options != nil && options.SimpleProtocol) { + c.lastStmtSent = true err = c.sanitizeAndSendSimpleQuery(sql, args...) if err != nil { rows.fatal(err) @@ -414,6 +416,7 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, buf = appendSync(buf) + c.lastStmtSent = true n, err := c.conn.Write(buf) if err != nil && fatalWriteErr(n, err) { rows.fatal(err) @@ -460,6 +463,7 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, rows.sql = ps.SQL rows.fields = ps.FieldDescriptions + c.lastStmtSent = true err = c.sendPreparedQuery(ps, args...) if err != nil { rows.fatal(err) diff --git a/query_test.go b/query_test.go index 6b6b5fac..06b7b8b7 100644 --- a/query_test.go +++ b/query_test.go @@ -283,6 +283,9 @@ func TestConnQueryCloseEarlyWithErrorOnWire(t *testing.T) { if err != nil { t.Fatalf("conn.Query failed: %v", err) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } rows.Close() ensureConnValid(t, conn) @@ -431,6 +434,9 @@ func TestQueryEncodeError(t *testing.T) { if err != nil { t.Errorf("conn.Query failure: %v", err) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } defer rows.Close() rows.Next() @@ -1186,6 +1192,9 @@ func TestQueryExContextSuccess(t *testing.T) { if err != nil { t.Fatal(err) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } var result, rowCount int for rows.Next() { @@ -1263,6 +1272,9 @@ func TestQueryExContextCancelationCancelsQuery(t *testing.T) { if err != nil { t.Fatal(err) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } for rows.Next() { t.Fatal("No rows should ever be ready -- context cancel apparently did not happen") @@ -1292,6 +1304,9 @@ func TestQueryRowExContextSuccess(t *testing.T) { if result != 42 { t.Fatalf("Expected result 42, got %d", result) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } ensureConnValid(t, conn) } @@ -1331,6 +1346,9 @@ func TestQueryRowExContextCancelationCancelsQuery(t *testing.T) { if err != context.Canceled { t.Fatalf("Expected context.Canceled error, got %v", err) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } ensureConnValid(t, conn) } @@ -1384,6 +1402,9 @@ func TestConnSimpleProtocol(t *testing.T) { if expected != actual { t.Errorf("expected %v got %v", expected, actual) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } } { @@ -1401,6 +1422,9 @@ func TestConnSimpleProtocol(t *testing.T) { if expected != actual { t.Errorf("expected %v got %v", expected, actual) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } } { @@ -1418,6 +1442,9 @@ func TestConnSimpleProtocol(t *testing.T) { if expected != actual { t.Errorf("expected %v got %v", expected, actual) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } } { @@ -1435,6 +1462,9 @@ func TestConnSimpleProtocol(t *testing.T) { if bytes.Compare(actual, expected) != 0 { t.Errorf("expected %v got %v", expected, actual) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } } { @@ -1452,6 +1482,9 @@ func TestConnSimpleProtocol(t *testing.T) { if expected != actual { t.Errorf("expected %v got %v", expected, actual) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } } // Test high-level type @@ -1471,6 +1504,9 @@ func TestConnSimpleProtocol(t *testing.T) { if expected != actual { t.Errorf("expected %v got %v", expected, actual) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } } // Test multiple args in single query @@ -1510,6 +1546,9 @@ func TestConnSimpleProtocol(t *testing.T) { if expectedString != actualString { t.Errorf("expected %v got %v", expectedString, actualString) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } } // Test dangerous cases @@ -1529,6 +1568,9 @@ func TestConnSimpleProtocol(t *testing.T) { if expected != actual { t.Errorf("expected %v got %v", expected, actual) } + if !conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return true") + } } ensureConnValid(t, conn) @@ -1577,3 +1619,17 @@ func TestConnSimpleProtocolRefusesNonStandardConformingStrings(t *testing.T) { ensureConnValid(t, conn) } + +func TestQueryExCloseBefore(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + closeConn(t, conn) + + if _, err := conn.QueryEx(context.Background(), "select 1", nil); err == nil { + t.Fatal("Expected network error") + } + if conn.LastStmtSent() { + t.Error("Expected LastStmtSent to return false") + } +} diff --git a/stdlib/sql.go b/stdlib/sql.go index 4d69d259..b83e527b 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -75,6 +75,7 @@ import ( "encoding/binary" "fmt" "io" + "net" "reflect" "strings" "sync" @@ -292,6 +293,12 @@ func (c *Conn) Exec(query string, argsV []driver.Value) (driver.Result, error) { args := valueToInterface(argsV) commandTag, err := c.conn.Exec(query, args...) + // if we got a network error before we had a chance to send the query, retry + if err != nil && !c.conn.LastStmtSent() { + if _, is := err.(net.Error); is { + return nil, driver.ErrBadConn + } + } return driver.RowsAffected(commandTag.RowsAffected()), err } @@ -303,6 +310,12 @@ func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.Nam args := namedValueToInterface(argsV) commandTag, err := c.conn.ExecEx(ctx, query, nil, args...) + // if we got a network error before we had a chance to send the query, retry + if err != nil && !c.conn.LastStmtSent() { + if _, is := err.(net.Error); is { + return nil, driver.ErrBadConn + } + } return driver.RowsAffected(commandTag.RowsAffected()), err } @@ -323,6 +336,12 @@ func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) { rows, err := c.conn.Query(query, valueToInterface(argsV)...) if err != nil { + // if we got a network error before we had a chance to send the query, retry + if !c.conn.LastStmtSent() { + if _, is := err.(net.Error); is { + return nil, driver.ErrBadConn + } + } return nil, err } @@ -339,6 +358,11 @@ func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.Na if !c.connConfig.PreferSimpleProtocol { ps, err := c.conn.PrepareEx(ctx, "", query, nil) if err != nil { + // since PrepareEx failed, we didn't actually get to send the values, so + // we can safely retry + if _, is := err.(net.Error); is { + return nil, driver.ErrBadConn + } return nil, err } @@ -348,6 +372,12 @@ func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.Na rows, err := c.conn.QueryEx(ctx, query, nil, namedValueToInterface(argsV)...) if err != nil { + // if we got a network error before we had a chance to send the query, retry + if !c.conn.LastStmtSent() { + if _, is := err.(net.Error); is { + return nil, driver.ErrBadConn + } + } return nil, err } diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 78f3e6d4..cf2b91b1 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "database/sql" + "database/sql/driver" "encoding/json" "fmt" "math" @@ -989,6 +990,28 @@ func TestConnExecContextCancel(t *testing.T) { } } +func TestConnExecContextFailureRetry(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + // we get a connection, immediately close it, and then get it back + { + conn, err := stdlib.AcquireConn(db) + if err != nil { + t.Fatalf("stdlib.AcquireConn unexpectedly failed: %v", err) + } + conn.Close() + stdlib.ReleaseConn(db, conn) + } + conn, err := db.Conn(context.Background()) + if err != nil { + t.Fatalf("db.Conn unexpectedly failed: %v", err) + } + if _, err := conn.ExecContext(context.Background(), "select 1"); err != driver.ErrBadConn { + t.Fatalf("Expected conn.ExecContext to return driver.ErrBadConn, but instead received: %v", err) + } +} + func TestConnQueryContextSuccess(t *testing.T) { db := openDB(t) defer closeDB(t, db) @@ -1083,6 +1106,28 @@ func TestConnQueryContextCancel(t *testing.T) { } } +func TestConnQueryContextFailureRetry(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + // we get a connection, immediately close it, and then get it back + { + conn, err := stdlib.AcquireConn(db) + if err != nil { + t.Fatalf("stdlib.AcquireConn unexpectedly failed: %v", err) + } + conn.Close() + stdlib.ReleaseConn(db, conn) + } + conn, err := db.Conn(context.Background()) + if err != nil { + t.Fatalf("db.Conn unexpectedly failed: %v", err) + } + if _, err := conn.QueryContext(context.Background(), "select 1"); err != driver.ErrBadConn { + t.Fatalf("Expected conn.QueryContext to return driver.ErrBadConn, but instead received: %v", err) + } +} + func TestRowsColumnTypeDatabaseTypeName(t *testing.T) { db := openDB(t) defer closeDB(t, db)