From 14be51536bbf5e183b68ee9a5fcadaf0d045e503 Mon Sep 17 00:00:00 2001 From: Pavlo Golub Date: Thu, 3 Nov 2022 16:49:20 +0100 Subject: [PATCH] implement `RowToStructByName` and `RowToAddrOfStructByName` --- rows.go | 100 +++++++++++++++++++++++++++++++++++++++++++++++++++ rows_test.go | 36 +++++++++++++++++++ 2 files changed, 136 insertions(+) diff --git a/rows.go b/rows.go index 5fea9883..23f33efc 100644 --- a/rows.go +++ b/rows.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "reflect" + "strings" "time" "github.com/jackc/pgx/v5/internal/stmtcache" @@ -544,3 +545,102 @@ func (rs *positionalStructRowScanner) appendScanTargets(dstElemValue reflect.Val return scanTargets } + +// RowToStructByName returns a T scanned from row. T must be a struct. T must have the same number a named public fields as row +// has fields. The row and T fields will by matched by name. +func RowToStructByName[T any](row CollectableRow) (T, error) { + var value T + err := row.Scan(&namedStructRowScanner{ptrToStruct: &value}) + return value, err +} + +// RowToAddrOfStructByPos returns the address of a T scanned from row. T must be a struct. T must have the same number a +// named public fields as row has fields. The row and T fields will by matched by name. +func RowToAddrOfStructByName[T any](row CollectableRow) (*T, error) { + var value T + err := row.Scan(&namedStructRowScanner{ptrToStruct: &value}) + return &value, err +} + +type namedStructRowScanner struct { + ptrToStruct any +} + +func (rs *namedStructRowScanner) ScanRow(rows Rows) error { + dst := rs.ptrToStruct + dstValue := reflect.ValueOf(dst) + if dstValue.Kind() != reflect.Ptr { + return fmt.Errorf("dst not a pointer") + } + + dstElemValue := dstValue.Elem() + scanTargets, err := rs.appendScanTargets(dstElemValue, nil, rows.FieldDescriptions()) + + if err != nil { + return err + } + + for i, t := range scanTargets { + if t == nil { + return fmt.Errorf("struct doesn't have corresponding row field %s", rows.FieldDescriptions()[i].Name) + } + } + + return rows.Scan(scanTargets...) +} + +const structTagKey = "db" + +func fieldPosByName(fldDescs []pgconn.FieldDescription, field string) (i int) { + i = -1 + for i, desc := range fldDescs { + if strings.EqualFold(desc.Name, field) { + return i + } + } + return +} + +func (rs *namedStructRowScanner) appendScanTargets(dstElemValue reflect.Value, scanTargets []any, fldDescs []pgconn.FieldDescription) ([]any, error) { + var err error + dstElemType := dstElemValue.Type() + + if scanTargets == nil { + scanTargets = make([]any, len(fldDescs)) + } + + for i := 0; i < dstElemType.NumField(); i++ { + sf := dstElemType.Field(i) + if sf.PkgPath != "" && !sf.Anonymous { + // Field is unexported, skip it. + continue + } + // Handle anoymous struct embedding, but do not try to handle embedded pointers. + if sf.Anonymous && sf.Type.Kind() == reflect.Struct { + scanTargets, err = rs.appendScanTargets(dstElemValue.Field(i), scanTargets, fldDescs) + if err != nil { + return nil, err + } + } else { + dbTag, dbTagPresent := sf.Tag.Lookup(structTagKey) + if dbTagPresent { + dbTag = strings.Split(dbTag, ",")[0] + } + if dbTag == "-" { + // Field is ignored, skip it. + continue + } + colName := dbTag + if !dbTagPresent { + colName = sf.Name + } + fpos := fieldPosByName(fldDescs, colName) + if fpos == -1 || fpos >= len(scanTargets) { + return nil, fmt.Errorf("cannot find field %s in returned row", colName) + } + scanTargets[fpos] = dstElemValue.Field(i).Addr().Interface() + } + } + + return scanTargets, err +} diff --git a/rows_test.go b/rows_test.go index 60cf1cd4..aacd3def 100644 --- a/rows_test.go +++ b/rows_test.go @@ -508,3 +508,39 @@ func TestRowToAddrOfStructPos(t *testing.T) { } }) } + +func TestRowToStructByNameEmbeddedStruct(t *testing.T) { + type Name struct { + Last string `db:"last_name"` + First string `db:"first_name"` + } + + type person struct { + Ignore bool `db:"-"` + Name + Age int32 + } + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 'John' as first_name, 'Smith' as last_name, n as age from generate_series(0, 9) n`) + slice, err := pgx.CollectRows(rows, pgx.RowToStructByName[person]) + assert.NoError(t, err) + + assert.Len(t, slice, 10) + for i := range slice { + assert.Equal(t, "Smith", slice[i].Name.Last) + assert.Equal(t, "John", slice[i].Name.First) + assert.EqualValues(t, i, slice[i].Age) + } + + // check missing fields in a returned row + rows, _ = conn.Query(ctx, `select 'Smith' as last_name, n as age from generate_series(0, 9) n`) + _, err = pgx.CollectRows(rows, pgx.RowToStructByName[person]) + assert.ErrorContains(t, err, "cannot find field first_name in returned row") + + // check missing field in a destination struct + rows, _ = conn.Query(ctx, `select 'John' as first_name, 'Smith' as last_name, n as age, null as ignore from generate_series(0, 9) n`) + _, err = pgx.CollectRows(rows, pgx.RowToAddrOfStructByName[person]) + assert.ErrorContains(t, err, "struct doesn't have corresponding row field ignore") + }) +}