Fix pipeline batch results not closing pipeline

when error occurs while reading directly from results instead of using
a callback.

https://github.com/jackc/pgx/issues/1578
pull/1579/head
Jack Christensen 2023-04-20 20:58:04 -05:00
parent 67f2a41587
commit 09371981f9
3 changed files with 33 additions and 15 deletions

View File

@ -381,17 +381,13 @@ func (br *pipelineBatchResults) Close() error {
}
}()
if br.err != nil {
return br.err
}
if br.lastRows != nil && br.lastRows.err != nil {
if br.err == nil && br.lastRows != nil && br.lastRows.err != nil {
br.err = br.lastRows.err
return br.err
}
if br.closed {
return nil
return br.err
}
// Read and run fn for all remaining items

View File

@ -720,6 +720,28 @@ func TestTxSendBatchRollback(t *testing.T) {
})
}
// https://github.com/jackc/pgx/issues/1578
func TestSendBatchErrorWhileReadingResultsWithoutCallback(t *testing.T) {
t.Parallel()
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
batch := &pgx.Batch{}
batch.Queue("select 4 / $1::int", 0)
batchResult := conn.SendBatch(ctx, batch)
_, execErr := batchResult.Exec()
require.Error(t, execErr)
closeErr := batchResult.Close()
require.Equal(t, execErr, closeErr)
// Try to use the connection.
_, err := conn.Exec(ctx, "select 1")
require.NoError(t, err)
})
}
func TestConnBeginBatchDeferredError(t *testing.T) {
t.Parallel()

18
conn.go
View File

@ -975,7 +975,7 @@ func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchR
func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
if c.statementCache == nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledStatementCache}
return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledStatementCache, closed: true}
}
distinctNewQueries := []*pgconn.StatementDescription{}
@ -1007,7 +1007,7 @@ func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batc
func (c *Conn) sendBatchQueryExecModeCacheDescribe(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
if c.descriptionCache == nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledDescriptionCache}
return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledDescriptionCache, closed: true}
}
distinctNewQueries := []*pgconn.StatementDescription{}
@ -1074,18 +1074,18 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d
err := pipeline.Sync()
if err != nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
}
for _, sd := range distinctNewQueries {
results, err := pipeline.GetResults()
if err != nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
}
resultSD, ok := results.(*pgconn.StatementDescription)
if !ok {
return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results)}
return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results), closed: true}
}
// Fill in the previously empty / pending statement descriptions.
@ -1095,12 +1095,12 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d
results, err := pipeline.GetResults()
if err != nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
}
_, ok := results.(*pgconn.PipelineSync)
if !ok {
return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results)}
return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results), closed: true}
}
}
@ -1117,7 +1117,7 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d
if err != nil {
// we wrap the error so we the user can understand which query failed inside the batch
err = fmt.Errorf("error building query %s: %w", bi.query, err)
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
}
if bi.sd.Name == "" {
@ -1129,7 +1129,7 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d
err := pipeline.Sync()
if err != nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
}
return &pipelineBatchResults{