mirror of https://github.com/jackc/pgx.git
Add CollectRows and RowTo* functions
Collect functionality was originally developed in pgxutilpull/1281/head
parent
3dafb5d4ee
commit
da192291f7
|
@ -135,9 +135,10 @@ allows arbitrary rewriting of query SQL and arguments.
|
||||||
|
|
||||||
The `RowScanner` interface allows a single argument to Rows.Scan to scan the entire row.
|
The `RowScanner` interface allows a single argument to Rows.Scan to scan the entire row.
|
||||||
|
|
||||||
## QueryFunc Replaced
|
## Rows Result Helpers
|
||||||
|
|
||||||
`QueryFunc` has been replaced by using `ForEachScannedRow`.
|
* `CollectRows` and `RowTo*` functions simplify collecting results into a slice.
|
||||||
|
* `QueryFunc` has been replaced by using `ForEachScannedRow`.
|
||||||
|
|
||||||
## SendBatch Uses Pipeline Mode When Appropriate
|
## SendBatch Uses Pipeline Mode When Appropriate
|
||||||
|
|
||||||
|
|
121
rows.go
121
rows.go
|
@ -395,3 +395,124 @@ func ForEachScannedRow(rows Rows, scans []any, fn func() error) (pgconn.CommandT
|
||||||
|
|
||||||
return rows.CommandTag(), nil
|
return rows.CommandTag(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CollectableRow is the subset of Rows methods that a RowToFunc is allowed to call.
|
||||||
|
type CollectableRow interface {
|
||||||
|
FieldDescriptions() []pgproto3.FieldDescription
|
||||||
|
Scan(dest ...any) error
|
||||||
|
Values() ([]any, error)
|
||||||
|
RawValues() [][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// RowToFunc is a function that scans or otherwise converts row to a T.
|
||||||
|
type RowToFunc[T any] func(row CollectableRow) (T, error)
|
||||||
|
|
||||||
|
// CollectRows iterates through rows, calling fn for each row, and collecting the results into a slice of T.
|
||||||
|
func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) {
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
slice := []T{}
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
value, err := fn(rows)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
slice = append(slice, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return slice, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RowTo returns a T scanned from row.
|
||||||
|
func RowTo[T any](row CollectableRow) (T, error) {
|
||||||
|
var value T
|
||||||
|
err := row.Scan(&value)
|
||||||
|
return value, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// RowTo returns a the address of a T scanned from row.
|
||||||
|
func RowToAddrOf[T any](row CollectableRow) (*T, error) {
|
||||||
|
var value T
|
||||||
|
err := row.Scan(&value)
|
||||||
|
return &value, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// RowToMap returns a map scanned from row.
|
||||||
|
func RowToMap(row CollectableRow) (map[string]any, error) {
|
||||||
|
var value map[string]any
|
||||||
|
err := row.Scan((*mapRowScanner)(&value))
|
||||||
|
return value, err
|
||||||
|
}
|
||||||
|
|
||||||
|
type mapRowScanner map[string]any
|
||||||
|
|
||||||
|
func (rs *mapRowScanner) ScanRow(rows Rows) error {
|
||||||
|
values, err := rows.Values()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*rs = make(mapRowScanner, len(values))
|
||||||
|
|
||||||
|
for i := range values {
|
||||||
|
(*rs)[string(rows.FieldDescriptions()[i].Name)] = values[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RowToStructByPos returns a T scanned from row. T must be a struct. T must have the same number a public fields as row
|
||||||
|
// has fields. The row and T fields will by matched by position.
|
||||||
|
func RowToStructByPos[T any](row CollectableRow) (T, error) {
|
||||||
|
var value T
|
||||||
|
err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value})
|
||||||
|
return value, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// RowToAddrOfStructByPos returns the address of a T scanned from row. T must be a struct. T must have the same number a
|
||||||
|
// public fields as row has fields. The row and T fields will by matched by position.
|
||||||
|
func RowToAddrOfStructByPos[T any](row CollectableRow) (*T, error) {
|
||||||
|
var value T
|
||||||
|
err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value})
|
||||||
|
return &value, err
|
||||||
|
}
|
||||||
|
|
||||||
|
type positionalStructRowScanner struct {
|
||||||
|
ptrToStruct any
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rs *positionalStructRowScanner) ScanRow(rows Rows) error {
|
||||||
|
dst := rs.ptrToStruct
|
||||||
|
dstValue := reflect.ValueOf(dst)
|
||||||
|
if dstValue.Kind() != reflect.Ptr {
|
||||||
|
return fmt.Errorf("dst not a pointer")
|
||||||
|
}
|
||||||
|
|
||||||
|
dstElemValue := dstValue.Elem()
|
||||||
|
dstElemType := dstElemValue.Type()
|
||||||
|
|
||||||
|
exportedFields := make([]int, 0, dstElemType.NumField())
|
||||||
|
for i := 0; i < dstElemType.NumField(); i++ {
|
||||||
|
sf := dstElemType.Field(i)
|
||||||
|
if sf.PkgPath == "" {
|
||||||
|
exportedFields = append(exportedFields, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rowFieldCount := len(rows.RawValues())
|
||||||
|
if rowFieldCount > len(exportedFields) {
|
||||||
|
return fmt.Errorf("got %d values, but dst struct has only %d fields", rowFieldCount, len(exportedFields))
|
||||||
|
}
|
||||||
|
|
||||||
|
scanTargets := make([]any, rowFieldCount)
|
||||||
|
for i := 0; i < rowFieldCount; i++ {
|
||||||
|
scanTargets[i] = dstElemValue.Field(exportedFields[i]).Addr().Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
return rows.Scan(scanTargets...)
|
||||||
|
}
|
||||||
|
|
96
rows_test.go
96
rows_test.go
|
@ -10,6 +10,7 @@ import (
|
||||||
"github.com/jackc/pgx/v5"
|
"github.com/jackc/pgx/v5"
|
||||||
"github.com/jackc/pgx/v5/pgconn"
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
"github.com/jackc/pgx/v5/pgxtest"
|
"github.com/jackc/pgx/v5/pgxtest"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -128,3 +129,98 @@ func ExampleForEachScannedRow() {
|
||||||
// 2, 4
|
// 2, 4
|
||||||
// 3, 6
|
// 3, 6
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCollectRows(t *testing.T) {
|
||||||
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
rows, _ := conn.Query(ctx, `select n from generate_series(0, 99) n`)
|
||||||
|
numbers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (int32, error) {
|
||||||
|
var n int32
|
||||||
|
err := row.Scan(&n)
|
||||||
|
return n, err
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Len(t, numbers, 100)
|
||||||
|
for i := range numbers {
|
||||||
|
assert.Equal(t, int32(i), numbers[i])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRowTo(t *testing.T) {
|
||||||
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
rows, _ := conn.Query(ctx, `select n from generate_series(0, 99) n`)
|
||||||
|
numbers, err := pgx.CollectRows(rows, pgx.RowTo[int32])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Len(t, numbers, 100)
|
||||||
|
for i := range numbers {
|
||||||
|
assert.Equal(t, int32(i), numbers[i])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRowToAddrOf(t *testing.T) {
|
||||||
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
rows, _ := conn.Query(ctx, `select n from generate_series(0, 99) n`)
|
||||||
|
numbers, err := pgx.CollectRows(rows, pgx.RowToAddrOf[int32])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Len(t, numbers, 100)
|
||||||
|
for i := range numbers {
|
||||||
|
assert.Equal(t, int32(i), *numbers[i])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRowToMap(t *testing.T) {
|
||||||
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
rows, _ := conn.Query(ctx, `select 'Joe' as name, n as age from generate_series(0, 9) n`)
|
||||||
|
slice, err := pgx.CollectRows(rows, pgx.RowToMap)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Len(t, slice, 10)
|
||||||
|
for i := range slice {
|
||||||
|
assert.Equal(t, "Joe", slice[i]["name"])
|
||||||
|
assert.EqualValues(t, i, slice[i]["age"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRowToStructPos(t *testing.T) {
|
||||||
|
type person struct {
|
||||||
|
Name string
|
||||||
|
Age int32
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
rows, _ := conn.Query(ctx, `select 'Joe' as name, n as age from generate_series(0, 9) n`)
|
||||||
|
slice, err := pgx.CollectRows(rows, pgx.RowToStructByPos[person])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Len(t, slice, 10)
|
||||||
|
for i := range slice {
|
||||||
|
assert.Equal(t, "Joe", slice[i].Name)
|
||||||
|
assert.EqualValues(t, i, slice[i].Age)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRowToAddrOfStructPos(t *testing.T) {
|
||||||
|
type person struct {
|
||||||
|
Name string
|
||||||
|
Age int32
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
rows, _ := conn.Query(ctx, `select 'Joe' as name, n as age from generate_series(0, 9) n`)
|
||||||
|
slice, err := pgx.CollectRows(rows, pgx.RowToAddrOfStructByPos[person])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Len(t, slice, 10)
|
||||||
|
for i := range slice {
|
||||||
|
assert.Equal(t, "Joe", slice[i].Name)
|
||||||
|
assert.EqualValues(t, i, slice[i].Age)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue