diff --git a/rows.go b/rows.go index 78ef5326..4720330c 100644 --- a/rows.go +++ b/rows.go @@ -6,6 +6,7 @@ import ( "fmt" "reflect" "strings" + "sync" "time" "github.com/jackc/pgx/v5/pgconn" @@ -541,7 +542,7 @@ func (rs *mapRowScanner) ScanRow(rows Rows) error { // ignored. func RowToStructByPos[T any](row CollectableRow) (T, error) { var value T - err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value}) + err := (&positionalStructRowScanner{ptrToStruct: &value}).ScanRow(row) return value, err } @@ -550,7 +551,7 @@ func RowToStructByPos[T any](row CollectableRow) (T, error) { // the field will be ignored. func RowToAddrOfStructByPos[T any](row CollectableRow) (*T, error) { var value T - err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value}) + err := (&positionalStructRowScanner{ptrToStruct: &value}).ScanRow(row) return &value, err } @@ -558,46 +559,60 @@ type positionalStructRowScanner struct { ptrToStruct any } -func (rs *positionalStructRowScanner) ScanRow(rows Rows) error { - dst := rs.ptrToStruct - dstValue := reflect.ValueOf(dst) - if dstValue.Kind() != reflect.Ptr { - return fmt.Errorf("dst not a pointer") +func (rs *positionalStructRowScanner) ScanRow(rows CollectableRow) error { + typ := reflect.TypeOf(rs.ptrToStruct).Elem() + fields := lookupStructFields(typ) + if len(rows.RawValues()) > len(fields) { + return fmt.Errorf( + "got %d values, but dst struct has only %d fields", + len(rows.RawValues()), + len(fields), + ) } - - dstElemValue := dstValue.Elem() - scanTargets := rs.appendScanTargets(dstElemValue, nil) - - if len(rows.RawValues()) > len(scanTargets) { - return fmt.Errorf("got %d values, but dst struct has only %d fields", len(rows.RawValues()), len(scanTargets)) - } - + scanTargets := setupStructScanTargets(rs.ptrToStruct, fields) return rows.Scan(scanTargets...) } -func (rs *positionalStructRowScanner) appendScanTargets(dstElemValue reflect.Value, scanTargets []any) []any { - dstElemType := dstElemValue.Type() +// Map from reflect.Type -> []structRowField +var positionalStructFieldMap sync.Map - if scanTargets == nil { - scanTargets = make([]any, 0, dstElemType.NumField()) +func lookupStructFields(t reflect.Type) []structRowField { + if cached, ok := positionalStructFieldMap.Load(t); ok { + return cached.([]structRowField) } - for i := 0; i < dstElemType.NumField(); i++ { - sf := dstElemType.Field(i) + fieldStack := make([]int, 0, 1) + fields := computeStructFields(t, make([]structRowField, 0, t.NumField()), &fieldStack) + fieldsIface, _ := positionalStructFieldMap.LoadOrStore(t, fields) + return fieldsIface.([]structRowField) +} + +func computeStructFields( + t reflect.Type, + fields []structRowField, + fieldStack *[]int, +) []structRowField { + tail := len(*fieldStack) + *fieldStack = append(*fieldStack, 0) + for i := 0; i < t.NumField(); i++ { + sf := t.Field(i) + (*fieldStack)[tail] = i // Handle anonymous struct embedding, but do not try to handle embedded pointers. if sf.Anonymous && sf.Type.Kind() == reflect.Struct { - scanTargets = rs.appendScanTargets(dstElemValue.Field(i), scanTargets) + fields = computeStructFields(sf.Type, fields, fieldStack) } else if sf.PkgPath == "" { dbTag, _ := sf.Tag.Lookup(structTagKey) if dbTag == "-" { // Field is ignored, skip it. continue } - scanTargets = append(scanTargets, dstElemValue.Field(i).Addr().Interface()) + fields = append(fields, structRowField{ + path: append([]int(nil), *fieldStack...), + }) } } - - return scanTargets + *fieldStack = (*fieldStack)[:tail] + return fields } // RowToStructByName returns a T scanned from row. T must be a struct. T must have the same number of named public @@ -605,7 +620,7 @@ func (rs *positionalStructRowScanner) appendScanTargets(dstElemValue reflect.Val // column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored. func RowToStructByName[T any](row CollectableRow) (T, error) { var value T - err := row.Scan(&namedStructRowScanner{ptrToStruct: &value}) + err := (&namedStructRowScanner{ptrToStruct: &value}).ScanRow(row) return value, err } @@ -615,7 +630,7 @@ func RowToStructByName[T any](row CollectableRow) (T, error) { // then the field will be ignored. func RowToAddrOfStructByName[T any](row CollectableRow) (*T, error) { var value T - err := row.Scan(&namedStructRowScanner{ptrToStruct: &value}) + err := (&namedStructRowScanner{ptrToStruct: &value}).ScanRow(row) return &value, err } @@ -624,7 +639,7 @@ func RowToAddrOfStructByName[T any](row CollectableRow) (*T, error) { // column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored. func RowToStructByNameLax[T any](row CollectableRow) (T, error) { var value T - err := row.Scan(&namedStructRowScanner{ptrToStruct: &value, lax: true}) + err := (&namedStructRowScanner{ptrToStruct: &value, lax: true}).ScanRow(row) return value, err } @@ -634,7 +649,7 @@ func RowToStructByNameLax[T any](row CollectableRow) (T, error) { // then the field will be ignored. func RowToAddrOfStructByNameLax[T any](row CollectableRow) (*T, error) { var value T - err := row.Scan(&namedStructRowScanner{ptrToStruct: &value, lax: true}) + err := (&namedStructRowScanner{ptrToStruct: &value, lax: true}).ScanRow(row) return &value, err } @@ -643,26 +658,152 @@ type namedStructRowScanner struct { lax bool } -func (rs *namedStructRowScanner) ScanRow(rows Rows) error { - dst := rs.ptrToStruct - dstValue := reflect.ValueOf(dst) - if dstValue.Kind() != reflect.Ptr { - return fmt.Errorf("dst not a pointer") - } - - dstElemValue := dstValue.Elem() - scanTargets, err := rs.appendScanTargets(dstElemValue, nil, rows.FieldDescriptions()) +func (rs *namedStructRowScanner) ScanRow(rows CollectableRow) error { + typ := reflect.TypeOf(rs.ptrToStruct).Elem() + fldDescs := rows.FieldDescriptions() + namedStructFields, err := lookupNamedStructFields(typ, fldDescs) if err != nil { return err } + if rs.lax && namedStructFields.missingField != "" { + return fmt.Errorf("cannot find field %s in returned row", namedStructFields.missingField) + } + fields := namedStructFields.fields + scanTargets := setupStructScanTargets(rs.ptrToStruct, fields) + return rows.Scan(scanTargets...) +} - for i, t := range scanTargets { - if t == nil { - return fmt.Errorf("struct doesn't have corresponding row field %s", rows.FieldDescriptions()[i].Name) +// Map from namedStructFieldMap -> *namedStructFields +var namedStructFieldMap sync.Map + +type namedStructFieldsKey struct { + t reflect.Type + colNames string +} + +type namedStructFields struct { + fields []structRowField + // missingField is the first field from the struct without a corresponding row field. + // This is used to construct the correct error message for non-lax queries. + missingField string +} + +func lookupNamedStructFields( + t reflect.Type, + fldDescs []pgconn.FieldDescription, +) (*namedStructFields, error) { + key := namedStructFieldsKey{ + t: t, + colNames: joinFieldNames(fldDescs), + } + if cached, ok := namedStructFieldMap.Load(key); ok { + return cached.(*namedStructFields), nil + } + + // We could probably do two-levels of caching, where we compute the key -> fields mapping + // for a type only once, cache it by type, then use that to compute the column -> fields + // mapping for a given set of columns. + fieldStack := make([]int, 0, 1) + fields, missingField := computeNamedStructFields( + fldDescs, + t, + make([]structRowField, len(fldDescs)), + &fieldStack, + ) + for i, f := range fields { + if f.path == nil { + return nil, fmt.Errorf( + "struct doesn't have corresponding row field %s", + fldDescs[i].Name, + ) } } - return rows.Scan(scanTargets...) + fieldsIface, _ := namedStructFieldMap.LoadOrStore( + key, + &namedStructFields{fields: fields, missingField: missingField}, + ) + return fieldsIface.(*namedStructFields), nil +} + +func joinFieldNames(fldDescs []pgconn.FieldDescription) string { + switch len(fldDescs) { + case 0: + return "" + case 1: + return fldDescs[0].Name + } + + totalSize := len(fldDescs) - 1 // Space for separator bytes. + for _, d := range fldDescs { + totalSize += len(d.Name) + } + var b strings.Builder + b.Grow(totalSize) + b.WriteString(fldDescs[0].Name) + for _, d := range fldDescs[1:] { + b.WriteByte(0) // Join with NUL byte as it's (presumably) not a valid column character. + b.WriteString(d.Name) + } + return b.String() +} + +func computeNamedStructFields( + fldDescs []pgconn.FieldDescription, + t reflect.Type, + fields []structRowField, + fieldStack *[]int, +) ([]structRowField, string) { + var missingField string + tail := len(*fieldStack) + *fieldStack = append(*fieldStack, 0) + for i := 0; i < t.NumField(); i++ { + sf := t.Field(i) + (*fieldStack)[tail] = i + if sf.PkgPath != "" && !sf.Anonymous { + // Field is unexported, skip it. + continue + } + // Handle anonymous struct embedding, but do not try to handle embedded pointers. + if sf.Anonymous && sf.Type.Kind() == reflect.Struct { + var missingSubField string + fields, missingSubField = computeNamedStructFields( + fldDescs, + sf.Type, + fields, + fieldStack, + ) + if missingField == "" { + missingField = missingSubField + } + } else { + dbTag, dbTagPresent := sf.Tag.Lookup(structTagKey) + if dbTagPresent { + dbTag, _, _ = strings.Cut(dbTag, ",") + } + if dbTag == "-" { + // Field is ignored, skip it. + continue + } + colName := dbTag + if !dbTagPresent { + colName = sf.Name + } + fpos := fieldPosByName(fldDescs, colName) + if fpos == -1 { + if missingField == "" { + missingField = colName + } + continue + } + fields[fpos] = structRowField{ + path: append([]int(nil), *fieldStack...), + } + } + } + *fieldStack = (*fieldStack)[:tail] + + return fields, missingField } const structTagKey = "db" @@ -682,52 +823,21 @@ func fieldPosByName(fldDescs []pgconn.FieldDescription, field string) (i int) { return } -func (rs *namedStructRowScanner) appendScanTargets(dstElemValue reflect.Value, scanTargets []any, fldDescs []pgconn.FieldDescription) ([]any, error) { - var err error - dstElemType := dstElemValue.Type() - - if scanTargets == nil { - scanTargets = make([]any, len(fldDescs)) - } - - for i := 0; i < dstElemType.NumField(); i++ { - sf := dstElemType.Field(i) - if sf.PkgPath != "" && !sf.Anonymous { - // Field is unexported, skip it. - continue - } - // Handle anonymous struct embedding, but do not try to handle embedded pointers. - if sf.Anonymous && sf.Type.Kind() == reflect.Struct { - scanTargets, err = rs.appendScanTargets(dstElemValue.Field(i), scanTargets, fldDescs) - if err != nil { - return nil, err - } - } else { - dbTag, dbTagPresent := sf.Tag.Lookup(structTagKey) - if dbTagPresent { - dbTag, _, _ = strings.Cut(dbTag, ",") - } - if dbTag == "-" { - // Field is ignored, skip it. - continue - } - colName := dbTag - if !dbTagPresent { - colName = sf.Name - } - fpos := fieldPosByName(fldDescs, colName) - if fpos == -1 { - if rs.lax { - continue - } - return nil, fmt.Errorf("cannot find field %s in returned row", colName) - } - if fpos >= len(scanTargets) && !rs.lax { - return nil, fmt.Errorf("cannot find field %s in returned row", colName) - } - scanTargets[fpos] = dstElemValue.Field(i).Addr().Interface() - } - } - - return scanTargets, err +// structRowField describes a field of a struct. +// +// TODO: It would be a bit more efficient to track the path using the pointer +// offset within the (outermost) struct and use unsafe.Pointer arithmetic to +// construct references when scanning rows. However, it's not clear it's worth +// using unsafe for this. +type structRowField struct { + path []int +} + +func setupStructScanTargets(receiver any, fields []structRowField) []any { + scanTargets := make([]any, len(fields)) + v := reflect.ValueOf(receiver).Elem() + for i, f := range fields { + scanTargets[i] = v.FieldByIndex(f.path).Addr().Interface() + } + return scanTargets }