mirror of https://github.com/jackc/pgx.git
Cache reflection analysis in RowToStructBy...
Modify the RowToStructByPos/Name functions to store the computed mapping of columns to struct field locations in a cache to reuse between calls. Because this computation can be expensive and the same few results will frequently be reused, caching these results provides a significant speedup. For positional mappings, we can key the cache by just the struct-type. However, for named mappings, the key must include a representation of the columns, in order, since different columns produce different mappings.pull/1991/head
parent
8db0f280fb
commit
ec98406207
288
rows.go
288
rows.go
|
@ -6,6 +6,7 @@ import (
|
|||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
|
@ -541,7 +542,7 @@ func (rs *mapRowScanner) ScanRow(rows Rows) error {
|
|||
// ignored.
|
||||
func RowToStructByPos[T any](row CollectableRow) (T, error) {
|
||||
var value T
|
||||
err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value})
|
||||
err := (&positionalStructRowScanner{ptrToStruct: &value}).ScanRow(row)
|
||||
return value, err
|
||||
}
|
||||
|
||||
|
@ -550,7 +551,7 @@ func RowToStructByPos[T any](row CollectableRow) (T, error) {
|
|||
// the field will be ignored.
|
||||
func RowToAddrOfStructByPos[T any](row CollectableRow) (*T, error) {
|
||||
var value T
|
||||
err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value})
|
||||
err := (&positionalStructRowScanner{ptrToStruct: &value}).ScanRow(row)
|
||||
return &value, err
|
||||
}
|
||||
|
||||
|
@ -558,46 +559,60 @@ type positionalStructRowScanner struct {
|
|||
ptrToStruct any
|
||||
}
|
||||
|
||||
func (rs *positionalStructRowScanner) ScanRow(rows Rows) error {
|
||||
dst := rs.ptrToStruct
|
||||
dstValue := reflect.ValueOf(dst)
|
||||
if dstValue.Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("dst not a pointer")
|
||||
func (rs *positionalStructRowScanner) ScanRow(rows CollectableRow) error {
|
||||
typ := reflect.TypeOf(rs.ptrToStruct).Elem()
|
||||
fields := lookupStructFields(typ)
|
||||
if len(rows.RawValues()) > len(fields) {
|
||||
return fmt.Errorf(
|
||||
"got %d values, but dst struct has only %d fields",
|
||||
len(rows.RawValues()),
|
||||
len(fields),
|
||||
)
|
||||
}
|
||||
|
||||
dstElemValue := dstValue.Elem()
|
||||
scanTargets := rs.appendScanTargets(dstElemValue, nil)
|
||||
|
||||
if len(rows.RawValues()) > len(scanTargets) {
|
||||
return fmt.Errorf("got %d values, but dst struct has only %d fields", len(rows.RawValues()), len(scanTargets))
|
||||
}
|
||||
|
||||
scanTargets := setupStructScanTargets(rs.ptrToStruct, fields)
|
||||
return rows.Scan(scanTargets...)
|
||||
}
|
||||
|
||||
func (rs *positionalStructRowScanner) appendScanTargets(dstElemValue reflect.Value, scanTargets []any) []any {
|
||||
dstElemType := dstElemValue.Type()
|
||||
// Map from reflect.Type -> []structRowField
|
||||
var positionalStructFieldMap sync.Map
|
||||
|
||||
if scanTargets == nil {
|
||||
scanTargets = make([]any, 0, dstElemType.NumField())
|
||||
func lookupStructFields(t reflect.Type) []structRowField {
|
||||
if cached, ok := positionalStructFieldMap.Load(t); ok {
|
||||
return cached.([]structRowField)
|
||||
}
|
||||
|
||||
for i := 0; i < dstElemType.NumField(); i++ {
|
||||
sf := dstElemType.Field(i)
|
||||
fieldStack := make([]int, 0, 1)
|
||||
fields := computeStructFields(t, make([]structRowField, 0, t.NumField()), &fieldStack)
|
||||
fieldsIface, _ := positionalStructFieldMap.LoadOrStore(t, fields)
|
||||
return fieldsIface.([]structRowField)
|
||||
}
|
||||
|
||||
func computeStructFields(
|
||||
t reflect.Type,
|
||||
fields []structRowField,
|
||||
fieldStack *[]int,
|
||||
) []structRowField {
|
||||
tail := len(*fieldStack)
|
||||
*fieldStack = append(*fieldStack, 0)
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
sf := t.Field(i)
|
||||
(*fieldStack)[tail] = i
|
||||
// Handle anonymous struct embedding, but do not try to handle embedded pointers.
|
||||
if sf.Anonymous && sf.Type.Kind() == reflect.Struct {
|
||||
scanTargets = rs.appendScanTargets(dstElemValue.Field(i), scanTargets)
|
||||
fields = computeStructFields(sf.Type, fields, fieldStack)
|
||||
} else if sf.PkgPath == "" {
|
||||
dbTag, _ := sf.Tag.Lookup(structTagKey)
|
||||
if dbTag == "-" {
|
||||
// Field is ignored, skip it.
|
||||
continue
|
||||
}
|
||||
scanTargets = append(scanTargets, dstElemValue.Field(i).Addr().Interface())
|
||||
fields = append(fields, structRowField{
|
||||
path: append([]int(nil), *fieldStack...),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return scanTargets
|
||||
*fieldStack = (*fieldStack)[:tail]
|
||||
return fields
|
||||
}
|
||||
|
||||
// RowToStructByName returns a T scanned from row. T must be a struct. T must have the same number of named public
|
||||
|
@ -605,7 +620,7 @@ func (rs *positionalStructRowScanner) appendScanTargets(dstElemValue reflect.Val
|
|||
// column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored.
|
||||
func RowToStructByName[T any](row CollectableRow) (T, error) {
|
||||
var value T
|
||||
err := row.Scan(&namedStructRowScanner{ptrToStruct: &value})
|
||||
err := (&namedStructRowScanner{ptrToStruct: &value}).ScanRow(row)
|
||||
return value, err
|
||||
}
|
||||
|
||||
|
@ -615,7 +630,7 @@ func RowToStructByName[T any](row CollectableRow) (T, error) {
|
|||
// then the field will be ignored.
|
||||
func RowToAddrOfStructByName[T any](row CollectableRow) (*T, error) {
|
||||
var value T
|
||||
err := row.Scan(&namedStructRowScanner{ptrToStruct: &value})
|
||||
err := (&namedStructRowScanner{ptrToStruct: &value}).ScanRow(row)
|
||||
return &value, err
|
||||
}
|
||||
|
||||
|
@ -624,7 +639,7 @@ func RowToAddrOfStructByName[T any](row CollectableRow) (*T, error) {
|
|||
// column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored.
|
||||
func RowToStructByNameLax[T any](row CollectableRow) (T, error) {
|
||||
var value T
|
||||
err := row.Scan(&namedStructRowScanner{ptrToStruct: &value, lax: true})
|
||||
err := (&namedStructRowScanner{ptrToStruct: &value, lax: true}).ScanRow(row)
|
||||
return value, err
|
||||
}
|
||||
|
||||
|
@ -634,7 +649,7 @@ func RowToStructByNameLax[T any](row CollectableRow) (T, error) {
|
|||
// then the field will be ignored.
|
||||
func RowToAddrOfStructByNameLax[T any](row CollectableRow) (*T, error) {
|
||||
var value T
|
||||
err := row.Scan(&namedStructRowScanner{ptrToStruct: &value, lax: true})
|
||||
err := (&namedStructRowScanner{ptrToStruct: &value, lax: true}).ScanRow(row)
|
||||
return &value, err
|
||||
}
|
||||
|
||||
|
@ -643,28 +658,154 @@ type namedStructRowScanner struct {
|
|||
lax bool
|
||||
}
|
||||
|
||||
func (rs *namedStructRowScanner) ScanRow(rows Rows) error {
|
||||
dst := rs.ptrToStruct
|
||||
dstValue := reflect.ValueOf(dst)
|
||||
if dstValue.Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("dst not a pointer")
|
||||
}
|
||||
|
||||
dstElemValue := dstValue.Elem()
|
||||
scanTargets, err := rs.appendScanTargets(dstElemValue, nil, rows.FieldDescriptions())
|
||||
func (rs *namedStructRowScanner) ScanRow(rows CollectableRow) error {
|
||||
typ := reflect.TypeOf(rs.ptrToStruct).Elem()
|
||||
fldDescs := rows.FieldDescriptions()
|
||||
namedStructFields, err := lookupNamedStructFields(typ, fldDescs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i, t := range scanTargets {
|
||||
if t == nil {
|
||||
return fmt.Errorf("struct doesn't have corresponding row field %s", rows.FieldDescriptions()[i].Name)
|
||||
if rs.lax && namedStructFields.missingField != "" {
|
||||
return fmt.Errorf("cannot find field %s in returned row", namedStructFields.missingField)
|
||||
}
|
||||
}
|
||||
|
||||
fields := namedStructFields.fields
|
||||
scanTargets := setupStructScanTargets(rs.ptrToStruct, fields)
|
||||
return rows.Scan(scanTargets...)
|
||||
}
|
||||
|
||||
// Map from namedStructFieldMap -> *namedStructFields
|
||||
var namedStructFieldMap sync.Map
|
||||
|
||||
type namedStructFieldsKey struct {
|
||||
t reflect.Type
|
||||
colNames string
|
||||
}
|
||||
|
||||
type namedStructFields struct {
|
||||
fields []structRowField
|
||||
// missingField is the first field from the struct without a corresponding row field.
|
||||
// This is used to construct the correct error message for non-lax queries.
|
||||
missingField string
|
||||
}
|
||||
|
||||
func lookupNamedStructFields(
|
||||
t reflect.Type,
|
||||
fldDescs []pgconn.FieldDescription,
|
||||
) (*namedStructFields, error) {
|
||||
key := namedStructFieldsKey{
|
||||
t: t,
|
||||
colNames: joinFieldNames(fldDescs),
|
||||
}
|
||||
if cached, ok := namedStructFieldMap.Load(key); ok {
|
||||
return cached.(*namedStructFields), nil
|
||||
}
|
||||
|
||||
// We could probably do two-levels of caching, where we compute the key -> fields mapping
|
||||
// for a type only once, cache it by type, then use that to compute the column -> fields
|
||||
// mapping for a given set of columns.
|
||||
fieldStack := make([]int, 0, 1)
|
||||
fields, missingField := computeNamedStructFields(
|
||||
fldDescs,
|
||||
t,
|
||||
make([]structRowField, len(fldDescs)),
|
||||
&fieldStack,
|
||||
)
|
||||
for i, f := range fields {
|
||||
if f.path == nil {
|
||||
return nil, fmt.Errorf(
|
||||
"struct doesn't have corresponding row field %s",
|
||||
fldDescs[i].Name,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fieldsIface, _ := namedStructFieldMap.LoadOrStore(
|
||||
key,
|
||||
&namedStructFields{fields: fields, missingField: missingField},
|
||||
)
|
||||
return fieldsIface.(*namedStructFields), nil
|
||||
}
|
||||
|
||||
func joinFieldNames(fldDescs []pgconn.FieldDescription) string {
|
||||
switch len(fldDescs) {
|
||||
case 0:
|
||||
return ""
|
||||
case 1:
|
||||
return fldDescs[0].Name
|
||||
}
|
||||
|
||||
totalSize := len(fldDescs) - 1 // Space for separator bytes.
|
||||
for _, d := range fldDescs {
|
||||
totalSize += len(d.Name)
|
||||
}
|
||||
var b strings.Builder
|
||||
b.Grow(totalSize)
|
||||
b.WriteString(fldDescs[0].Name)
|
||||
for _, d := range fldDescs[1:] {
|
||||
b.WriteByte(0) // Join with NUL byte as it's (presumably) not a valid column character.
|
||||
b.WriteString(d.Name)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func computeNamedStructFields(
|
||||
fldDescs []pgconn.FieldDescription,
|
||||
t reflect.Type,
|
||||
fields []structRowField,
|
||||
fieldStack *[]int,
|
||||
) ([]structRowField, string) {
|
||||
var missingField string
|
||||
tail := len(*fieldStack)
|
||||
*fieldStack = append(*fieldStack, 0)
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
sf := t.Field(i)
|
||||
(*fieldStack)[tail] = i
|
||||
if sf.PkgPath != "" && !sf.Anonymous {
|
||||
// Field is unexported, skip it.
|
||||
continue
|
||||
}
|
||||
// Handle anonymous struct embedding, but do not try to handle embedded pointers.
|
||||
if sf.Anonymous && sf.Type.Kind() == reflect.Struct {
|
||||
var missingSubField string
|
||||
fields, missingSubField = computeNamedStructFields(
|
||||
fldDescs,
|
||||
sf.Type,
|
||||
fields,
|
||||
fieldStack,
|
||||
)
|
||||
if missingField == "" {
|
||||
missingField = missingSubField
|
||||
}
|
||||
} else {
|
||||
dbTag, dbTagPresent := sf.Tag.Lookup(structTagKey)
|
||||
if dbTagPresent {
|
||||
dbTag, _, _ = strings.Cut(dbTag, ",")
|
||||
}
|
||||
if dbTag == "-" {
|
||||
// Field is ignored, skip it.
|
||||
continue
|
||||
}
|
||||
colName := dbTag
|
||||
if !dbTagPresent {
|
||||
colName = sf.Name
|
||||
}
|
||||
fpos := fieldPosByName(fldDescs, colName)
|
||||
if fpos == -1 {
|
||||
if missingField == "" {
|
||||
missingField = colName
|
||||
}
|
||||
continue
|
||||
}
|
||||
fields[fpos] = structRowField{
|
||||
path: append([]int(nil), *fieldStack...),
|
||||
}
|
||||
}
|
||||
}
|
||||
*fieldStack = (*fieldStack)[:tail]
|
||||
|
||||
return fields, missingField
|
||||
}
|
||||
|
||||
const structTagKey = "db"
|
||||
|
||||
func fieldPosByName(fldDescs []pgconn.FieldDescription, field string) (i int) {
|
||||
|
@ -682,52 +823,21 @@ func fieldPosByName(fldDescs []pgconn.FieldDescription, field string) (i int) {
|
|||
return
|
||||
}
|
||||
|
||||
func (rs *namedStructRowScanner) appendScanTargets(dstElemValue reflect.Value, scanTargets []any, fldDescs []pgconn.FieldDescription) ([]any, error) {
|
||||
var err error
|
||||
dstElemType := dstElemValue.Type()
|
||||
|
||||
if scanTargets == nil {
|
||||
scanTargets = make([]any, len(fldDescs))
|
||||
// structRowField describes a field of a struct.
|
||||
//
|
||||
// TODO: It would be a bit more efficient to track the path using the pointer
|
||||
// offset within the (outermost) struct and use unsafe.Pointer arithmetic to
|
||||
// construct references when scanning rows. However, it's not clear it's worth
|
||||
// using unsafe for this.
|
||||
type structRowField struct {
|
||||
path []int
|
||||
}
|
||||
|
||||
for i := 0; i < dstElemType.NumField(); i++ {
|
||||
sf := dstElemType.Field(i)
|
||||
if sf.PkgPath != "" && !sf.Anonymous {
|
||||
// Field is unexported, skip it.
|
||||
continue
|
||||
func setupStructScanTargets(receiver any, fields []structRowField) []any {
|
||||
scanTargets := make([]any, len(fields))
|
||||
v := reflect.ValueOf(receiver).Elem()
|
||||
for i, f := range fields {
|
||||
scanTargets[i] = v.FieldByIndex(f.path).Addr().Interface()
|
||||
}
|
||||
// Handle anonymous struct embedding, but do not try to handle embedded pointers.
|
||||
if sf.Anonymous && sf.Type.Kind() == reflect.Struct {
|
||||
scanTargets, err = rs.appendScanTargets(dstElemValue.Field(i), scanTargets, fldDescs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
dbTag, dbTagPresent := sf.Tag.Lookup(structTagKey)
|
||||
if dbTagPresent {
|
||||
dbTag, _, _ = strings.Cut(dbTag, ",")
|
||||
}
|
||||
if dbTag == "-" {
|
||||
// Field is ignored, skip it.
|
||||
continue
|
||||
}
|
||||
colName := dbTag
|
||||
if !dbTagPresent {
|
||||
colName = sf.Name
|
||||
}
|
||||
fpos := fieldPosByName(fldDescs, colName)
|
||||
if fpos == -1 {
|
||||
if rs.lax {
|
||||
continue
|
||||
}
|
||||
return nil, fmt.Errorf("cannot find field %s in returned row", colName)
|
||||
}
|
||||
if fpos >= len(scanTargets) && !rs.lax {
|
||||
return nil, fmt.Errorf("cannot find field %s in returned row", colName)
|
||||
}
|
||||
scanTargets[fpos] = dstElemValue.Field(i).Addr().Interface()
|
||||
}
|
||||
}
|
||||
|
||||
return scanTargets, err
|
||||
return scanTargets
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue