mirror of https://github.com/jackc/pgx.git
Fix pipeline batch results not closing pipeline
when error occurs while reading directly from results instead of using a callback. https://github.com/jackc/pgx/issues/1578pull/1579/head
parent
67f2a41587
commit
09371981f9
8
batch.go
8
batch.go
|
@ -381,17 +381,13 @@ func (br *pipelineBatchResults) Close() error {
|
|||
}
|
||||
}()
|
||||
|
||||
if br.err != nil {
|
||||
return br.err
|
||||
}
|
||||
|
||||
if br.lastRows != nil && br.lastRows.err != nil {
|
||||
if br.err == nil && br.lastRows != nil && br.lastRows.err != nil {
|
||||
br.err = br.lastRows.err
|
||||
return br.err
|
||||
}
|
||||
|
||||
if br.closed {
|
||||
return nil
|
||||
return br.err
|
||||
}
|
||||
|
||||
// Read and run fn for all remaining items
|
||||
|
|
|
@ -720,6 +720,28 @@ func TestTxSendBatchRollback(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
// https://github.com/jackc/pgx/issues/1578
|
||||
func TestSendBatchErrorWhileReadingResultsWithoutCallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
batch := &pgx.Batch{}
|
||||
batch.Queue("select 4 / $1::int", 0)
|
||||
|
||||
batchResult := conn.SendBatch(ctx, batch)
|
||||
|
||||
_, execErr := batchResult.Exec()
|
||||
require.Error(t, execErr)
|
||||
|
||||
closeErr := batchResult.Close()
|
||||
require.Equal(t, execErr, closeErr)
|
||||
|
||||
// Try to use the connection.
|
||||
_, err := conn.Exec(ctx, "select 1")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnBeginBatchDeferredError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
18
conn.go
18
conn.go
|
@ -975,7 +975,7 @@ func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchR
|
|||
|
||||
func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
|
||||
if c.statementCache == nil {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledStatementCache}
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledStatementCache, closed: true}
|
||||
}
|
||||
|
||||
distinctNewQueries := []*pgconn.StatementDescription{}
|
||||
|
@ -1007,7 +1007,7 @@ func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batc
|
|||
|
||||
func (c *Conn) sendBatchQueryExecModeCacheDescribe(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
|
||||
if c.descriptionCache == nil {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledDescriptionCache}
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledDescriptionCache, closed: true}
|
||||
}
|
||||
|
||||
distinctNewQueries := []*pgconn.StatementDescription{}
|
||||
|
@ -1074,18 +1074,18 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d
|
|||
|
||||
err := pipeline.Sync()
|
||||
if err != nil {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
|
||||
}
|
||||
|
||||
for _, sd := range distinctNewQueries {
|
||||
results, err := pipeline.GetResults()
|
||||
if err != nil {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
|
||||
}
|
||||
|
||||
resultSD, ok := results.(*pgconn.StatementDescription)
|
||||
if !ok {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results)}
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results), closed: true}
|
||||
}
|
||||
|
||||
// Fill in the previously empty / pending statement descriptions.
|
||||
|
@ -1095,12 +1095,12 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d
|
|||
|
||||
results, err := pipeline.GetResults()
|
||||
if err != nil {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
|
||||
}
|
||||
|
||||
_, ok := results.(*pgconn.PipelineSync)
|
||||
if !ok {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results)}
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results), closed: true}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1117,7 +1117,7 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d
|
|||
if err != nil {
|
||||
// we wrap the error so we the user can understand which query failed inside the batch
|
||||
err = fmt.Errorf("error building query %s: %w", bi.query, err)
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
|
||||
}
|
||||
|
||||
if bi.sd.Name == "" {
|
||||
|
@ -1129,7 +1129,7 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d
|
|||
|
||||
err := pipeline.Sync()
|
||||
if err != nil {
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
|
||||
return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
|
||||
}
|
||||
|
||||
return &pipelineBatchResults{
|
||||
|
|
Loading…
Reference in New Issue