From 6e40968cfc1866ff2795771c7fcb08dd74faee68 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 22 Oct 2022 08:44:06 -0500 Subject: [PATCH] CollectOneRow prefers PostgreSQL error over pgx.ErrorNoRows fixes https://github.com/jackc/pgx/issues/1334 --- rows.go | 3 +++ rows_test.go | 29 +++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/rows.go b/rows.go index 33d8ab09..9acb6fc6 100644 --- a/rows.go +++ b/rows.go @@ -433,6 +433,9 @@ func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) { var err error if !rows.Next() { + if err = rows.Err(); err != nil { + return value, err + } return value, ErrNoRows } diff --git a/rows_test.go b/rows_test.go index 7aeafac8..ca024566 100644 --- a/rows_test.go +++ b/rows_test.go @@ -218,6 +218,35 @@ func TestCollectOneRowIgnoresExtraRows(t *testing.T) { }) } +// https://github.com/jackc/pgx/issues/1334 +func TestCollectOneRowPrefersPostgreSQLErrorOverErrNoRows(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + _, err := conn.Exec(ctx, `create temporary table t (name text not null unique)`) + require.NoError(t, err) + + var name string + rows, _ := conn.Query(ctx, `insert into t (name) values ('foo') returning name`) + name, err = pgx.CollectOneRow(rows, func(row pgx.CollectableRow) (string, error) { + var n string + err := row.Scan(&n) + return n, err + }) + require.NoError(t, err) + require.Equal(t, "foo", name) + + rows, _ = conn.Query(ctx, `insert into t (name) values ('foo') returning name`) + name, err = pgx.CollectOneRow(rows, func(row pgx.CollectableRow) (string, error) { + var n string + err := row.Scan(&n) + return n, err + }) + require.Error(t, err) + var pgErr *pgconn.PgError + require.ErrorAs(t, err, &pgErr) + require.Equal(t, "23505", pgErr.Code) + }) +} + func TestRowTo(t *testing.T) { defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { rows, _ := conn.Query(ctx, `select n from generate_series(0, 99) n`)