diff --git a/batch.go b/batch.go index 0e86fead..8f6ea4f0 100644 --- a/batch.go +++ b/batch.go @@ -139,7 +139,10 @@ func (br *batchResults) Exec() (pgconn.CommandTag, error) { } commandTag, err := br.mrr.ResultReader().Close() - br.err = err + if err != nil { + br.err = err + br.mrr.Close() + } if br.conn.batchTracer != nil { br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{ diff --git a/batch_test.go b/batch_test.go index f65b018e..44abbd8e 100644 --- a/batch_test.go +++ b/batch_test.go @@ -742,6 +742,27 @@ func TestSendBatchErrorWhileReadingResultsWithoutCallback(t *testing.T) { }) } +func TestSendBatchErrorWhileReadingResultsWithExecWhereSomeRowsAreReturned(t *testing.T) { + t.Parallel() + + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + batch := &pgx.Batch{} + batch.Queue("select 4 / n from generate_series(-2, 2) n") + + batchResult := conn.SendBatch(ctx, batch) + + _, execErr := batchResult.Exec() + require.Error(t, execErr) + + closeErr := batchResult.Close() + require.Equal(t, execErr, closeErr) + + // Try to use the connection. + _, err := conn.Exec(ctx, "select 1") + require.NoError(t, err) + }) +} + func TestConnBeginBatchDeferredError(t *testing.T) { t.Parallel()