mirror of https://github.com/jackc/pgx.git
Fix prepared statement already exists on batch prepare failure
When a batch successfully prepared some statements, but then failed to prepare others, the prepared statements that were successfully prepared were not properly cleaned up. This could lead to a "prepared statement already exists" error on subsequent attempts to prepare the same statement. https://github.com/jackc/pgx/issues/1847#issuecomment-2347858887pull/2127/head
parent
672c4a3a24
commit
fd0c65478e
|
@ -1008,6 +1008,36 @@ func TestSendBatchSimpleProtocol(t *testing.T) {
|
|||
assert.False(t, rows.Next())
|
||||
}
|
||||
|
||||
// https://github.com/jackc/pgx/issues/1847#issuecomment-2347858887
|
||||
func TestConnSendBatchErrorDoesNotLeaveOrphanedPreparedStatement(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
pgxtest.SkipCockroachDB(t, conn, "Server serial type is incompatible with test")
|
||||
|
||||
mustExec(t, conn, `create temporary table foo(col1 text primary key);`)
|
||||
|
||||
batch := &pgx.Batch{}
|
||||
batch.Queue("select col1 from foo")
|
||||
batch.Queue("select col1 from baz")
|
||||
err := conn.SendBatch(ctx, batch).Close()
|
||||
require.EqualError(t, err, `ERROR: relation "baz" does not exist (SQLSTATE 42P01)`)
|
||||
|
||||
mustExec(t, conn, `create temporary table baz(col1 text primary key);`)
|
||||
|
||||
// Since table baz now exists, the batch should succeed.
|
||||
|
||||
batch = &pgx.Batch{}
|
||||
batch.Queue("select col1 from foo")
|
||||
batch.Queue("select col1 from baz")
|
||||
err = conn.SendBatch(ctx, batch).Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func ExampleConn_SendBatch() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
|
75
conn.go
75
conn.go
|
@ -1126,47 +1126,64 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d
|
|||
|
||||
// Prepare any needed queries
|
||||
if len(distinctNewQueries) > 0 {
|
||||
for _, sd := range distinctNewQueries {
|
||||
pipeline.SendPrepare(sd.Name, sd.SQL, nil)
|
||||
}
|
||||
err := func() (err error) {
|
||||
for _, sd := range distinctNewQueries {
|
||||
pipeline.SendPrepare(sd.Name, sd.SQL, nil)
|
||||
}
|
||||
|
||||
err := pipeline.Sync()
|
||||
if err != nil {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
|
||||
}
|
||||
// Store all statements we are preparing into the cache. It's fine if it overflows because HandleInvalidated will
|
||||
// clean them up later.
|
||||
if sdCache != nil {
|
||||
for _, sd := range distinctNewQueries {
|
||||
sdCache.Put(sd)
|
||||
}
|
||||
}
|
||||
|
||||
// If something goes wrong preparing the statements, we need to invalidate the cache entries we just added.
|
||||
defer func() {
|
||||
if err != nil && sdCache != nil {
|
||||
for _, sd := range distinctNewQueries {
|
||||
sdCache.Invalidate(sd.SQL)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
err = pipeline.Sync()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, sd := range distinctNewQueries {
|
||||
results, err := pipeline.GetResults()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resultSD, ok := results.(*pgconn.StatementDescription)
|
||||
if !ok {
|
||||
return fmt.Errorf("expected statement description, got %T", results)
|
||||
}
|
||||
|
||||
// Fill in the previously empty / pending statement descriptions.
|
||||
sd.ParamOIDs = resultSD.ParamOIDs
|
||||
sd.Fields = resultSD.Fields
|
||||
}
|
||||
|
||||
for _, sd := range distinctNewQueries {
|
||||
results, err := pipeline.GetResults()
|
||||
if err != nil {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
|
||||
return err
|
||||
}
|
||||
|
||||
resultSD, ok := results.(*pgconn.StatementDescription)
|
||||
_, ok := results.(*pgconn.PipelineSync)
|
||||
if !ok {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results), closed: true}
|
||||
return fmt.Errorf("expected sync, got %T", results)
|
||||
}
|
||||
|
||||
// Fill in the previously empty / pending statement descriptions.
|
||||
sd.ParamOIDs = resultSD.ParamOIDs
|
||||
sd.Fields = resultSD.Fields
|
||||
}
|
||||
|
||||
results, err := pipeline.GetResults()
|
||||
return nil
|
||||
}()
|
||||
if err != nil {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
|
||||
}
|
||||
|
||||
_, ok := results.(*pgconn.PipelineSync)
|
||||
if !ok {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results), closed: true}
|
||||
}
|
||||
}
|
||||
|
||||
// Put all statements into the cache. It's fine if it overflows because HandleInvalidated will clean them up later.
|
||||
if sdCache != nil {
|
||||
for _, sd := range distinctNewQueries {
|
||||
sdCache.Put(sd)
|
||||
}
|
||||
}
|
||||
|
||||
// Queue the queries.
|
||||
|
|
Loading…
Reference in New Issue