From 9ee7d29cf914b870b417219bb94cce1c7e625748 Mon Sep 17 00:00:00 2001 From: Julien GOTTELAND Date: Sat, 19 Aug 2023 18:24:39 +0200 Subject: [PATCH] Add CollectExactlyOneRow function --- conn.go | 8 ++++++-- rows.go | 33 ++++++++++++++++++++++++++++++++- rows_test.go | 39 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 77 insertions(+), 3 deletions(-) diff --git a/conn.go b/conn.go index 7c7081b4..b867acfe 100644 --- a/conn.go +++ b/conn.go @@ -99,8 +99,12 @@ func (ident Identifier) Sanitize() string { return strings.Join(parts, ".") } -// ErrNoRows occurs when rows are expected but none are returned. -var ErrNoRows = errors.New("no rows in result set") +var ( + // ErrNoRows occurs when rows are expected but none are returned. + ErrNoRows = errors.New("no rows in result set") + // ErrTooManyRows occurs when more rows than expected are returned. + ErrTooManyRows = errors.New("too many rows in result set") +) var errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache") var errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache") diff --git a/rows.go b/rows.go index ab8a12f8..a58cc7ee 100644 --- a/rows.go +++ b/rows.go @@ -48,7 +48,7 @@ type Rows interface { // Callers should check rows.Err() after rows.Next() returns false to detect // whether result-set reading ended prematurely due to an error. See // Conn.Query for details. - // + // // For simpler error handling, consider using the higher-level pgx v5 // CollectRows() and ForEachRow() helpers instead. Next() bool @@ -465,6 +465,37 @@ func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) { return value, rows.Err() } +// CollectExactlyOneRow calls fn for the first row in rows and returns the result. +// - If no rows are found returns an error where errors.Is(ErrNoRows) is true. +// - If more than 1 row is found returns the first result and an error where errors.Is(ErrTooManyRows) is true. +func CollectExactlyOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) { + defer rows.Close() + + var ( + err error + value T + ) + + if !rows.Next() { + if err = rows.Err(); err != nil { + return value, err + } + + return value, ErrNoRows + } + + value, err = fn(rows) + if err != nil { + return value, err + } + + if rows.Next() { + return value, ErrTooManyRows + } + + return value, rows.Err() +} + // RowTo returns a T scanned from row. func RowTo[T any](row CollectableRow) (T, error) { var value T diff --git a/rows_test.go b/rows_test.go index b2d1137a..be221ea4 100644 --- a/rows_test.go +++ b/rows_test.go @@ -274,6 +274,45 @@ func TestCollectOneRowPrefersPostgreSQLErrorOverErrNoRows(t *testing.T) { }) } +func TestCollectExactlyOneRow(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 42`) + n, err := pgx.CollectExactlyOneRow(rows, func(row pgx.CollectableRow) (int32, error) { + var n int32 + err := row.Scan(&n) + return n, err + }) + assert.NoError(t, err) + assert.Equal(t, int32(42), n) + }) +} + +func TestCollectExactlyOneRowNotFound(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 42 where false`) + n, err := pgx.CollectExactlyOneRow(rows, func(row pgx.CollectableRow) (int32, error) { + var n int32 + err := row.Scan(&n) + return n, err + }) + assert.ErrorIs(t, err, pgx.ErrNoRows) + assert.Equal(t, int32(0), n) + }) +} + +func TestCollectExactlyOneRowExtraRows(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select n from generate_series(42, 99) n`) + n, err := pgx.CollectExactlyOneRow(rows, func(row pgx.CollectableRow) (int32, error) { + var n int32 + err := row.Scan(&n) + return n, err + }) + assert.ErrorIs(t, err, pgx.ErrTooManyRows) + assert.Equal(t, int32(0), n) + }) +} + func TestRowTo(t *testing.T) { defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { rows, _ := conn.Query(ctx, `select n from generate_series(0, 99) n`)