diff --git a/CHANGELOG.md b/CHANGELOG.md index 8b6c3a96..6d2d8ee8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -135,9 +135,10 @@ allows arbitrary rewriting of query SQL and arguments. The `RowScanner` interface allows a single argument to Rows.Scan to scan the entire row. -## QueryFunc Replaced +## Rows Result Helpers -`QueryFunc` has been replaced by using `ForEachScannedRow`. +* `CollectRows` and `RowTo*` functions simplify collecting results into a slice. +* `QueryFunc` has been replaced by using `ForEachScannedRow`. ## SendBatch Uses Pipeline Mode When Appropriate diff --git a/rows.go b/rows.go index 4d4c5ec6..0c630bc4 100644 --- a/rows.go +++ b/rows.go @@ -395,3 +395,124 @@ func ForEachScannedRow(rows Rows, scans []any, fn func() error) (pgconn.CommandT return rows.CommandTag(), nil } + +// CollectableRow is the subset of Rows methods that a RowToFunc is allowed to call. +type CollectableRow interface { + FieldDescriptions() []pgproto3.FieldDescription + Scan(dest ...any) error + Values() ([]any, error) + RawValues() [][]byte +} + +// RowToFunc is a function that scans or otherwise converts row to a T. +type RowToFunc[T any] func(row CollectableRow) (T, error) + +// CollectRows iterates through rows, calling fn for each row, and collecting the results into a slice of T. +func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) { + defer rows.Close() + + slice := []T{} + + for rows.Next() { + value, err := fn(rows) + if err != nil { + return nil, err + } + slice = append(slice, value) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return slice, nil +} + +// RowTo returns a T scanned from row. +func RowTo[T any](row CollectableRow) (T, error) { + var value T + err := row.Scan(&value) + return value, err +} + +// RowTo returns a the address of a T scanned from row. +func RowToAddrOf[T any](row CollectableRow) (*T, error) { + var value T + err := row.Scan(&value) + return &value, err +} + +// RowToMap returns a map scanned from row. +func RowToMap(row CollectableRow) (map[string]any, error) { + var value map[string]any + err := row.Scan((*mapRowScanner)(&value)) + return value, err +} + +type mapRowScanner map[string]any + +func (rs *mapRowScanner) ScanRow(rows Rows) error { + values, err := rows.Values() + if err != nil { + return err + } + + *rs = make(mapRowScanner, len(values)) + + for i := range values { + (*rs)[string(rows.FieldDescriptions()[i].Name)] = values[i] + } + + return nil +} + +// RowToStructByPos returns a T scanned from row. T must be a struct. T must have the same number a public fields as row +// has fields. The row and T fields will by matched by position. +func RowToStructByPos[T any](row CollectableRow) (T, error) { + var value T + err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value}) + return value, err +} + +// RowToAddrOfStructByPos returns the address of a T scanned from row. T must be a struct. T must have the same number a +// public fields as row has fields. The row and T fields will by matched by position. +func RowToAddrOfStructByPos[T any](row CollectableRow) (*T, error) { + var value T + err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value}) + return &value, err +} + +type positionalStructRowScanner struct { + ptrToStruct any +} + +func (rs *positionalStructRowScanner) ScanRow(rows Rows) error { + dst := rs.ptrToStruct + dstValue := reflect.ValueOf(dst) + if dstValue.Kind() != reflect.Ptr { + return fmt.Errorf("dst not a pointer") + } + + dstElemValue := dstValue.Elem() + dstElemType := dstElemValue.Type() + + exportedFields := make([]int, 0, dstElemType.NumField()) + for i := 0; i < dstElemType.NumField(); i++ { + sf := dstElemType.Field(i) + if sf.PkgPath == "" { + exportedFields = append(exportedFields, i) + } + } + + rowFieldCount := len(rows.RawValues()) + if rowFieldCount > len(exportedFields) { + return fmt.Errorf("got %d values, but dst struct has only %d fields", rowFieldCount, len(exportedFields)) + } + + scanTargets := make([]any, rowFieldCount) + for i := 0; i < rowFieldCount; i++ { + scanTargets[i] = dstElemValue.Field(exportedFields[i]).Addr().Interface() + } + + return rows.Scan(scanTargets...) +} diff --git a/rows_test.go b/rows_test.go index 63bb77d5..9f07ee2e 100644 --- a/rows_test.go +++ b/rows_test.go @@ -10,6 +10,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -128,3 +129,98 @@ func ExampleForEachScannedRow() { // 2, 4 // 3, 6 } + +func TestCollectRows(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select n from generate_series(0, 99) n`) + numbers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (int32, error) { + var n int32 + err := row.Scan(&n) + return n, err + }) + require.NoError(t, err) + + assert.Len(t, numbers, 100) + for i := range numbers { + assert.Equal(t, int32(i), numbers[i]) + } + }) +} + +func TestRowTo(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select n from generate_series(0, 99) n`) + numbers, err := pgx.CollectRows(rows, pgx.RowTo[int32]) + require.NoError(t, err) + + assert.Len(t, numbers, 100) + for i := range numbers { + assert.Equal(t, int32(i), numbers[i]) + } + }) +} + +func TestRowToAddrOf(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select n from generate_series(0, 99) n`) + numbers, err := pgx.CollectRows(rows, pgx.RowToAddrOf[int32]) + require.NoError(t, err) + + assert.Len(t, numbers, 100) + for i := range numbers { + assert.Equal(t, int32(i), *numbers[i]) + } + }) +} + +func TestRowToMap(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 'Joe' as name, n as age from generate_series(0, 9) n`) + slice, err := pgx.CollectRows(rows, pgx.RowToMap) + require.NoError(t, err) + + assert.Len(t, slice, 10) + for i := range slice { + assert.Equal(t, "Joe", slice[i]["name"]) + assert.EqualValues(t, i, slice[i]["age"]) + } + }) +} + +func TestRowToStructPos(t *testing.T) { + type person struct { + Name string + Age int32 + } + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 'Joe' as name, n as age from generate_series(0, 9) n`) + slice, err := pgx.CollectRows(rows, pgx.RowToStructByPos[person]) + require.NoError(t, err) + + assert.Len(t, slice, 10) + for i := range slice { + assert.Equal(t, "Joe", slice[i].Name) + assert.EqualValues(t, i, slice[i].Age) + } + }) +} + +func TestRowToAddrOfStructPos(t *testing.T) { + type person struct { + Name string + Age int32 + } + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 'Joe' as name, n as age from generate_series(0, 9) n`) + slice, err := pgx.CollectRows(rows, pgx.RowToAddrOfStructByPos[person]) + require.NoError(t, err) + + assert.Len(t, slice, 10) + for i := range slice { + assert.Equal(t, "Joe", slice[i].Name) + assert.EqualValues(t, i, slice[i].Age) + } + }) +}