implement scanning iterator `AllRowsScanned`

pull/2241/head
xobotyi 2025-01-24 18:07:46 +01:00
parent 548aaceffc
commit 812c9373f0
2 changed files with 114 additions and 0 deletions

25
rows.go
View File

@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"iter"
"reflect"
"strings"
"sync"
@ -666,6 +667,30 @@ func RowToAddrOfStructByNameLax[T any](row CollectableRow) (*T, error) {
return &value, err
}
// AllRowsScanned returns iterator that read and scans rows one-by-one. It closes
// the rows automatically on return.
//
// In case rows.Err() returns non-nil error after all rows are read, it will
// trigger extra yield with zero value and the error.
func AllRowsScanned[T any](rows Rows, fn RowToFunc[T]) iter.Seq2[T, error] {
return func(yield func(T, error) bool) {
defer rows.Close()
for rows.Next() {
if !yield(fn(rows)) {
break
}
}
// we don't have another choice but to push one more time
// in order to propagate the error to user
if err := rows.Err(); err != nil {
var zero T
yield(zero, err)
}
}
}
type namedStructRowScanner struct {
ptrToStruct any
lax bool

View File

@ -993,3 +993,92 @@ insert into products (name, price) values
// Fries: $5
// Soft Drink: $3
}
func ExampleAllRowsScanned() {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
if err != nil {
fmt.Printf("Unable to establish connection: %v", err)
return
}
if conn.PgConn().ParameterStatus("crdb_version") != "" {
// Skip test / example when running on CockroachDB. Since an example can't be skipped fake success instead.
fmt.Println(`Cheeseburger: $10
Fries: $5
Soft Drink: $3`)
return
}
// Setup example schema and data.
_, err = conn.Exec(ctx, `
create temporary table products (
id int primary key generated by default as identity,
name varchar(100) not null,
price int not null
);
insert into products (name, price) values
('Cheeseburger', 10),
('Double Cheeseburger', 14),
('Fries', 5),
('Soft Drink', 3);
`)
if err != nil {
fmt.Printf("Unable to setup example schema and data: %v", err)
return
}
type product struct {
ID int32
Name string
Type string
Price int32
}
result := make([]product, 0, 3)
rows, _ := conn.Query(ctx, "select * from products where price < $1 order by price desc", 12)
for row, err := range pgx.AllRowsScanned[product](rows, pgx.RowToStructByNameLax) {
if err != nil {
fmt.Printf("AllRowsScanned error: %v", err)
return
}
// our business logic here
result = append(result, row)
}
for _, p := range result {
fmt.Printf("%s: $%d\n", p.Name, p.Price)
}
// Output:
// Cheeseburger: $10
// Fries: $5
// Soft Drink: $3
}
func TestAllRowsScanned(t *testing.T) {
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
type resultRow struct {
N int32 `db:"n"`
}
rows, _ := conn.Query(ctx, `select n from generate_series(0, 99) n`)
results := make([]resultRow, 0, 100)
for row, err := range pgx.AllRowsScanned[resultRow](rows, pgx.RowToStructByName) {
require.NoError(t, err)
results = append(results, row)
}
assert.Len(t, results, 100)
for i := range results {
assert.Equal(t, int32(i), results[i].N)
}
})
}