From 62f034758660c08051586a030f8404390c99efd7 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 9 Jul 2022 16:59:29 -0500 Subject: [PATCH] Add CollectOneRow --- CHANGELOG.md | 1 + rows.go | 21 +++++++++++++++++++++ rows_test.go | 41 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 63 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 44964ac7..4438e0bc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -138,6 +138,7 @@ The `RowScanner` interface allows a single argument to Rows.Scan to scan the ent ## Rows Result Helpers * `CollectRows` and `RowTo*` functions simplify collecting results into a slice. +* `CollectOneRow` collects one row using `RowTo*` functions. * `ForEachRow` simplifies scanning each row and executing code using the scanned values. `ForEachRow` replaces `QueryFunc`. ## SendBatch Uses Pipeline Mode When Appropriate diff --git a/rows.go b/rows.go index 2afb31df..90a24d28 100644 --- a/rows.go +++ b/rows.go @@ -428,6 +428,27 @@ func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) { return slice, nil } +// CollectOneRow calls fn for the first row in rows and returns the result. If no rows are found returns an error where errors.Is(ErrNoRows) is true. +// CollectOneRow is to CollectRows as QueryRow is to Query. +func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) { + defer rows.Close() + + var value T + var err error + + if !rows.Next() { + return value, ErrNoRows + } + + value, err = fn(rows) + if err != nil { + return value, err + } + + rows.Close() + return value, rows.Err() +} + // RowTo returns a T scanned from row. func RowTo[T any](row CollectableRow) (T, error) { var value T diff --git a/rows_test.go b/rows_test.go index 8806b1f6..0e3bc179 100644 --- a/rows_test.go +++ b/rows_test.go @@ -147,6 +147,47 @@ func TestCollectRows(t *testing.T) { }) } +func TestCollectOneRow(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 42`) + n, err := pgx.CollectOneRow(rows, func(row pgx.CollectableRow) (int32, error) { + var n int32 + err := row.Scan(&n) + return n, err + }) + assert.NoError(t, err) + assert.Equal(t, int32(42), n) + }) +} + +func TestCollectOneRowNotFound(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 42 where false`) + n, err := pgx.CollectOneRow(rows, func(row pgx.CollectableRow) (int32, error) { + var n int32 + err := row.Scan(&n) + return n, err + }) + assert.ErrorIs(t, err, pgx.ErrNoRows) + assert.Equal(t, int32(0), n) + }) +} + +func TestCollectOneRowIgnoresExtraRows(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select n from generate_series(42, 99) n`) + n, err := pgx.CollectOneRow(rows, func(row pgx.CollectableRow) (int32, error) { + var n int32 + err := row.Scan(&n) + return n, err + }) + require.NoError(t, err) + + assert.NoError(t, err) + assert.Equal(t, int32(42), n) + }) +} + func TestRowTo(t *testing.T) { defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { rows, _ := conn.Query(ctx, `select n from generate_series(0, 99) n`)