Fix incomplete selects during batch

An incompletely read select followed by an insert would fail. This was
caused by query methods in the non-batch path always calling
ensureConnectionReadyForQuery. This ensures that connections interrupted
by context cancellation are still usable. However, in the batch case
query methods are not being called while reading the result. A
incompletely read select followed by another select would not manifest
this error due to it starting by reading until row description. But when
an incomplete select (which even a successful QueryRow would be
considered) is followed by an Exec, the CommandComplete message from the
select would be considered as the response to the subsequent Exec.

The fix is the batch tracking whether a CommandComplete is pending and
reading it before advancing to the next result. This is similar in
principle to ensureConnectionReadyForQuery, just specific to Batch.
pull/330/head
Jack Christensen 2017-09-21 11:19:52 -05:00
parent b4f9d149c1
commit 53e5d8e341
3 changed files with 99 additions and 10 deletions

View File

@ -17,13 +17,14 @@ type batchItem struct {
// Batch queries are a way of bundling multiple queries together to avoid
// unnecessary network round trips.
type Batch struct {
conn *Conn
connPool *ConnPool
items []*batchItem
resultsRead int
sent bool
ctx context.Context
err error
conn *Conn
connPool *ConnPool
items []*batchItem
resultsRead int
sent bool
pendingCommandComplete bool
ctx context.Context
err error
}
// BeginBatch returns a *Batch query for c.
@ -145,8 +146,15 @@ func (b *Batch) ExecResults() (CommandTag, error) {
default:
}
if err := b.ensureCommandComplete(); err != nil {
b.die(err)
return "", err
}
b.resultsRead++
b.pendingCommandComplete = true
for {
msg, err := b.conn.rxMsg()
if err != nil {
@ -155,6 +163,7 @@ func (b *Batch) ExecResults() (CommandTag, error) {
switch msg := msg.(type) {
case *pgproto3.CommandComplete:
b.pendingCommandComplete = false
return CommandTag(msg.CommandTag), nil
default:
if err := b.conn.processContextFreeMsg(msg); err != nil {
@ -182,8 +191,16 @@ func (b *Batch) QueryResults() (*Rows, error) {
default:
}
if err := b.ensureCommandComplete(); err != nil {
b.die(err)
rows.fatal(err)
return rows, err
}
b.resultsRead++
b.pendingCommandComplete = true
fieldDescriptions, err := b.conn.readUntilRowDescription()
if err != nil {
b.die(err)
@ -244,3 +261,25 @@ func (b *Batch) die(err error) {
b.connPool.Release(b.conn)
}
}
func (b *Batch) ensureCommandComplete() error {
for b.pendingCommandComplete {
msg, err := b.conn.rxMsg()
if err != nil {
return err
}
switch msg := msg.(type) {
case *pgproto3.CommandComplete:
b.pendingCommandComplete = false
return nil
default:
err = b.conn.processContextFreeMsg(msg)
if err != nil {
return err
}
}
}
return nil
}

View File

@ -477,7 +477,7 @@ func TestConnBeginBatchQuerySyntaxError(t *testing.T) {
}
}
func TestConnBeginBatchSelectInsert(t *testing.T) {
func TestConnBeginBatchQueryRowInsert(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
@ -525,3 +525,52 @@ func TestConnBeginBatchSelectInsert(t *testing.T) {
ensureConnValid(t, conn)
}
func TestConnBeginBatchQueryPartialReadInsert(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
sql := `create temporary table ledger(
id serial primary key,
description varchar not null,
amount int not null
);`
mustExec(t, conn, sql)
batch := conn.BeginBatch()
batch.Queue("select 1 union all select 2 union all select 3",
nil,
nil,
[]int16{pgx.BinaryFormatCode},
)
batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)",
[]interface{}{"q1", 1},
[]pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID},
nil,
)
err := batch.Send(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
rows, err := batch.QueryResults()
if err != nil {
t.Error(err)
}
rows.Close()
ct, err := batch.ExecResults()
if err != nil {
t.Error(err)
}
if ct.RowsAffected() != 2 {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected, 2)
}
batch.Close()
ensureConnValid(t, conn)
}

View File

@ -34,8 +34,6 @@ func (r *Row) Scan(dest ...interface{}) (err error) {
}
rows.Scan(dest...)
for rows.Next() {
}
rows.Close()
return rows.Err()
}
@ -151,6 +149,9 @@ func (rows *Rows) Next() bool {
rows.values = msg.Values
return true
case *pgproto3.CommandComplete:
if rows.batch != nil {
rows.batch.pendingCommandComplete = false
}
rows.Close()
return false