diff --git a/rows.go b/rows.go index ffe739b0..c5739963 100644 --- a/rows.go +++ b/rows.go @@ -578,7 +578,6 @@ func (rs *namedStructRowScanner) ScanRow(rows Rows) error { dstElemValue := dstValue.Elem() scanTargets, err := rs.appendScanTargets(dstElemValue, nil, rows.FieldDescriptions()) - if err != nil { return err } @@ -647,3 +646,92 @@ func (rs *namedStructRowScanner) appendScanTargets(dstElemValue reflect.Value, s return scanTargets, err } + +// RowToStructByNameLax returns a T scanned from row. T must be a struct. T must have greater than or equal number of named public +// fields as row has fields. The row and T fields will by matched by name. The match is case-insensitive. The database +// column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored. +func RowToStructByNameLax[T any](row CollectableRow) (T, error) { + var value T + err := row.Scan(&namedStructRowLaxScanner{ptrToStruct: &value}) + return value, err +} + +// RowToAddrOfStructByNameLax returns the address of a T scanned from row. T must be a struct. T must have greater than or +// equal number of named public fields as row has fields. The row and T fields will by matched by name. The match is +// case-insensitive. The database column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" +// then the field will be ignored. +func RowToAddrOfStructByNameLax[T any](row CollectableRow) (*T, error) { + var value T + err := row.Scan(&namedStructRowLaxScanner{ptrToStruct: &value}) + return &value, err +} + +type namedStructRowLaxScanner struct { + ptrToStruct any +} + +func (rs *namedStructRowLaxScanner) ScanRow(rows Rows) error { + dst := rs.ptrToStruct + dstValue := reflect.ValueOf(dst) + if dstValue.Kind() != reflect.Ptr { + return fmt.Errorf("dst not a pointer") + } + + dstElemValue := dstValue.Elem() + scanTargets, err := rs.appendScanTargets(dstElemValue, nil, rows.FieldDescriptions()) + if err != nil { + return err + } + + for i, t := range scanTargets { + if t == nil { + return fmt.Errorf("struct doesn't have corresponding row field %s", rows.FieldDescriptions()[i].Name) + } + } + + return rows.Scan(scanTargets...) +} + +func (rs *namedStructRowLaxScanner) appendScanTargets(dstElemValue reflect.Value, scanTargets []any, fldDescs []pgconn.FieldDescription) ([]any, error) { + var err error + dstElemType := dstElemValue.Type() + + if scanTargets == nil { + scanTargets = make([]any, len(fldDescs)) + } + + for i := 0; i < dstElemType.NumField(); i++ { + sf := dstElemType.Field(i) + if sf.PkgPath != "" && !sf.Anonymous { + // Field is unexported, skip it. + continue + } + // Handle anoymous struct embedding, but do not try to handle embedded pointers. + if sf.Anonymous && sf.Type.Kind() == reflect.Struct { + scanTargets, err = rs.appendScanTargets(dstElemValue.Field(i), scanTargets, fldDescs) + if err != nil { + return nil, err + } + } else { + dbTag, dbTagPresent := sf.Tag.Lookup(structTagKey) + if dbTagPresent { + dbTag = strings.Split(dbTag, ",")[0] + } + if dbTag == "-" { + // Field is ignored, skip it. + continue + } + colName := dbTag + if !dbTagPresent { + colName = sf.Name + } + fpos := fieldPosByName(fldDescs, colName) + if fpos == -1 { + continue + } + scanTargets[fpos] = dstElemValue.Field(i).Addr().Interface() + } + } + + return scanTargets, err +} diff --git a/rows_test.go b/rows_test.go index ca9c3010..a5b6e942 100644 --- a/rows_test.go +++ b/rows_test.go @@ -635,3 +635,164 @@ insert into products (name, price) values // Fries: $5 // Soft Drink: $3 } + +func TestRowToStructByNameLax(t *testing.T) { + type person struct { + Last string + First string + Age int32 + Ignore bool `db:"-"` + } + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 'John' as first, 'Smith' as last, n as age from generate_series(0, 9) n`) + slice, err := pgx.CollectRows(rows, pgx.RowToStructByNameLax[person]) + assert.NoError(t, err) + + assert.Len(t, slice, 10) + for i := range slice { + assert.Equal(t, "Smith", slice[i].Last) + assert.Equal(t, "John", slice[i].First) + assert.EqualValues(t, i, slice[i].Age) + } + + // check missing fields in a returned row + rows, _ = conn.Query(ctx, `select 'John' as first, n as age from generate_series(0, 9) n`) + slice, err = pgx.CollectRows(rows, pgx.RowToStructByNameLax[person]) + assert.NoError(t, err) + + assert.Len(t, slice, 10) + for i := range slice { + assert.Equal(t, "John", slice[i].First) + assert.EqualValues(t, i, slice[i].Age) + } + + // check extra fields in a returned row + rows, _ = conn.Query(ctx, `select 'John' as first, 'Smith' as last, n as age, null as ignore from generate_series(0, 9) n`) + _, err = pgx.CollectRows(rows, pgx.RowToAddrOfStructByNameLax[person]) + assert.ErrorContains(t, err, "struct doesn't have corresponding row field ignore") + + // check missing fields in a destination struct + rows, _ = conn.Query(ctx, `select 'Smith' as last, 'D.' as middle, n as age from generate_series(0, 9) n`) + _, err = pgx.CollectRows(rows, pgx.RowToAddrOfStructByNameLax[person]) + assert.ErrorContains(t, err, "struct doesn't have corresponding row field middle") + + // check ignored fields in a destination struct + rows, _ = conn.Query(ctx, `select 'Smith' as last, n as age, null as ignore from generate_series(0, 9) n`) + _, err = pgx.CollectRows(rows, pgx.RowToAddrOfStructByNameLax[person]) + assert.ErrorContains(t, err, "struct doesn't have corresponding row field ignore") + }) +} + +func TestRowToStructByNameLaxEmbeddedStruct(t *testing.T) { + type Name struct { + Last string `db:"last_name"` + First string `db:"first_name"` + } + + type person struct { + Ignore bool `db:"-"` + Name + Age int32 + } + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 'John' as first_name, 'Smith' as last_name, n as age from generate_series(0, 9) n`) + slice, err := pgx.CollectRows(rows, pgx.RowToStructByNameLax[person]) + assert.NoError(t, err) + + assert.Len(t, slice, 10) + for i := range slice { + assert.Equal(t, "Smith", slice[i].Name.Last) + assert.Equal(t, "John", slice[i].Name.First) + assert.EqualValues(t, i, slice[i].Age) + } + + // check missing fields in a returned row + rows, _ = conn.Query(ctx, `select 'John' as first_name, n as age from generate_series(0, 9) n`) + slice, err = pgx.CollectRows(rows, pgx.RowToStructByNameLax[person]) + assert.NoError(t, err) + + assert.Len(t, slice, 10) + for i := range slice { + assert.Equal(t, "John", slice[i].Name.First) + assert.EqualValues(t, i, slice[i].Age) + } + + // check extra fields in a returned row + rows, _ = conn.Query(ctx, `select 'John' as first_name, 'Smith' as last_name, n as age, null as ignore from generate_series(0, 9) n`) + _, err = pgx.CollectRows(rows, pgx.RowToAddrOfStructByNameLax[person]) + assert.ErrorContains(t, err, "struct doesn't have corresponding row field ignore") + + // check missing fields in a destination struct + rows, _ = conn.Query(ctx, `select 'Smith' as last_name, 'D.' as middle_name, n as age from generate_series(0, 9) n`) + _, err = pgx.CollectRows(rows, pgx.RowToAddrOfStructByNameLax[person]) + assert.ErrorContains(t, err, "struct doesn't have corresponding row field middle_name") + + // check ignored fields in a destination struct + rows, _ = conn.Query(ctx, `select 'Smith' as last_name, n as age, null as ignore from generate_series(0, 9) n`) + _, err = pgx.CollectRows(rows, pgx.RowToAddrOfStructByNameLax[person]) + assert.ErrorContains(t, err, "struct doesn't have corresponding row field ignore") + }) +} + +func ExampleRowToStructByNameLax() { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + fmt.Printf("Unable to establish connection: %v", err) + return + } + + if conn.PgConn().ParameterStatus("crdb_version") != "" { + // Skip test / example when running on CockroachDB. Since an example can't be skipped fake success instead. + fmt.Println(`Cheeseburger: $10 +Fries: $5 +Soft Drink: $3`) + return + } + + // Setup example schema and data. + _, err = conn.Exec(ctx, ` +create temporary table products ( + id int primary key generated by default as identity, + name varchar(100) not null, + price int not null +); + +insert into products (name, price) values + ('Cheeseburger', 10), + ('Double Cheeseburger', 14), + ('Fries', 5), + ('Soft Drink', 3); +`) + if err != nil { + fmt.Printf("Unable to setup example schema and data: %v", err) + return + } + + type product struct { + ID int32 + Name string + Type string + Price int32 + } + + rows, _ := conn.Query(ctx, "select * from products where price < $1 order by price desc", 12) + products, err := pgx.CollectRows(rows, pgx.RowToStructByNameLax[product]) + if err != nil { + fmt.Printf("CollectRows error: %v", err) + return + } + + for _, p := range products { + fmt.Printf("%s: $%d\n", p.Name, p.Price) + } + + // Output: + // Cheeseburger: $10 + // Fries: $5 + // Soft Drink: $3 +}