Fix incomplete selects during batch

An incompletely read select followed by an insert would fail. This was
caused by query methods in the non-batch path always calling
ensureConnectionReadyForQuery. This ensures that connections interrupted
by context cancellation are still usable. However, in the batch case
query methods are not being called while reading the result. A
incompletely read select followed by another select would not manifest
this error due to it starting by reading until row description. But when
an incomplete select (which even a successful QueryRow would be
considered) is followed by an Exec, the CommandComplete message from the
select would be considered as the response to the subsequent Exec.

The fix is the batch tracking whether a CommandComplete is pending and
reading it before advancing to the next result. This is similar in
principle to ensureConnectionReadyForQuery, just specific to Batch.
pull/330/head
Jack Christensen 2017-09-21 11:19:52 -05:00
parent b4f9d149c1
commit 53e5d8e341
3 changed files with 99 additions and 10 deletions

View File

@ -17,13 +17,14 @@ type batchItem struct {
// Batch queries are a way of bundling multiple queries together to avoid // Batch queries are a way of bundling multiple queries together to avoid
// unnecessary network round trips. // unnecessary network round trips.
type Batch struct { type Batch struct {
conn *Conn conn *Conn
connPool *ConnPool connPool *ConnPool
items []*batchItem items []*batchItem
resultsRead int resultsRead int
sent bool sent bool
ctx context.Context pendingCommandComplete bool
err error ctx context.Context
err error
} }
// BeginBatch returns a *Batch query for c. // BeginBatch returns a *Batch query for c.
@ -145,8 +146,15 @@ func (b *Batch) ExecResults() (CommandTag, error) {
default: default:
} }
if err := b.ensureCommandComplete(); err != nil {
b.die(err)
return "", err
}
b.resultsRead++ b.resultsRead++
b.pendingCommandComplete = true
for { for {
msg, err := b.conn.rxMsg() msg, err := b.conn.rxMsg()
if err != nil { if err != nil {
@ -155,6 +163,7 @@ func (b *Batch) ExecResults() (CommandTag, error) {
switch msg := msg.(type) { switch msg := msg.(type) {
case *pgproto3.CommandComplete: case *pgproto3.CommandComplete:
b.pendingCommandComplete = false
return CommandTag(msg.CommandTag), nil return CommandTag(msg.CommandTag), nil
default: default:
if err := b.conn.processContextFreeMsg(msg); err != nil { if err := b.conn.processContextFreeMsg(msg); err != nil {
@ -182,8 +191,16 @@ func (b *Batch) QueryResults() (*Rows, error) {
default: default:
} }
if err := b.ensureCommandComplete(); err != nil {
b.die(err)
rows.fatal(err)
return rows, err
}
b.resultsRead++ b.resultsRead++
b.pendingCommandComplete = true
fieldDescriptions, err := b.conn.readUntilRowDescription() fieldDescriptions, err := b.conn.readUntilRowDescription()
if err != nil { if err != nil {
b.die(err) b.die(err)
@ -244,3 +261,25 @@ func (b *Batch) die(err error) {
b.connPool.Release(b.conn) b.connPool.Release(b.conn)
} }
} }
func (b *Batch) ensureCommandComplete() error {
for b.pendingCommandComplete {
msg, err := b.conn.rxMsg()
if err != nil {
return err
}
switch msg := msg.(type) {
case *pgproto3.CommandComplete:
b.pendingCommandComplete = false
return nil
default:
err = b.conn.processContextFreeMsg(msg)
if err != nil {
return err
}
}
}
return nil
}

View File

@ -477,7 +477,7 @@ func TestConnBeginBatchQuerySyntaxError(t *testing.T) {
} }
} }
func TestConnBeginBatchSelectInsert(t *testing.T) { func TestConnBeginBatchQueryRowInsert(t *testing.T) {
t.Parallel() t.Parallel()
conn := mustConnect(t, *defaultConnConfig) conn := mustConnect(t, *defaultConnConfig)
@ -525,3 +525,52 @@ func TestConnBeginBatchSelectInsert(t *testing.T) {
ensureConnValid(t, conn) ensureConnValid(t, conn)
} }
func TestConnBeginBatchQueryPartialReadInsert(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
sql := `create temporary table ledger(
id serial primary key,
description varchar not null,
amount int not null
);`
mustExec(t, conn, sql)
batch := conn.BeginBatch()
batch.Queue("select 1 union all select 2 union all select 3",
nil,
nil,
[]int16{pgx.BinaryFormatCode},
)
batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)",
[]interface{}{"q1", 1},
[]pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID},
nil,
)
err := batch.Send(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
rows, err := batch.QueryResults()
if err != nil {
t.Error(err)
}
rows.Close()
ct, err := batch.ExecResults()
if err != nil {
t.Error(err)
}
if ct.RowsAffected() != 2 {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected, 2)
}
batch.Close()
ensureConnValid(t, conn)
}

View File

@ -34,8 +34,6 @@ func (r *Row) Scan(dest ...interface{}) (err error) {
} }
rows.Scan(dest...) rows.Scan(dest...)
for rows.Next() {
}
rows.Close() rows.Close()
return rows.Err() return rows.Err()
} }
@ -151,6 +149,9 @@ func (rows *Rows) Next() bool {
rows.values = msg.Values rows.values = msg.Values
return true return true
case *pgproto3.CommandComplete: case *pgproto3.CommandComplete:
if rows.batch != nil {
rows.batch.pendingCommandComplete = false
}
rows.Close() rows.Close()
return false return false