Prevent prematurely closing statements in database/sql

This error was introduced by 0f0d236599.
If the same statement was prepared multiple times then whenever Close
was called on one of the statements the underlying prepared statement
would be closed even if other statements were still using it.

https://github.com/jackc/pgx/issues/1754#issuecomment-1752004634
pull/1766/head
Jack Christensen 2023-10-10 21:51:42 -05:00
parent 1484fec57f
commit 7a2b93323c
2 changed files with 46 additions and 0 deletions

View File

@ -294,6 +294,7 @@ func (c connector) Connect(ctx context.Context) (driver.Conn, error) {
driver: c.driver,
connConfig: connConfig,
resetSessionFunc: c.ResetSession,
psRefCounts: make(map[*pgconn.StatementDescription]int),
}, nil
}
@ -375,6 +376,7 @@ func (dc *driverConnector) Connect(ctx context.Context) (driver.Conn, error) {
driver: dc.driver,
connConfig: *connConfig,
resetSessionFunc: func(context.Context, *pgx.Conn) error { return nil },
psRefCounts: make(map[*pgconn.StatementDescription]int),
}
return c, nil
@ -401,6 +403,14 @@ type Conn struct {
connConfig pgx.ConnConfig
resetSessionFunc func(context.Context, *pgx.Conn) error // Function is called before a connection is reused
lastResetSessionTime time.Time
// psRefCounts contains reference counts for prepared statements. Prepare uses the underlying pgx logic to generate
// deterministic statement names from the statement text. If this query has already been prepared then the existing
// *pgconn.StatementDescription will be returned. However, this means that if Close is called on the returned Stmt
// then the underlying prepared statement will be closed even when the underlying prepared statement is still in use
// by another database/sql Stmt. To prevent this psRefCounts keeps track of how many database/sql statements are using
// the same underlying statement and only closes the underlying statement when the reference count reaches 0.
psRefCounts map[*pgconn.StatementDescription]int
}
// Conn returns the underlying *pgx.Conn
@ -421,6 +431,7 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e
if err != nil {
return nil, err
}
c.psRefCounts[sd]++
return &Stmt{sd: sd, conn: c}, nil
}
@ -554,6 +565,15 @@ type Stmt struct {
func (s *Stmt) Close() error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
refCount := s.conn.psRefCounts[s.sd]
if refCount == 1 {
delete(s.conn.psRefCounts, s.sd)
} else {
s.conn.psRefCounts[s.sd]--
return nil
}
return s.conn.conn.Deallocate(ctx, s.sd.SQL)
}

View File

@ -801,6 +801,32 @@ func TestConnPrepareContextSuccess(t *testing.T) {
})
}
// https://github.com/jackc/pgx/issues/1753#issuecomment-1746033281
// https://github.com/jackc/pgx/issues/1754#issuecomment-1752004634
func TestConnMultiplePrepareAndDeallocate(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
sql := "select 42"
stmt1, err := db.PrepareContext(context.Background(), sql)
require.NoError(t, err)
stmt2, err := db.PrepareContext(context.Background(), sql)
require.NoError(t, err)
err = stmt1.Close()
require.NoError(t, err)
var preparedStmtCount int64
err = db.QueryRowContext(context.Background(), "select count(*) from pg_prepared_statements where statement = $1", sql).Scan(&preparedStmtCount)
require.NoError(t, err)
require.EqualValues(t, 1, preparedStmtCount)
err = stmt2.Close() // err isn't as useful as it should be as database/sql will ignore errors from Deallocate.
require.NoError(t, err)
err = db.QueryRowContext(context.Background(), "select count(*) from pg_prepared_statements where statement = $1", sql).Scan(&preparedStmtCount)
require.NoError(t, err)
require.EqualValues(t, 0, preparedStmtCount)
})
}
func TestConnExecContextSuccess(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
_, err := db.ExecContext(context.Background(), "create temporary table exec_context_test(id serial primary key)")