Jack Christensen 2021-09-11 10:32:02 -05:00
parent 5320ad87c8
commit 38cd1b40aa
2 changed files with 80 additions and 47 deletions

View File

@ -41,6 +41,9 @@ type BatchResults interface {
// QueryRow reads the results from the next query in the batch as if the query has been sent with Conn.QueryRow. // QueryRow reads the results from the next query in the batch as if the query has been sent with Conn.QueryRow.
QueryRow() Row QueryRow() Row
// QueryFunc reads the results from the next query in the batch as if the query has been sent with Conn.QueryFunc.
QueryFunc(scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error)
// Close closes the batch operation. This must be called before the underlying connection can be used again. Any error // Close closes the batch operation. This must be called before the underlying connection can be used again. Any error
// that occurred during a batch operation may have made it impossible to resyncronize the connection with the server. // that occurred during a batch operation may have made it impossible to resyncronize the connection with the server.
// In this case the underlying connection will have been closed. // In this case the underlying connection will have been closed.
@ -135,6 +138,33 @@ func (br *batchResults) Query() (Rows, error) {
return rows, nil return rows, nil
} }
// QueryFunc reads the results from the next query in the batch as if the query has been sent with Conn.QueryFunc.
func (br *batchResults) QueryFunc(scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) {
rows, err := br.Query()
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
err = rows.Scan(scans...)
if err != nil {
return nil, err
}
err = f(rows)
if err != nil {
return nil, err
}
}
if err := rows.Err(); err != nil {
return nil, err
}
return rows.CommandTag(), nil
}
// QueryRow reads the results from the next query in the batch as if the query has been sent with QueryRow. // QueryRow reads the results from the next query in the batch as if the query has been sent with QueryRow.
func (br *batchResults) QueryRow() Row { func (br *batchResults) QueryRow() Row {
rows, _ := br.Query() rows, _ := br.Query()

View File

@ -32,6 +32,7 @@ func TestConnSendBatch(t *testing.T) {
batch.Queue("insert into ledger(description, amount) values($1, $2)", "q2", 2) batch.Queue("insert into ledger(description, amount) values($1, $2)", "q2", 2)
batch.Queue("insert into ledger(description, amount) values($1, $2)", "q3", 3) batch.Queue("insert into ledger(description, amount) values($1, $2)", "q3", 3)
batch.Queue("select id, description, amount from ledger order by id") batch.Queue("select id, description, amount from ledger order by id")
batch.Queue("select id, description, amount from ledger order by id")
batch.Queue("select sum(amount) from ledger") batch.Queue("select sum(amount) from ledger")
br := conn.SendBatch(context.Background(), batch) br := conn.SendBatch(context.Background(), batch)
@ -60,6 +61,16 @@ func TestConnSendBatch(t *testing.T) {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
} }
selectFromLedgerExpectedRows := []struct {
id int32
description string
amount int32
}{
{1, "q1", 1},
{2, "q2", 2},
{3, "q3", 3},
}
rows, err := br.Query() rows, err := br.Query()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
@ -68,62 +79,54 @@ func TestConnSendBatch(t *testing.T) {
var id int32 var id int32
var description string var description string
var amount int32 var amount int32
if !rows.Next() { rowCount := 0
t.Fatal("expected a row to be available")
}
if err := rows.Scan(&id, &description, &amount); err != nil {
t.Fatal(err)
}
if id != 1 {
t.Errorf("id => %v, want %v", id, 1)
}
if description != "q1" {
t.Errorf("description => %v, want %v", description, "q1")
}
if amount != 1 {
t.Errorf("amount => %v, want %v", amount, 1)
}
if !rows.Next() { for rows.Next() {
t.Fatal("expected a row to be available") if rowCount >= len(selectFromLedgerExpectedRows) {
} t.Fatalf("got too many rows: %d", rowCount)
if err := rows.Scan(&id, &description, &amount); err != nil { }
t.Fatal(err)
}
if id != 2 {
t.Errorf("id => %v, want %v", id, 2)
}
if description != "q2" {
t.Errorf("description => %v, want %v", description, "q2")
}
if amount != 2 {
t.Errorf("amount => %v, want %v", amount, 2)
}
if !rows.Next() { if err := rows.Scan(&id, &description, &amount); err != nil {
t.Fatal("expected a row to be available") t.Fatalf("row %d: %v", rowCount, err)
} }
if err := rows.Scan(&id, &description, &amount); err != nil {
t.Fatal(err)
}
if id != 3 {
t.Errorf("id => %v, want %v", id, 3)
}
if description != "q3" {
t.Errorf("description => %v, want %v", description, "q3")
}
if amount != 3 {
t.Errorf("amount => %v, want %v", amount, 3)
}
if rows.Next() { if id != selectFromLedgerExpectedRows[rowCount].id {
t.Fatal("did not expect a row to be available") t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id)
}
if description != selectFromLedgerExpectedRows[rowCount].description {
t.Errorf("description => %v, want %v", description, selectFromLedgerExpectedRows[rowCount].description)
}
if amount != selectFromLedgerExpectedRows[rowCount].amount {
t.Errorf("amount => %v, want %v", amount, selectFromLedgerExpectedRows[rowCount].amount)
}
rowCount++
} }
if rows.Err() != nil { if rows.Err() != nil {
t.Fatal(rows.Err()) t.Fatal(rows.Err())
} }
rowCount = 0
_, err = br.QueryFunc([]interface{}{&id, &description, &amount}, func(pgx.QueryFuncRow) error {
if id != selectFromLedgerExpectedRows[rowCount].id {
t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id)
}
if description != selectFromLedgerExpectedRows[rowCount].description {
t.Errorf("description => %v, want %v", description, selectFromLedgerExpectedRows[rowCount].description)
}
if amount != selectFromLedgerExpectedRows[rowCount].amount {
t.Errorf("amount => %v, want %v", amount, selectFromLedgerExpectedRows[rowCount].amount)
}
rowCount++
return nil
})
if err != nil {
t.Error(err)
}
err = br.QueryRow().Scan(&amount) err = br.QueryRow().Scan(&amount)
if err != nil { if err != nil {
t.Error(err) t.Error(err)