mirror of https://github.com/jackc/pgx.git
implement `RowToStructByName` and `RowToAddrOfStructByName`
parent
1376a2c0ed
commit
14be51536b
100
rows.go
100
rows.go
|
@ -5,6 +5,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/stmtcache"
|
||||
|
@ -544,3 +545,102 @@ func (rs *positionalStructRowScanner) appendScanTargets(dstElemValue reflect.Val
|
|||
|
||||
return scanTargets
|
||||
}
|
||||
|
||||
// RowToStructByName returns a T scanned from row. T must be a struct. T must have the same number a named public fields as row
|
||||
// has fields. The row and T fields will by matched by name.
|
||||
func RowToStructByName[T any](row CollectableRow) (T, error) {
|
||||
var value T
|
||||
err := row.Scan(&namedStructRowScanner{ptrToStruct: &value})
|
||||
return value, err
|
||||
}
|
||||
|
||||
// RowToAddrOfStructByPos returns the address of a T scanned from row. T must be a struct. T must have the same number a
|
||||
// named public fields as row has fields. The row and T fields will by matched by name.
|
||||
func RowToAddrOfStructByName[T any](row CollectableRow) (*T, error) {
|
||||
var value T
|
||||
err := row.Scan(&namedStructRowScanner{ptrToStruct: &value})
|
||||
return &value, err
|
||||
}
|
||||
|
||||
type namedStructRowScanner struct {
|
||||
ptrToStruct any
|
||||
}
|
||||
|
||||
func (rs *namedStructRowScanner) ScanRow(rows Rows) error {
|
||||
dst := rs.ptrToStruct
|
||||
dstValue := reflect.ValueOf(dst)
|
||||
if dstValue.Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("dst not a pointer")
|
||||
}
|
||||
|
||||
dstElemValue := dstValue.Elem()
|
||||
scanTargets, err := rs.appendScanTargets(dstElemValue, nil, rows.FieldDescriptions())
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i, t := range scanTargets {
|
||||
if t == nil {
|
||||
return fmt.Errorf("struct doesn't have corresponding row field %s", rows.FieldDescriptions()[i].Name)
|
||||
}
|
||||
}
|
||||
|
||||
return rows.Scan(scanTargets...)
|
||||
}
|
||||
|
||||
const structTagKey = "db"
|
||||
|
||||
func fieldPosByName(fldDescs []pgconn.FieldDescription, field string) (i int) {
|
||||
i = -1
|
||||
for i, desc := range fldDescs {
|
||||
if strings.EqualFold(desc.Name, field) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (rs *namedStructRowScanner) appendScanTargets(dstElemValue reflect.Value, scanTargets []any, fldDescs []pgconn.FieldDescription) ([]any, error) {
|
||||
var err error
|
||||
dstElemType := dstElemValue.Type()
|
||||
|
||||
if scanTargets == nil {
|
||||
scanTargets = make([]any, len(fldDescs))
|
||||
}
|
||||
|
||||
for i := 0; i < dstElemType.NumField(); i++ {
|
||||
sf := dstElemType.Field(i)
|
||||
if sf.PkgPath != "" && !sf.Anonymous {
|
||||
// Field is unexported, skip it.
|
||||
continue
|
||||
}
|
||||
// Handle anoymous struct embedding, but do not try to handle embedded pointers.
|
||||
if sf.Anonymous && sf.Type.Kind() == reflect.Struct {
|
||||
scanTargets, err = rs.appendScanTargets(dstElemValue.Field(i), scanTargets, fldDescs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
dbTag, dbTagPresent := sf.Tag.Lookup(structTagKey)
|
||||
if dbTagPresent {
|
||||
dbTag = strings.Split(dbTag, ",")[0]
|
||||
}
|
||||
if dbTag == "-" {
|
||||
// Field is ignored, skip it.
|
||||
continue
|
||||
}
|
||||
colName := dbTag
|
||||
if !dbTagPresent {
|
||||
colName = sf.Name
|
||||
}
|
||||
fpos := fieldPosByName(fldDescs, colName)
|
||||
if fpos == -1 || fpos >= len(scanTargets) {
|
||||
return nil, fmt.Errorf("cannot find field %s in returned row", colName)
|
||||
}
|
||||
scanTargets[fpos] = dstElemValue.Field(i).Addr().Interface()
|
||||
}
|
||||
}
|
||||
|
||||
return scanTargets, err
|
||||
}
|
||||
|
|
36
rows_test.go
36
rows_test.go
|
@ -508,3 +508,39 @@ func TestRowToAddrOfStructPos(t *testing.T) {
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRowToStructByNameEmbeddedStruct(t *testing.T) {
|
||||
type Name struct {
|
||||
Last string `db:"last_name"`
|
||||
First string `db:"first_name"`
|
||||
}
|
||||
|
||||
type person struct {
|
||||
Ignore bool `db:"-"`
|
||||
Name
|
||||
Age int32
|
||||
}
|
||||
|
||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
rows, _ := conn.Query(ctx, `select 'John' as first_name, 'Smith' as last_name, n as age from generate_series(0, 9) n`)
|
||||
slice, err := pgx.CollectRows(rows, pgx.RowToStructByName[person])
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Len(t, slice, 10)
|
||||
for i := range slice {
|
||||
assert.Equal(t, "Smith", slice[i].Name.Last)
|
||||
assert.Equal(t, "John", slice[i].Name.First)
|
||||
assert.EqualValues(t, i, slice[i].Age)
|
||||
}
|
||||
|
||||
// check missing fields in a returned row
|
||||
rows, _ = conn.Query(ctx, `select 'Smith' as last_name, n as age from generate_series(0, 9) n`)
|
||||
_, err = pgx.CollectRows(rows, pgx.RowToStructByName[person])
|
||||
assert.ErrorContains(t, err, "cannot find field first_name in returned row")
|
||||
|
||||
// check missing field in a destination struct
|
||||
rows, _ = conn.Query(ctx, `select 'John' as first_name, 'Smith' as last_name, n as age, null as ignore from generate_series(0, 9) n`)
|
||||
_, err = pgx.CollectRows(rows, pgx.RowToAddrOfStructByName[person])
|
||||
assert.ErrorContains(t, err, "struct doesn't have corresponding row field ignore")
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue