Fix prepared statement already exists on batch prepare failure

When a batch successfully prepared some statements, but then failed to
prepare others, the prepared statements that were successfully prepared
were not properly cleaned up. This could lead to a "prepared statement
already exists" error on subsequent attempts to prepare the same
statement.

https://github.com/jackc/pgx/issues/1847#issuecomment-2347858887
pull/2127/head
Jack Christensen 2024-09-13 08:03:37 -05:00
parent 672c4a3a24
commit fd0c65478e
2 changed files with 76 additions and 29 deletions

View File

@ -1008,6 +1008,36 @@ func TestSendBatchSimpleProtocol(t *testing.T) {
assert.False(t, rows.Next()) assert.False(t, rows.Next())
} }
// https://github.com/jackc/pgx/issues/1847#issuecomment-2347858887
func TestConnSendBatchErrorDoesNotLeaveOrphanedPreparedStatement(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.SkipCockroachDB(t, conn, "Server serial type is incompatible with test")
mustExec(t, conn, `create temporary table foo(col1 text primary key);`)
batch := &pgx.Batch{}
batch.Queue("select col1 from foo")
batch.Queue("select col1 from baz")
err := conn.SendBatch(ctx, batch).Close()
require.EqualError(t, err, `ERROR: relation "baz" does not exist (SQLSTATE 42P01)`)
mustExec(t, conn, `create temporary table baz(col1 text primary key);`)
// Since table baz now exists, the batch should succeed.
batch = &pgx.Batch{}
batch.Queue("select col1 from foo")
batch.Queue("select col1 from baz")
err = conn.SendBatch(ctx, batch).Close()
require.NoError(t, err)
})
}
func ExampleConn_SendBatch() { func ExampleConn_SendBatch() {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel() defer cancel()

75
conn.go
View File

@ -1126,47 +1126,64 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d
// Prepare any needed queries // Prepare any needed queries
if len(distinctNewQueries) > 0 { if len(distinctNewQueries) > 0 {
for _, sd := range distinctNewQueries { err := func() (err error) {
pipeline.SendPrepare(sd.Name, sd.SQL, nil) for _, sd := range distinctNewQueries {
} pipeline.SendPrepare(sd.Name, sd.SQL, nil)
}
err := pipeline.Sync() // Store all statements we are preparing into the cache. It's fine if it overflows because HandleInvalidated will
if err != nil { // clean them up later.
return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} if sdCache != nil {
} for _, sd := range distinctNewQueries {
sdCache.Put(sd)
}
}
// If something goes wrong preparing the statements, we need to invalidate the cache entries we just added.
defer func() {
if err != nil && sdCache != nil {
for _, sd := range distinctNewQueries {
sdCache.Invalidate(sd.SQL)
}
}
}()
err = pipeline.Sync()
if err != nil {
return err
}
for _, sd := range distinctNewQueries {
results, err := pipeline.GetResults()
if err != nil {
return err
}
resultSD, ok := results.(*pgconn.StatementDescription)
if !ok {
return fmt.Errorf("expected statement description, got %T", results)
}
// Fill in the previously empty / pending statement descriptions.
sd.ParamOIDs = resultSD.ParamOIDs
sd.Fields = resultSD.Fields
}
for _, sd := range distinctNewQueries {
results, err := pipeline.GetResults() results, err := pipeline.GetResults()
if err != nil { if err != nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} return err
} }
resultSD, ok := results.(*pgconn.StatementDescription) _, ok := results.(*pgconn.PipelineSync)
if !ok { if !ok {
return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results), closed: true} return fmt.Errorf("expected sync, got %T", results)
} }
// Fill in the previously empty / pending statement descriptions. return nil
sd.ParamOIDs = resultSD.ParamOIDs }()
sd.Fields = resultSD.Fields
}
results, err := pipeline.GetResults()
if err != nil { if err != nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
} }
_, ok := results.(*pgconn.PipelineSync)
if !ok {
return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results), closed: true}
}
}
// Put all statements into the cache. It's fine if it overflows because HandleInvalidated will clean them up later.
if sdCache != nil {
for _, sd := range distinctNewQueries {
sdCache.Put(sd)
}
} }
// Queue the queries. // Queue the queries.