mirror of https://github.com/jackc/pgx.git
add CollectFilteredRows
parent
04bcc0219d
commit
ecdab4e9ae
36
rows.go
36
rows.go
|
@ -451,6 +451,42 @@ func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) {
|
|||
return AppendRows([]T{}, rows, fn)
|
||||
}
|
||||
|
||||
// AppendFilteredRows iterates through rows, calling fn for each row, and appending the results into a slice of T.
|
||||
// If filter is not nil, only rows for which filter returns true will be appended, If filter is nil, all rows will be appended.
|
||||
// This function closes the rows automatically on return.
|
||||
func AppendFilteredRows[T any, S ~[]T](slice S, rows Rows, fn RowToFunc[T], filter func(T)bool) (S, error){
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
value, err := fn(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if filter == nil{
|
||||
slice = append(slice, value)
|
||||
continue
|
||||
}
|
||||
|
||||
if filter(value){
|
||||
slice = append(slice, value)
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return slice, nil
|
||||
}
|
||||
|
||||
// CollectFilteredRows iterates through rows, calling fn for each row, and collecting the results into a slice of T.
|
||||
// If filter is not nil, only rows for which filter returns true will be collected, If filter is nil, all rows will be collected.
|
||||
// This function closes the rows automatically on return.
|
||||
func CollectFilteredRows[T any](rows Rows, fn RowToFunc[T], filter func(T)bool) ([]T, error){
|
||||
return AppendFilteredRows([]T{}, rows, fn, filter)
|
||||
}
|
||||
|
||||
// CollectOneRow calls fn for the first row in rows and returns the result. If no rows are found returns an error where errors.Is(ErrNoRows) is true.
|
||||
// CollectOneRow is to CollectRows as QueryRow is to Query.
|
||||
//
|
||||
|
|
56
rows_test.go
56
rows_test.go
|
@ -175,6 +175,27 @@ func TestCollectRows(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestCollectFilteredRows(t *testing.T) {
|
||||
filter := func(value int32) bool {
|
||||
return value <= 20
|
||||
}
|
||||
|
||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
rows, _ := conn.Query(ctx, `select n from generate_series(0, 99) n`)
|
||||
numbers, err := pgx.CollectFilteredRows(rows, func(row pgx.CollectableRow) (int32, error) {
|
||||
var n int32
|
||||
err := row.Scan(&n)
|
||||
return n, err
|
||||
}, filter)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, numbers, 21)
|
||||
for i := range numbers {
|
||||
assert.Equal(t, int32(i), numbers[i])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCollectRowsEmpty(t *testing.T) {
|
||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
rows, _ := conn.Query(ctx, `select n from generate_series(1, 0) n`)
|
||||
|
@ -219,6 +240,41 @@ func ExampleCollectRows() {
|
|||
// [1 2 3 4 5]
|
||||
}
|
||||
|
||||
// This example uses CollectFilteredRows with a manually written collector function. In most cases RowTo, RowToAddrOf,
|
||||
// RowToStructByPos, RowToAddrOfStructByPos, or another generic function would be used.
|
||||
// The filter function is used to filter out rows that don't meet the criteria.
|
||||
// In this example, we filter out all rows where the number is less than 3.
|
||||
func ExampleCollectFilteredRows() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
||||
if err != nil {
|
||||
fmt.Printf("Unable to establish connection: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
rows, _ := conn.Query(ctx, `select n from generate_series(1, 5) n`)
|
||||
numbers, err := pgx.CollectFilteredRows(rows, func(row pgx.CollectableRow) (int32, error) {
|
||||
var n int32
|
||||
err := row.Scan(&n)
|
||||
return n, err
|
||||
},
|
||||
func(n int32) bool {
|
||||
return n >= 3
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
fmt.Printf("CollectFilteredRows error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println(numbers)
|
||||
|
||||
// Output:
|
||||
// [3 4 5]
|
||||
}
|
||||
|
||||
func TestCollectOneRow(t *testing.T) {
|
||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
rows, _ := conn.Query(ctx, `select 42`)
|
||||
|
|
Loading…
Reference in New Issue