From ee2622a8e699a7b052f272b08a6354bfdb90dbcd Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 6 Sep 2022 18:32:10 -0500 Subject: [PATCH] RowToStructByPos supports embedded structs https://github.com/jackc/pgx/issues/1273#issuecomment-1236966785 --- rows.go | 42 +++++++++++++++++++++++++----------------- rows_test.go | 44 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 17 deletions(-) diff --git a/rows.go b/rows.go index 80df4bb2..33d8ab09 100644 --- a/rows.go +++ b/rows.go @@ -511,25 +511,33 @@ func (rs *positionalStructRowScanner) ScanRow(rows Rows) error { } dstElemValue := dstValue.Elem() - dstElemType := dstElemValue.Type() + scanTargets := rs.appendScanTargets(dstElemValue, nil) - exportedFields := make([]int, 0, dstElemType.NumField()) - for i := 0; i < dstElemType.NumField(); i++ { - sf := dstElemType.Field(i) - if sf.PkgPath == "" { - exportedFields = append(exportedFields, i) - } - } - - rowFieldCount := len(rows.RawValues()) - if rowFieldCount > len(exportedFields) { - return fmt.Errorf("got %d values, but dst struct has only %d fields", rowFieldCount, len(exportedFields)) - } - - scanTargets := make([]any, rowFieldCount) - for i := 0; i < rowFieldCount; i++ { - scanTargets[i] = dstElemValue.Field(exportedFields[i]).Addr().Interface() + if len(rows.RawValues()) > len(scanTargets) { + return fmt.Errorf("got %d values, but dst struct has only %d fields", len(rows.RawValues()), len(scanTargets)) } return rows.Scan(scanTargets...) } + +func (rs *positionalStructRowScanner) appendScanTargets(dstElemValue reflect.Value, scanTargets []any) []any { + dstElemType := dstElemValue.Type() + + if scanTargets == nil { + scanTargets = make([]any, 0, dstElemType.NumField()) + } + + for i := 0; i < dstElemType.NumField(); i++ { + sf := dstElemType.Field(i) + if sf.PkgPath == "" { + // Handle anoymous struct embedding, but do not try to handle embedded pointers. + if sf.Anonymous && sf.Type.Kind() == reflect.Struct { + scanTargets = append(scanTargets, rs.appendScanTargets(dstElemValue.Field(i), scanTargets)...) + } else { + scanTargets = append(scanTargets, dstElemValue.Field(i).Addr().Interface()) + } + } + } + + return scanTargets +} diff --git a/rows_test.go b/rows_test.go index 6771469f..7aeafac8 100644 --- a/rows_test.go +++ b/rows_test.go @@ -329,6 +329,50 @@ func TestRowToStructByPos(t *testing.T) { }) } +func TestRowToStructByPosEmbeddedStruct(t *testing.T) { + type Name struct { + First string + Last string + } + + type person struct { + Name + Age int32 + } + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 'John' as first_name, 'Smith' as last_name, n as age from generate_series(0, 9) n`) + slice, err := pgx.CollectRows(rows, pgx.RowToStructByPos[person]) + require.NoError(t, err) + + assert.Len(t, slice, 10) + for i := range slice { + assert.Equal(t, "John", slice[i].Name.First) + assert.Equal(t, "Smith", slice[i].Name.Last) + assert.EqualValues(t, i, slice[i].Age) + } + }) +} + +// Pointer to struct is not supported. But check that we don't panic. +func TestRowToStructByPosEmbeddedPointerToStruct(t *testing.T) { + type Name struct { + First string + Last string + } + + type person struct { + *Name + Age int32 + } + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 'John' as first_name, 'Smith' as last_name, n as age from generate_series(0, 9) n`) + _, err := pgx.CollectRows(rows, pgx.RowToStructByPos[person]) + require.EqualError(t, err, "got 3 values, but dst struct has only 2 fields") + }) +} + func ExampleRowToStructByPos() { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel()