diff --git a/stdlib/sql.go b/stdlib/sql.go index a50720f9..f688f70c 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -294,6 +294,7 @@ func (c connector) Connect(ctx context.Context) (driver.Conn, error) { driver: c.driver, connConfig: connConfig, resetSessionFunc: c.ResetSession, + psRefCounts: make(map[*pgconn.StatementDescription]int), }, nil } @@ -375,6 +376,7 @@ func (dc *driverConnector) Connect(ctx context.Context) (driver.Conn, error) { driver: dc.driver, connConfig: *connConfig, resetSessionFunc: func(context.Context, *pgx.Conn) error { return nil }, + psRefCounts: make(map[*pgconn.StatementDescription]int), } return c, nil @@ -401,6 +403,14 @@ type Conn struct { connConfig pgx.ConnConfig resetSessionFunc func(context.Context, *pgx.Conn) error // Function is called before a connection is reused lastResetSessionTime time.Time + + // psRefCounts contains reference counts for prepared statements. Prepare uses the underlying pgx logic to generate + // deterministic statement names from the statement text. If this query has already been prepared then the existing + // *pgconn.StatementDescription will be returned. However, this means that if Close is called on the returned Stmt + // then the underlying prepared statement will be closed even when the underlying prepared statement is still in use + // by another database/sql Stmt. To prevent this psRefCounts keeps track of how many database/sql statements are using + // the same underlying statement and only closes the underlying statement when the reference count reaches 0. + psRefCounts map[*pgconn.StatementDescription]int } // Conn returns the underlying *pgx.Conn @@ -421,6 +431,7 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e if err != nil { return nil, err } + c.psRefCounts[sd]++ return &Stmt{sd: sd, conn: c}, nil } @@ -554,6 +565,15 @@ type Stmt struct { func (s *Stmt) Close() error { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() + + refCount := s.conn.psRefCounts[s.sd] + if refCount == 1 { + delete(s.conn.psRefCounts, s.sd) + } else { + s.conn.psRefCounts[s.sd]-- + return nil + } + return s.conn.conn.Deallocate(ctx, s.sd.SQL) } diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index eb732135..daf77e8c 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -801,6 +801,32 @@ func TestConnPrepareContextSuccess(t *testing.T) { }) } +// https://github.com/jackc/pgx/issues/1753#issuecomment-1746033281 +// https://github.com/jackc/pgx/issues/1754#issuecomment-1752004634 +func TestConnMultiplePrepareAndDeallocate(t *testing.T) { + testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { + sql := "select 42" + stmt1, err := db.PrepareContext(context.Background(), sql) + require.NoError(t, err) + stmt2, err := db.PrepareContext(context.Background(), sql) + require.NoError(t, err) + err = stmt1.Close() + require.NoError(t, err) + + var preparedStmtCount int64 + err = db.QueryRowContext(context.Background(), "select count(*) from pg_prepared_statements where statement = $1", sql).Scan(&preparedStmtCount) + require.NoError(t, err) + require.EqualValues(t, 1, preparedStmtCount) + + err = stmt2.Close() // err isn't as useful as it should be as database/sql will ignore errors from Deallocate. + require.NoError(t, err) + + err = db.QueryRowContext(context.Background(), "select count(*) from pg_prepared_statements where statement = $1", sql).Scan(&preparedStmtCount) + require.NoError(t, err) + require.EqualValues(t, 0, preparedStmtCount) + }) +} + func TestConnExecContextSuccess(t *testing.T) { testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { _, err := db.ExecContext(context.Background(), "create temporary table exec_context_test(id serial primary key)")