From 38cd1b40aab7244bd3d593d5153619e03b09edca Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Sep 2021 10:32:02 -0500 Subject: [PATCH] Add QueryFunc to BatchResults https://github.com/jackc/pgx/issues/1048#issuecomment-915123822 --- batch.go | 30 ++++++++++++++++ batch_test.go | 97 ++++++++++++++++++++++++++------------------------- 2 files changed, 80 insertions(+), 47 deletions(-) diff --git a/batch.go b/batch.go index 4b96ca19..f0479ea6 100644 --- a/batch.go +++ b/batch.go @@ -41,6 +41,9 @@ type BatchResults interface { // QueryRow reads the results from the next query in the batch as if the query has been sent with Conn.QueryRow. QueryRow() Row + // QueryFunc reads the results from the next query in the batch as if the query has been sent with Conn.QueryFunc. + QueryFunc(scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) + // Close closes the batch operation. This must be called before the underlying connection can be used again. Any error // that occurred during a batch operation may have made it impossible to resyncronize the connection with the server. // In this case the underlying connection will have been closed. @@ -135,6 +138,33 @@ func (br *batchResults) Query() (Rows, error) { return rows, nil } +// QueryFunc reads the results from the next query in the batch as if the query has been sent with Conn.QueryFunc. +func (br *batchResults) QueryFunc(scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { + rows, err := br.Query() + if err != nil { + return nil, err + } + defer rows.Close() + + for rows.Next() { + err = rows.Scan(scans...) + if err != nil { + return nil, err + } + + err = f(rows) + if err != nil { + return nil, err + } + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return rows.CommandTag(), nil +} + // QueryRow reads the results from the next query in the batch as if the query has been sent with QueryRow. func (br *batchResults) QueryRow() Row { rows, _ := br.Query() diff --git a/batch_test.go b/batch_test.go index 0b95cd47..988a1682 100644 --- a/batch_test.go +++ b/batch_test.go @@ -32,6 +32,7 @@ func TestConnSendBatch(t *testing.T) { batch.Queue("insert into ledger(description, amount) values($1, $2)", "q2", 2) batch.Queue("insert into ledger(description, amount) values($1, $2)", "q3", 3) batch.Queue("select id, description, amount from ledger order by id") + batch.Queue("select id, description, amount from ledger order by id") batch.Queue("select sum(amount) from ledger") br := conn.SendBatch(context.Background(), batch) @@ -60,6 +61,16 @@ func TestConnSendBatch(t *testing.T) { t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) } + selectFromLedgerExpectedRows := []struct { + id int32 + description string + amount int32 + }{ + {1, "q1", 1}, + {2, "q2", 2}, + {3, "q3", 3}, + } + rows, err := br.Query() if err != nil { t.Error(err) @@ -68,62 +79,54 @@ func TestConnSendBatch(t *testing.T) { var id int32 var description string var amount int32 - if !rows.Next() { - t.Fatal("expected a row to be available") - } - if err := rows.Scan(&id, &description, &amount); err != nil { - t.Fatal(err) - } - if id != 1 { - t.Errorf("id => %v, want %v", id, 1) - } - if description != "q1" { - t.Errorf("description => %v, want %v", description, "q1") - } - if amount != 1 { - t.Errorf("amount => %v, want %v", amount, 1) - } + rowCount := 0 - if !rows.Next() { - t.Fatal("expected a row to be available") - } - if err := rows.Scan(&id, &description, &amount); err != nil { - t.Fatal(err) - } - if id != 2 { - t.Errorf("id => %v, want %v", id, 2) - } - if description != "q2" { - t.Errorf("description => %v, want %v", description, "q2") - } - if amount != 2 { - t.Errorf("amount => %v, want %v", amount, 2) - } + for rows.Next() { + if rowCount >= len(selectFromLedgerExpectedRows) { + t.Fatalf("got too many rows: %d", rowCount) + } - if !rows.Next() { - t.Fatal("expected a row to be available") - } - if err := rows.Scan(&id, &description, &amount); err != nil { - t.Fatal(err) - } - if id != 3 { - t.Errorf("id => %v, want %v", id, 3) - } - if description != "q3" { - t.Errorf("description => %v, want %v", description, "q3") - } - if amount != 3 { - t.Errorf("amount => %v, want %v", amount, 3) - } + if err := rows.Scan(&id, &description, &amount); err != nil { + t.Fatalf("row %d: %v", rowCount, err) + } - if rows.Next() { - t.Fatal("did not expect a row to be available") + if id != selectFromLedgerExpectedRows[rowCount].id { + t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id) + } + if description != selectFromLedgerExpectedRows[rowCount].description { + t.Errorf("description => %v, want %v", description, selectFromLedgerExpectedRows[rowCount].description) + } + if amount != selectFromLedgerExpectedRows[rowCount].amount { + t.Errorf("amount => %v, want %v", amount, selectFromLedgerExpectedRows[rowCount].amount) + } + + rowCount++ } if rows.Err() != nil { t.Fatal(rows.Err()) } + rowCount = 0 + _, err = br.QueryFunc([]interface{}{&id, &description, &amount}, func(pgx.QueryFuncRow) error { + if id != selectFromLedgerExpectedRows[rowCount].id { + t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id) + } + if description != selectFromLedgerExpectedRows[rowCount].description { + t.Errorf("description => %v, want %v", description, selectFromLedgerExpectedRows[rowCount].description) + } + if amount != selectFromLedgerExpectedRows[rowCount].amount { + t.Errorf("amount => %v, want %v", amount, selectFromLedgerExpectedRows[rowCount].amount) + } + + rowCount++ + + return nil + }) + if err != nil { + t.Error(err) + } + err = br.QueryRow().Scan(&amount) if err != nil { t.Error(err)