add CollectFilteredRows

pull/2297/head
Pius Alfred 2025-03-29 15:55:02 +03:00
parent 04bcc0219d
commit ecdab4e9ae
No known key found for this signature in database
GPG Key ID: 5970E91D854508CA
2 changed files with 92 additions and 0 deletions

36
rows.go
View File

@ -451,6 +451,42 @@ func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) {
return AppendRows([]T{}, rows, fn)
}
// AppendFilteredRows iterates through rows, calling fn for each row, and appending the results into a slice of T.
// If filter is not nil, only rows for which filter returns true will be appended, If filter is nil, all rows will be appended.
// This function closes the rows automatically on return.
func AppendFilteredRows[T any, S ~[]T](slice S, rows Rows, fn RowToFunc[T], filter func(T)bool) (S, error){
defer rows.Close()
for rows.Next() {
value, err := fn(rows)
if err != nil {
return nil, err
}
if filter == nil{
slice = append(slice, value)
continue
}
if filter(value){
slice = append(slice, value)
}
}
if err := rows.Err(); err != nil {
return nil, err
}
return slice, nil
}
// CollectFilteredRows iterates through rows, calling fn for each row, and collecting the results into a slice of T.
// If filter is not nil, only rows for which filter returns true will be collected, If filter is nil, all rows will be collected.
// This function closes the rows automatically on return.
func CollectFilteredRows[T any](rows Rows, fn RowToFunc[T], filter func(T)bool) ([]T, error){
return AppendFilteredRows([]T{}, rows, fn, filter)
}
// CollectOneRow calls fn for the first row in rows and returns the result. If no rows are found returns an error where errors.Is(ErrNoRows) is true.
// CollectOneRow is to CollectRows as QueryRow is to Query.
//

View File

@ -175,6 +175,27 @@ func TestCollectRows(t *testing.T) {
})
}
func TestCollectFilteredRows(t *testing.T) {
filter := func(value int32) bool {
return value <= 20
}
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
rows, _ := conn.Query(ctx, `select n from generate_series(0, 99) n`)
numbers, err := pgx.CollectFilteredRows(rows, func(row pgx.CollectableRow) (int32, error) {
var n int32
err := row.Scan(&n)
return n, err
}, filter)
require.NoError(t, err)
assert.Len(t, numbers, 21)
for i := range numbers {
assert.Equal(t, int32(i), numbers[i])
}
})
}
func TestCollectRowsEmpty(t *testing.T) {
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
rows, _ := conn.Query(ctx, `select n from generate_series(1, 0) n`)
@ -219,6 +240,41 @@ func ExampleCollectRows() {
// [1 2 3 4 5]
}
// This example uses CollectFilteredRows with a manually written collector function. In most cases RowTo, RowToAddrOf,
// RowToStructByPos, RowToAddrOfStructByPos, or another generic function would be used.
// The filter function is used to filter out rows that don't meet the criteria.
// In this example, we filter out all rows where the number is less than 3.
func ExampleCollectFilteredRows() {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
if err != nil {
fmt.Printf("Unable to establish connection: %v", err)
return
}
rows, _ := conn.Query(ctx, `select n from generate_series(1, 5) n`)
numbers, err := pgx.CollectFilteredRows(rows, func(row pgx.CollectableRow) (int32, error) {
var n int32
err := row.Scan(&n)
return n, err
},
func(n int32) bool {
return n >= 3
})
if err != nil {
fmt.Printf("CollectFilteredRows error: %v", err)
return
}
fmt.Println(numbers)
// Output:
// [3 4 5]
}
func TestCollectOneRow(t *testing.T) {
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
rows, _ := conn.Query(ctx, `select 42`)