Fix: prepared statement already exists

When a conn is going to execute a query, the first thing it does is to
deallocate any invalidated prepared statements from the statement cache.
However, the statements were removed from the cache regardless of
whether the deallocation succeeded. This would cause subsequent calls of
the same SQL to fail with "prepared statement already exists" error.

This problem is easy to trigger by running a query with a context that
is already canceled.

This commit changes the deallocate invalidated cached statements logic
so that the statements are only removed from the cache if the
deallocation was successful on the server.

https://github.com/jackc/pgx/issues/1847
pull/1895/head
Jack Christensen 2024-02-03 12:25:57 -06:00
parent fd4411453f
commit 832b4f9771
5 changed files with 61 additions and 13 deletions

10
conn.go
View File

@ -1359,12 +1359,12 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error
} }
if c.descriptionCache != nil { if c.descriptionCache != nil {
c.descriptionCache.HandleInvalidated() c.descriptionCache.RemoveInvalidated()
} }
var invalidatedStatements []*pgconn.StatementDescription var invalidatedStatements []*pgconn.StatementDescription
if c.statementCache != nil { if c.statementCache != nil {
invalidatedStatements = c.statementCache.HandleInvalidated() invalidatedStatements = c.statementCache.GetInvalidated()
} }
if len(invalidatedStatements) == 0 { if len(invalidatedStatements) == 0 {
@ -1376,7 +1376,6 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error
for _, sd := range invalidatedStatements { for _, sd := range invalidatedStatements {
pipeline.SendDeallocate(sd.Name) pipeline.SendDeallocate(sd.Name)
delete(c.preparedStatements, sd.Name)
} }
err := pipeline.Sync() err := pipeline.Sync()
@ -1389,5 +1388,10 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error
return fmt.Errorf("failed to deallocate cached statement(s): %w", err) return fmt.Errorf("failed to deallocate cached statement(s): %w", err)
} }
c.statementCache.RemoveInvalidated()
for _, sd := range invalidatedStatements {
delete(c.preparedStatements, sd.Name)
}
return nil return nil
} }

View File

@ -1338,3 +1338,32 @@ func TestRawValuesUnderlyingMemoryReused(t *testing.T) {
t.Fatal("expected buffer from RawValues to be overwritten by subsequent queries but it was not") t.Fatal("expected buffer from RawValues to be overwritten by subsequent queries but it was not")
}) })
} }
// https://github.com/jackc/pgx/issues/1847
func TestConnDeallocateInvalidatedCachedStatementsWhenCanceled(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
var n int32
err := conn.QueryRow(ctx, "select 1 / $1::int", 1).Scan(&n)
require.NoError(t, err)
require.EqualValues(t, 1, n)
// Divide by zero causes an error. baseRows.Close() calls Invalidate on the statement cache whenever an error was
// encountered by the query. Use this to purposely invalidate the query. If we had access to private fields of conn
// we could call conn.statementCache.InvalidateAll() instead.
err = conn.QueryRow(ctx, "select 1 / $1::int", 0).Scan(&n)
require.Error(t, err)
ctx2, cancel2 := context.WithCancel(ctx)
cancel2()
err = conn.QueryRow(ctx2, "select 1 / $1::int", 1).Scan(&n)
require.Error(t, err)
require.ErrorIs(t, err, context.Canceled)
err = conn.QueryRow(ctx, "select 1 / $1::int", 1).Scan(&n)
require.NoError(t, err)
require.EqualValues(t, 1, n)
})
}

View File

@ -81,12 +81,16 @@ func (c *LRUCache) InvalidateAll() {
c.l = list.New() c.l = list.New()
} }
// HandleInvalidated returns a slice of all statement descriptions invalidated since the last call to HandleInvalidated. // GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
// Typically, the caller will then deallocate them. func (c *LRUCache) GetInvalidated() []*pgconn.StatementDescription {
func (c *LRUCache) HandleInvalidated() []*pgconn.StatementDescription { return c.invalidStmts
invalidStmts := c.invalidStmts }
// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
// never seen by the call to GetInvalidated.
func (c *LRUCache) RemoveInvalidated() {
c.invalidStmts = nil c.invalidStmts = nil
return invalidStmts
} }
// Len returns the number of cached prepared statement descriptions. // Len returns the number of cached prepared statement descriptions.

View File

@ -29,8 +29,13 @@ type Cache interface {
// InvalidateAll invalidates all statement descriptions. // InvalidateAll invalidates all statement descriptions.
InvalidateAll() InvalidateAll()
// HandleInvalidated returns a slice of all statement descriptions invalidated since the last call to HandleInvalidated. // GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
HandleInvalidated() []*pgconn.StatementDescription GetInvalidated() []*pgconn.StatementDescription
// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
// never seen by the call to GetInvalidated.
RemoveInvalidated()
// Len returns the number of cached prepared statement descriptions. // Len returns the number of cached prepared statement descriptions.
Len() int Len() int

View File

@ -54,10 +54,16 @@ func (c *UnlimitedCache) InvalidateAll() {
c.m = make(map[string]*pgconn.StatementDescription) c.m = make(map[string]*pgconn.StatementDescription)
} }
func (c *UnlimitedCache) HandleInvalidated() []*pgconn.StatementDescription { // GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
invalidStmts := c.invalidStmts func (c *UnlimitedCache) GetInvalidated() []*pgconn.StatementDescription {
return c.invalidStmts
}
// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
// never seen by the call to GetInvalidated.
func (c *UnlimitedCache) RemoveInvalidated() {
c.invalidStmts = nil c.invalidStmts = nil
return invalidStmts
} }
// Len returns the number of cached prepared statement descriptions. // Len returns the number of cached prepared statement descriptions.