From 09371981f92f113f676010458462a3b0ebc38c1f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 20 Apr 2023 20:58:04 -0500 Subject: [PATCH] Fix pipeline batch results not closing pipeline when error occurs while reading directly from results instead of using a callback. https://github.com/jackc/pgx/issues/1578 --- batch.go | 8 ++------ batch_test.go | 22 ++++++++++++++++++++++ conn.go | 18 +++++++++--------- 3 files changed, 33 insertions(+), 15 deletions(-) diff --git a/batch.go b/batch.go index a61758a1..fca182f8 100644 --- a/batch.go +++ b/batch.go @@ -381,17 +381,13 @@ func (br *pipelineBatchResults) Close() error { } }() - if br.err != nil { - return br.err - } - - if br.lastRows != nil && br.lastRows.err != nil { + if br.err == nil && br.lastRows != nil && br.lastRows.err != nil { br.err = br.lastRows.err return br.err } if br.closed { - return nil + return br.err } // Read and run fn for all remaining items diff --git a/batch_test.go b/batch_test.go index 2ade0d4a..f65b018e 100644 --- a/batch_test.go +++ b/batch_test.go @@ -720,6 +720,28 @@ func TestTxSendBatchRollback(t *testing.T) { }) } +// https://github.com/jackc/pgx/issues/1578 +func TestSendBatchErrorWhileReadingResultsWithoutCallback(t *testing.T) { + t.Parallel() + + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + batch := &pgx.Batch{} + batch.Queue("select 4 / $1::int", 0) + + batchResult := conn.SendBatch(ctx, batch) + + _, execErr := batchResult.Exec() + require.Error(t, execErr) + + closeErr := batchResult.Close() + require.Equal(t, execErr, closeErr) + + // Try to use the connection. + _, err := conn.Exec(ctx, "select 1") + require.NoError(t, err) + }) +} + func TestConnBeginBatchDeferredError(t *testing.T) { t.Parallel() diff --git a/conn.go b/conn.go index 0bbf3cee..dc986035 100644 --- a/conn.go +++ b/conn.go @@ -975,7 +975,7 @@ func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchR func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) { if c.statementCache == nil { - return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledStatementCache} + return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledStatementCache, closed: true} } distinctNewQueries := []*pgconn.StatementDescription{} @@ -1007,7 +1007,7 @@ func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batc func (c *Conn) sendBatchQueryExecModeCacheDescribe(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) { if c.descriptionCache == nil { - return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledDescriptionCache} + return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledDescriptionCache, closed: true} } distinctNewQueries := []*pgconn.StatementDescription{} @@ -1074,18 +1074,18 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d err := pipeline.Sync() if err != nil { - return &pipelineBatchResults{ctx: ctx, conn: c, err: err} + return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} } for _, sd := range distinctNewQueries { results, err := pipeline.GetResults() if err != nil { - return &pipelineBatchResults{ctx: ctx, conn: c, err: err} + return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} } resultSD, ok := results.(*pgconn.StatementDescription) if !ok { - return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results)} + return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results), closed: true} } // Fill in the previously empty / pending statement descriptions. @@ -1095,12 +1095,12 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d results, err := pipeline.GetResults() if err != nil { - return &pipelineBatchResults{ctx: ctx, conn: c, err: err} + return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} } _, ok := results.(*pgconn.PipelineSync) if !ok { - return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results)} + return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results), closed: true} } } @@ -1117,7 +1117,7 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d if err != nil { // we wrap the error so we the user can understand which query failed inside the batch err = fmt.Errorf("error building query %s: %w", bi.query, err) - return &pipelineBatchResults{ctx: ctx, conn: c, err: err} + return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} } if bi.sd.Name == "" { @@ -1129,7 +1129,7 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d err := pipeline.Sync() if err != nil { - return &pipelineBatchResults{ctx: ctx, conn: c, err: err} + return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} } return &pipelineBatchResults{