From 7a2b93323c1851ebecf7d742e16b79b4915dcf8c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 10 Oct 2023 21:51:42 -0500 Subject: [PATCH] Prevent prematurely closing statements in database/sql This error was introduced by 0f0d23659950bbf7a1677e50aac09b1e29ad7c60. If the same statement was prepared multiple times then whenever Close was called on one of the statements the underlying prepared statement would be closed even if other statements were still using it. https://github.com/jackc/pgx/issues/1754#issuecomment-1752004634 --- stdlib/sql.go | 20 ++++++++++++++++++++ stdlib/sql_test.go | 26 ++++++++++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/stdlib/sql.go b/stdlib/sql.go index a50720f9..f688f70c 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -294,6 +294,7 @@ func (c connector) Connect(ctx context.Context) (driver.Conn, error) { driver: c.driver, connConfig: connConfig, resetSessionFunc: c.ResetSession, + psRefCounts: make(map[*pgconn.StatementDescription]int), }, nil } @@ -375,6 +376,7 @@ func (dc *driverConnector) Connect(ctx context.Context) (driver.Conn, error) { driver: dc.driver, connConfig: *connConfig, resetSessionFunc: func(context.Context, *pgx.Conn) error { return nil }, + psRefCounts: make(map[*pgconn.StatementDescription]int), } return c, nil @@ -401,6 +403,14 @@ type Conn struct { connConfig pgx.ConnConfig resetSessionFunc func(context.Context, *pgx.Conn) error // Function is called before a connection is reused lastResetSessionTime time.Time + + // psRefCounts contains reference counts for prepared statements. Prepare uses the underlying pgx logic to generate + // deterministic statement names from the statement text. If this query has already been prepared then the existing + // *pgconn.StatementDescription will be returned. However, this means that if Close is called on the returned Stmt + // then the underlying prepared statement will be closed even when the underlying prepared statement is still in use + // by another database/sql Stmt. To prevent this psRefCounts keeps track of how many database/sql statements are using + // the same underlying statement and only closes the underlying statement when the reference count reaches 0. + psRefCounts map[*pgconn.StatementDescription]int } // Conn returns the underlying *pgx.Conn @@ -421,6 +431,7 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e if err != nil { return nil, err } + c.psRefCounts[sd]++ return &Stmt{sd: sd, conn: c}, nil } @@ -554,6 +565,15 @@ type Stmt struct { func (s *Stmt) Close() error { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() + + refCount := s.conn.psRefCounts[s.sd] + if refCount == 1 { + delete(s.conn.psRefCounts, s.sd) + } else { + s.conn.psRefCounts[s.sd]-- + return nil + } + return s.conn.conn.Deallocate(ctx, s.sd.SQL) } diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index eb732135..daf77e8c 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -801,6 +801,32 @@ func TestConnPrepareContextSuccess(t *testing.T) { }) } +// https://github.com/jackc/pgx/issues/1753#issuecomment-1746033281 +// https://github.com/jackc/pgx/issues/1754#issuecomment-1752004634 +func TestConnMultiplePrepareAndDeallocate(t *testing.T) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + sql := "select 42" + stmt1, err := db.PrepareContext(context.Background(), sql) + require.NoError(t, err) + stmt2, err := db.PrepareContext(context.Background(), sql) + require.NoError(t, err) + err = stmt1.Close() + require.NoError(t, err) + + var preparedStmtCount int64 + err = db.QueryRowContext(context.Background(), "select count(*) from pg_prepared_statements where statement = $1", sql).Scan(&preparedStmtCount) + require.NoError(t, err) + require.EqualValues(t, 1, preparedStmtCount) + + err = stmt2.Close() // err isn't as useful as it should be as database/sql will ignore errors from Deallocate. + require.NoError(t, err) + + err = db.QueryRowContext(context.Background(), "select count(*) from pg_prepared_statements where statement = $1", sql).Scan(&preparedStmtCount) + require.NoError(t, err) + require.EqualValues(t, 0, preparedStmtCount) + }) +} + func TestConnExecContextSuccess(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { _, err := db.ExecContext(context.Background(), "create temporary table exec_context_test(id serial primary key)")