From ecdab4e9aeec5bc34507f45ac9b87cae19b9e2d7 Mon Sep 17 00:00:00 2001 From: Pius Alfred Date: Sat, 29 Mar 2025 15:55:02 +0300 Subject: [PATCH] add CollectFilteredRows --- rows.go | 36 +++++++++++++++++++++++++++++++++ rows_test.go | 56 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+) diff --git a/rows.go b/rows.go index f6f26f47..2251868b 100644 --- a/rows.go +++ b/rows.go @@ -451,6 +451,42 @@ func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) { return AppendRows([]T{}, rows, fn) } +// AppendFilteredRows iterates through rows, calling fn for each row, and appending the results into a slice of T. +// If filter is not nil, only rows for which filter returns true will be appended, If filter is nil, all rows will be appended. +// This function closes the rows automatically on return. +func AppendFilteredRows[T any, S ~[]T](slice S, rows Rows, fn RowToFunc[T], filter func(T)bool) (S, error){ + defer rows.Close() + + for rows.Next() { + value, err := fn(rows) + if err != nil { + return nil, err + } + + if filter == nil{ + slice = append(slice, value) + continue + } + + if filter(value){ + slice = append(slice, value) + } + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return slice, nil +} + +// CollectFilteredRows iterates through rows, calling fn for each row, and collecting the results into a slice of T. +// If filter is not nil, only rows for which filter returns true will be collected, If filter is nil, all rows will be collected. +// This function closes the rows automatically on return. +func CollectFilteredRows[T any](rows Rows, fn RowToFunc[T], filter func(T)bool) ([]T, error){ + return AppendFilteredRows([]T{}, rows, fn, filter) +} + // CollectOneRow calls fn for the first row in rows and returns the result. If no rows are found returns an error where errors.Is(ErrNoRows) is true. // CollectOneRow is to CollectRows as QueryRow is to Query. // diff --git a/rows_test.go b/rows_test.go index 4cda957f..fe1efa31 100644 --- a/rows_test.go +++ b/rows_test.go @@ -175,6 +175,27 @@ func TestCollectRows(t *testing.T) { }) } +func TestCollectFilteredRows(t *testing.T) { + filter := func(value int32) bool { + return value <= 20 + } + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select n from generate_series(0, 99) n`) + numbers, err := pgx.CollectFilteredRows(rows, func(row pgx.CollectableRow) (int32, error) { + var n int32 + err := row.Scan(&n) + return n, err + }, filter) + require.NoError(t, err) + + assert.Len(t, numbers, 21) + for i := range numbers { + assert.Equal(t, int32(i), numbers[i]) + } + }) +} + func TestCollectRowsEmpty(t *testing.T) { defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { rows, _ := conn.Query(ctx, `select n from generate_series(1, 0) n`) @@ -219,6 +240,41 @@ func ExampleCollectRows() { // [1 2 3 4 5] } +// This example uses CollectFilteredRows with a manually written collector function. In most cases RowTo, RowToAddrOf, +// RowToStructByPos, RowToAddrOfStructByPos, or another generic function would be used. +// The filter function is used to filter out rows that don't meet the criteria. +// In this example, we filter out all rows where the number is less than 3. +func ExampleCollectFilteredRows() { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + fmt.Printf("Unable to establish connection: %v", err) + return + } + + rows, _ := conn.Query(ctx, `select n from generate_series(1, 5) n`) + numbers, err := pgx.CollectFilteredRows(rows, func(row pgx.CollectableRow) (int32, error) { + var n int32 + err := row.Scan(&n) + return n, err + }, + func(n int32) bool { + return n >= 3 + }) + + if err != nil { + fmt.Printf("CollectFilteredRows error: %v", err) + return + } + + fmt.Println(numbers) + + // Output: + // [3 4 5] +} + func TestCollectOneRow(t *testing.T) { defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { rows, _ := conn.Query(ctx, `select 42`)