mirror of https://github.com/jackc/pgx.git
Fix incomplete selects during batch
An incompletely read select followed by an insert would fail. This was caused by query methods in the non-batch path always calling ensureConnectionReadyForQuery. This ensures that connections interrupted by context cancellation are still usable. However, in the batch case query methods are not being called while reading the result. A incompletely read select followed by another select would not manifest this error due to it starting by reading until row description. But when an incomplete select (which even a successful QueryRow would be considered) is followed by an Exec, the CommandComplete message from the select would be considered as the response to the subsequent Exec. The fix is the batch tracking whether a CommandComplete is pending and reading it before advancing to the next result. This is similar in principle to ensureConnectionReadyForQuery, just specific to Batch.pull/330/head
parent
b4f9d149c1
commit
53e5d8e341
53
batch.go
53
batch.go
|
@ -17,13 +17,14 @@ type batchItem struct {
|
|||
// Batch queries are a way of bundling multiple queries together to avoid
|
||||
// unnecessary network round trips.
|
||||
type Batch struct {
|
||||
conn *Conn
|
||||
connPool *ConnPool
|
||||
items []*batchItem
|
||||
resultsRead int
|
||||
sent bool
|
||||
ctx context.Context
|
||||
err error
|
||||
conn *Conn
|
||||
connPool *ConnPool
|
||||
items []*batchItem
|
||||
resultsRead int
|
||||
sent bool
|
||||
pendingCommandComplete bool
|
||||
ctx context.Context
|
||||
err error
|
||||
}
|
||||
|
||||
// BeginBatch returns a *Batch query for c.
|
||||
|
@ -145,8 +146,15 @@ func (b *Batch) ExecResults() (CommandTag, error) {
|
|||
default:
|
||||
}
|
||||
|
||||
if err := b.ensureCommandComplete(); err != nil {
|
||||
b.die(err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
b.resultsRead++
|
||||
|
||||
b.pendingCommandComplete = true
|
||||
|
||||
for {
|
||||
msg, err := b.conn.rxMsg()
|
||||
if err != nil {
|
||||
|
@ -155,6 +163,7 @@ func (b *Batch) ExecResults() (CommandTag, error) {
|
|||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.CommandComplete:
|
||||
b.pendingCommandComplete = false
|
||||
return CommandTag(msg.CommandTag), nil
|
||||
default:
|
||||
if err := b.conn.processContextFreeMsg(msg); err != nil {
|
||||
|
@ -182,8 +191,16 @@ func (b *Batch) QueryResults() (*Rows, error) {
|
|||
default:
|
||||
}
|
||||
|
||||
if err := b.ensureCommandComplete(); err != nil {
|
||||
b.die(err)
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
|
||||
b.resultsRead++
|
||||
|
||||
b.pendingCommandComplete = true
|
||||
|
||||
fieldDescriptions, err := b.conn.readUntilRowDescription()
|
||||
if err != nil {
|
||||
b.die(err)
|
||||
|
@ -244,3 +261,25 @@ func (b *Batch) die(err error) {
|
|||
b.connPool.Release(b.conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Batch) ensureCommandComplete() error {
|
||||
for b.pendingCommandComplete {
|
||||
msg, err := b.conn.rxMsg()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.CommandComplete:
|
||||
b.pendingCommandComplete = false
|
||||
return nil
|
||||
default:
|
||||
err = b.conn.processContextFreeMsg(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -477,7 +477,7 @@ func TestConnBeginBatchQuerySyntaxError(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestConnBeginBatchSelectInsert(t *testing.T) {
|
||||
func TestConnBeginBatchQueryRowInsert(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
|
@ -525,3 +525,52 @@ func TestConnBeginBatchSelectInsert(t *testing.T) {
|
|||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnBeginBatchQueryPartialReadInsert(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
sql := `create temporary table ledger(
|
||||
id serial primary key,
|
||||
description varchar not null,
|
||||
amount int not null
|
||||
);`
|
||||
mustExec(t, conn, sql)
|
||||
|
||||
batch := conn.BeginBatch()
|
||||
batch.Queue("select 1 union all select 2 union all select 3",
|
||||
nil,
|
||||
nil,
|
||||
[]int16{pgx.BinaryFormatCode},
|
||||
)
|
||||
batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)",
|
||||
[]interface{}{"q1", 1},
|
||||
[]pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID},
|
||||
nil,
|
||||
)
|
||||
|
||||
err := batch.Send(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rows, err := batch.QueryResults()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
ct, err := batch.ExecResults()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if ct.RowsAffected() != 2 {
|
||||
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected, 2)
|
||||
}
|
||||
|
||||
batch.Close()
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
|
5
query.go
5
query.go
|
@ -34,8 +34,6 @@ func (r *Row) Scan(dest ...interface{}) (err error) {
|
|||
}
|
||||
|
||||
rows.Scan(dest...)
|
||||
for rows.Next() {
|
||||
}
|
||||
rows.Close()
|
||||
return rows.Err()
|
||||
}
|
||||
|
@ -151,6 +149,9 @@ func (rows *Rows) Next() bool {
|
|||
rows.values = msg.Values
|
||||
return true
|
||||
case *pgproto3.CommandComplete:
|
||||
if rows.batch != nil {
|
||||
rows.batch.pendingCommandComplete = false
|
||||
}
|
||||
rows.Close()
|
||||
return false
|
||||
|
||||
|
|
Loading…
Reference in New Issue