stdlib: implement Conn.ResetSession

This prevents closed connections from being returned
by `database.sql.DB.Conn`.

Fixes #974.
pull/982/head
Jonathan Amsterdam 2021-03-26 08:32:04 -04:00 committed by Jack Christensen
parent 909b81a163
commit 88ede6efb5
2 changed files with 13 additions and 5 deletions

View File

@ -385,6 +385,13 @@ func (c *Conn) CheckNamedValue(*driver.NamedValue) error {
return nil
}
func (c *Conn) ResetSession(ctx context.Context) error {
if c.conn.IsClosed() {
return driver.ErrBadConn
}
return nil
}
type Stmt struct {
sd *pgconn.StatementDescription
conn *Conn

View File

@ -4,7 +4,6 @@ import (
"bytes"
"context"
"database/sql"
"database/sql/driver"
"encoding/json"
"math"
"os"
@ -719,7 +718,8 @@ func TestConnExecContextSuccess(t *testing.T) {
func TestConnExecContextFailureRetry(t *testing.T) {
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
// we get a connection, immediately close it, and then get it back
// We get a connection, immediately close it, and then get it back;
// DB.Conn along with Conn.ResetSession does the retry for us.
{
conn, err := stdlib.AcquireConn(db)
require.NoError(t, err)
@ -729,7 +729,7 @@ func TestConnExecContextFailureRetry(t *testing.T) {
conn, err := db.Conn(context.Background())
require.NoError(t, err)
_, err = conn.ExecContext(context.Background(), "select 1")
require.EqualValues(t, driver.ErrBadConn, err)
require.NoError(t, err)
})
}
@ -749,7 +749,8 @@ func TestConnQueryContextSuccess(t *testing.T) {
func TestConnQueryContextFailureRetry(t *testing.T) {
testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) {
// we get a connection, immediately close it, and then get it back
// We get a connection, immediately close it, and then get it back;
// DB.Conn along with Conn.ResetSession does the retry for us.
{
conn, err := stdlib.AcquireConn(db)
require.NoError(t, err)
@ -760,7 +761,7 @@ func TestConnQueryContextFailureRetry(t *testing.T) {
require.NoError(t, err)
_, err = conn.QueryContext(context.Background(), "select 1")
require.EqualValues(t, driver.ErrBadConn, err)
require.NoError(t, err)
})
}