diff --git a/stdlib/sql.go b/stdlib/sql.go index 067c114f..71380e98 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -385,6 +385,13 @@ func (c *Conn) CheckNamedValue(*driver.NamedValue) error { return nil } +func (c *Conn) ResetSession(ctx context.Context) error { + if c.conn.IsClosed() { + return driver.ErrBadConn + } + return nil +} + type Stmt struct { sd *pgconn.StatementDescription conn *Conn diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 6dc1d2f6..1365f230 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "database/sql" - "database/sql/driver" "encoding/json" "math" "os" @@ -719,7 +718,8 @@ func TestConnExecContextSuccess(t *testing.T) { func TestConnExecContextFailureRetry(t *testing.T) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { - // we get a connection, immediately close it, and then get it back + // We get a connection, immediately close it, and then get it back; + // DB.Conn along with Conn.ResetSession does the retry for us. { conn, err := stdlib.AcquireConn(db) require.NoError(t, err) @@ -729,7 +729,7 @@ func TestConnExecContextFailureRetry(t *testing.T) { conn, err := db.Conn(context.Background()) require.NoError(t, err) _, err = conn.ExecContext(context.Background(), "select 1") - require.EqualValues(t, driver.ErrBadConn, err) + require.NoError(t, err) }) } @@ -749,7 +749,8 @@ func TestConnQueryContextSuccess(t *testing.T) { func TestConnQueryContextFailureRetry(t *testing.T) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { - // we get a connection, immediately close it, and then get it back + // We get a connection, immediately close it, and then get it back; + // DB.Conn along with Conn.ResetSession does the retry for us. { conn, err := stdlib.AcquireConn(db) require.NoError(t, err) @@ -760,7 +761,7 @@ func TestConnQueryContextFailureRetry(t *testing.T) { require.NoError(t, err) _, err = conn.QueryContext(context.Background(), "select 1") - require.EqualValues(t, driver.ErrBadConn, err) + require.NoError(t, err) }) }