From ec9840620789766dd065077d314b92cbbaf02ec8 Mon Sep 17 00:00:00 2001
From: Zach Olstein <zach.olstein408@gmail.com>
Date: Sun, 24 Mar 2024 17:28:31 -0700
Subject: [PATCH] Cache reflection analysis in RowToStructBy...

Modify the RowToStructByPos/Name functions to store the computed mapping
of columns to struct field locations in a cache to reuse between calls.
Because this computation can be expensive and the same few results will
frequently be reused, caching these results provides a significant
speedup.

For positional mappings, we can key the cache by just the struct-type.
However, for named mappings, the key must include a representation of
the columns, in order, since different columns produce different
mappings.
---
 rows.go | 290 ++++++++++++++++++++++++++++++++++++++------------------
 1 file changed, 200 insertions(+), 90 deletions(-)

diff --git a/rows.go b/rows.go
index 78ef5326..4720330c 100644
--- a/rows.go
+++ b/rows.go
@@ -6,6 +6,7 @@ import (
 	"fmt"
 	"reflect"
 	"strings"
+	"sync"
 	"time"
 
 	"github.com/jackc/pgx/v5/pgconn"
@@ -541,7 +542,7 @@ func (rs *mapRowScanner) ScanRow(rows Rows) error {
 // ignored.
 func RowToStructByPos[T any](row CollectableRow) (T, error) {
 	var value T
-	err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value})
+	err := (&positionalStructRowScanner{ptrToStruct: &value}).ScanRow(row)
 	return value, err
 }
 
@@ -550,7 +551,7 @@ func RowToStructByPos[T any](row CollectableRow) (T, error) {
 // the field will be ignored.
 func RowToAddrOfStructByPos[T any](row CollectableRow) (*T, error) {
 	var value T
-	err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value})
+	err := (&positionalStructRowScanner{ptrToStruct: &value}).ScanRow(row)
 	return &value, err
 }
 
@@ -558,46 +559,60 @@ type positionalStructRowScanner struct {
 	ptrToStruct any
 }
 
-func (rs *positionalStructRowScanner) ScanRow(rows Rows) error {
-	dst := rs.ptrToStruct
-	dstValue := reflect.ValueOf(dst)
-	if dstValue.Kind() != reflect.Ptr {
-		return fmt.Errorf("dst not a pointer")
+func (rs *positionalStructRowScanner) ScanRow(rows CollectableRow) error {
+	typ := reflect.TypeOf(rs.ptrToStruct).Elem()
+	fields := lookupStructFields(typ)
+	if len(rows.RawValues()) > len(fields) {
+		return fmt.Errorf(
+			"got %d values, but dst struct has only %d fields",
+			len(rows.RawValues()),
+			len(fields),
+		)
 	}
-
-	dstElemValue := dstValue.Elem()
-	scanTargets := rs.appendScanTargets(dstElemValue, nil)
-
-	if len(rows.RawValues()) > len(scanTargets) {
-		return fmt.Errorf("got %d values, but dst struct has only %d fields", len(rows.RawValues()), len(scanTargets))
-	}
-
+	scanTargets := setupStructScanTargets(rs.ptrToStruct, fields)
 	return rows.Scan(scanTargets...)
 }
 
-func (rs *positionalStructRowScanner) appendScanTargets(dstElemValue reflect.Value, scanTargets []any) []any {
-	dstElemType := dstElemValue.Type()
+// Map from reflect.Type -> []structRowField
+var positionalStructFieldMap sync.Map
 
-	if scanTargets == nil {
-		scanTargets = make([]any, 0, dstElemType.NumField())
+func lookupStructFields(t reflect.Type) []structRowField {
+	if cached, ok := positionalStructFieldMap.Load(t); ok {
+		return cached.([]structRowField)
 	}
 
-	for i := 0; i < dstElemType.NumField(); i++ {
-		sf := dstElemType.Field(i)
+	fieldStack := make([]int, 0, 1)
+	fields := computeStructFields(t, make([]structRowField, 0, t.NumField()), &fieldStack)
+	fieldsIface, _ := positionalStructFieldMap.LoadOrStore(t, fields)
+	return fieldsIface.([]structRowField)
+}
+
+func computeStructFields(
+	t reflect.Type,
+	fields []structRowField,
+	fieldStack *[]int,
+) []structRowField {
+	tail := len(*fieldStack)
+	*fieldStack = append(*fieldStack, 0)
+	for i := 0; i < t.NumField(); i++ {
+		sf := t.Field(i)
+		(*fieldStack)[tail] = i
 		// Handle anonymous struct embedding, but do not try to handle embedded pointers.
 		if sf.Anonymous && sf.Type.Kind() == reflect.Struct {
-			scanTargets = rs.appendScanTargets(dstElemValue.Field(i), scanTargets)
+			fields = computeStructFields(sf.Type, fields, fieldStack)
 		} else if sf.PkgPath == "" {
 			dbTag, _ := sf.Tag.Lookup(structTagKey)
 			if dbTag == "-" {
 				// Field is ignored, skip it.
 				continue
 			}
-			scanTargets = append(scanTargets, dstElemValue.Field(i).Addr().Interface())
+			fields = append(fields, structRowField{
+				path: append([]int(nil), *fieldStack...),
+			})
 		}
 	}
-
-	return scanTargets
+	*fieldStack = (*fieldStack)[:tail]
+	return fields
 }
 
 // RowToStructByName returns a T scanned from row. T must be a struct. T must have the same number of named public
@@ -605,7 +620,7 @@ func (rs *positionalStructRowScanner) appendScanTargets(dstElemValue reflect.Val
 // column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored.
 func RowToStructByName[T any](row CollectableRow) (T, error) {
 	var value T
-	err := row.Scan(&namedStructRowScanner{ptrToStruct: &value})
+	err := (&namedStructRowScanner{ptrToStruct: &value}).ScanRow(row)
 	return value, err
 }
 
@@ -615,7 +630,7 @@ func RowToStructByName[T any](row CollectableRow) (T, error) {
 // then the field will be ignored.
 func RowToAddrOfStructByName[T any](row CollectableRow) (*T, error) {
 	var value T
-	err := row.Scan(&namedStructRowScanner{ptrToStruct: &value})
+	err := (&namedStructRowScanner{ptrToStruct: &value}).ScanRow(row)
 	return &value, err
 }
 
@@ -624,7 +639,7 @@ func RowToAddrOfStructByName[T any](row CollectableRow) (*T, error) {
 // column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored.
 func RowToStructByNameLax[T any](row CollectableRow) (T, error) {
 	var value T
-	err := row.Scan(&namedStructRowScanner{ptrToStruct: &value, lax: true})
+	err := (&namedStructRowScanner{ptrToStruct: &value, lax: true}).ScanRow(row)
 	return value, err
 }
 
@@ -634,7 +649,7 @@ func RowToStructByNameLax[T any](row CollectableRow) (T, error) {
 // then the field will be ignored.
 func RowToAddrOfStructByNameLax[T any](row CollectableRow) (*T, error) {
 	var value T
-	err := row.Scan(&namedStructRowScanner{ptrToStruct: &value, lax: true})
+	err := (&namedStructRowScanner{ptrToStruct: &value, lax: true}).ScanRow(row)
 	return &value, err
 }
 
@@ -643,26 +658,152 @@ type namedStructRowScanner struct {
 	lax         bool
 }
 
-func (rs *namedStructRowScanner) ScanRow(rows Rows) error {
-	dst := rs.ptrToStruct
-	dstValue := reflect.ValueOf(dst)
-	if dstValue.Kind() != reflect.Ptr {
-		return fmt.Errorf("dst not a pointer")
-	}
-
-	dstElemValue := dstValue.Elem()
-	scanTargets, err := rs.appendScanTargets(dstElemValue, nil, rows.FieldDescriptions())
+func (rs *namedStructRowScanner) ScanRow(rows CollectableRow) error {
+	typ := reflect.TypeOf(rs.ptrToStruct).Elem()
+	fldDescs := rows.FieldDescriptions()
+	namedStructFields, err := lookupNamedStructFields(typ, fldDescs)
 	if err != nil {
 		return err
 	}
+	if rs.lax && namedStructFields.missingField != "" {
+		return fmt.Errorf("cannot find field %s in returned row", namedStructFields.missingField)
+	}
+	fields := namedStructFields.fields
+	scanTargets := setupStructScanTargets(rs.ptrToStruct, fields)
+	return rows.Scan(scanTargets...)
+}
 
-	for i, t := range scanTargets {
-		if t == nil {
-			return fmt.Errorf("struct doesn't have corresponding row field %s", rows.FieldDescriptions()[i].Name)
+// Map from namedStructFieldMap -> *namedStructFields
+var namedStructFieldMap sync.Map
+
+type namedStructFieldsKey struct {
+	t        reflect.Type
+	colNames string
+}
+
+type namedStructFields struct {
+	fields []structRowField
+	// missingField is the first field from the struct without a corresponding row field.
+	// This is used to construct the correct error message for non-lax queries.
+	missingField string
+}
+
+func lookupNamedStructFields(
+	t reflect.Type,
+	fldDescs []pgconn.FieldDescription,
+) (*namedStructFields, error) {
+	key := namedStructFieldsKey{
+		t:        t,
+		colNames: joinFieldNames(fldDescs),
+	}
+	if cached, ok := namedStructFieldMap.Load(key); ok {
+		return cached.(*namedStructFields), nil
+	}
+
+	// We could probably do two-levels of caching, where we compute the key -> fields mapping
+	// for a type only once, cache it by type, then use that to compute the column -> fields
+	// mapping for a given set of columns.
+	fieldStack := make([]int, 0, 1)
+	fields, missingField := computeNamedStructFields(
+		fldDescs,
+		t,
+		make([]structRowField, len(fldDescs)),
+		&fieldStack,
+	)
+	for i, f := range fields {
+		if f.path == nil {
+			return nil, fmt.Errorf(
+				"struct doesn't have corresponding row field %s",
+				fldDescs[i].Name,
+			)
 		}
 	}
 
-	return rows.Scan(scanTargets...)
+	fieldsIface, _ := namedStructFieldMap.LoadOrStore(
+		key,
+		&namedStructFields{fields: fields, missingField: missingField},
+	)
+	return fieldsIface.(*namedStructFields), nil
+}
+
+func joinFieldNames(fldDescs []pgconn.FieldDescription) string {
+	switch len(fldDescs) {
+	case 0:
+		return ""
+	case 1:
+		return fldDescs[0].Name
+	}
+
+	totalSize := len(fldDescs) - 1 // Space for separator bytes.
+	for _, d := range fldDescs {
+		totalSize += len(d.Name)
+	}
+	var b strings.Builder
+	b.Grow(totalSize)
+	b.WriteString(fldDescs[0].Name)
+	for _, d := range fldDescs[1:] {
+		b.WriteByte(0) // Join with NUL byte as it's (presumably) not a valid column character.
+		b.WriteString(d.Name)
+	}
+	return b.String()
+}
+
+func computeNamedStructFields(
+	fldDescs []pgconn.FieldDescription,
+	t reflect.Type,
+	fields []structRowField,
+	fieldStack *[]int,
+) ([]structRowField, string) {
+	var missingField string
+	tail := len(*fieldStack)
+	*fieldStack = append(*fieldStack, 0)
+	for i := 0; i < t.NumField(); i++ {
+		sf := t.Field(i)
+		(*fieldStack)[tail] = i
+		if sf.PkgPath != "" && !sf.Anonymous {
+			// Field is unexported, skip it.
+			continue
+		}
+		// Handle anonymous struct embedding, but do not try to handle embedded pointers.
+		if sf.Anonymous && sf.Type.Kind() == reflect.Struct {
+			var missingSubField string
+			fields, missingSubField = computeNamedStructFields(
+				fldDescs,
+				sf.Type,
+				fields,
+				fieldStack,
+			)
+			if missingField == "" {
+				missingField = missingSubField
+			}
+		} else {
+			dbTag, dbTagPresent := sf.Tag.Lookup(structTagKey)
+			if dbTagPresent {
+				dbTag, _, _ = strings.Cut(dbTag, ",")
+			}
+			if dbTag == "-" {
+				// Field is ignored, skip it.
+				continue
+			}
+			colName := dbTag
+			if !dbTagPresent {
+				colName = sf.Name
+			}
+			fpos := fieldPosByName(fldDescs, colName)
+			if fpos == -1 {
+				if missingField == "" {
+					missingField = colName
+				}
+				continue
+			}
+			fields[fpos] = structRowField{
+				path: append([]int(nil), *fieldStack...),
+			}
+		}
+	}
+	*fieldStack = (*fieldStack)[:tail]
+
+	return fields, missingField
 }
 
 const structTagKey = "db"
@@ -682,52 +823,21 @@ func fieldPosByName(fldDescs []pgconn.FieldDescription, field string) (i int) {
 	return
 }
 
-func (rs *namedStructRowScanner) appendScanTargets(dstElemValue reflect.Value, scanTargets []any, fldDescs []pgconn.FieldDescription) ([]any, error) {
-	var err error
-	dstElemType := dstElemValue.Type()
-
-	if scanTargets == nil {
-		scanTargets = make([]any, len(fldDescs))
-	}
-
-	for i := 0; i < dstElemType.NumField(); i++ {
-		sf := dstElemType.Field(i)
-		if sf.PkgPath != "" && !sf.Anonymous {
-			// Field is unexported, skip it.
-			continue
-		}
-		// Handle anonymous struct embedding, but do not try to handle embedded pointers.
-		if sf.Anonymous && sf.Type.Kind() == reflect.Struct {
-			scanTargets, err = rs.appendScanTargets(dstElemValue.Field(i), scanTargets, fldDescs)
-			if err != nil {
-				return nil, err
-			}
-		} else {
-			dbTag, dbTagPresent := sf.Tag.Lookup(structTagKey)
-			if dbTagPresent {
-				dbTag, _, _ = strings.Cut(dbTag, ",")
-			}
-			if dbTag == "-" {
-				// Field is ignored, skip it.
-				continue
-			}
-			colName := dbTag
-			if !dbTagPresent {
-				colName = sf.Name
-			}
-			fpos := fieldPosByName(fldDescs, colName)
-			if fpos == -1 {
-				if rs.lax {
-					continue
-				}
-				return nil, fmt.Errorf("cannot find field %s in returned row", colName)
-			}
-			if fpos >= len(scanTargets) && !rs.lax {
-				return nil, fmt.Errorf("cannot find field %s in returned row", colName)
-			}
-			scanTargets[fpos] = dstElemValue.Field(i).Addr().Interface()
-		}
-	}
-
-	return scanTargets, err
+// structRowField describes a field of a struct.
+//
+// TODO: It would be a bit more efficient to track the path using the pointer
+// offset within the (outermost) struct and use unsafe.Pointer arithmetic to
+// construct references when scanning rows. However, it's not clear it's worth
+// using unsafe for this.
+type structRowField struct {
+	path []int
+}
+
+func setupStructScanTargets(receiver any, fields []structRowField) []any {
+	scanTargets := make([]any, len(fields))
+	v := reflect.ValueOf(receiver).Elem()
+	for i, f := range fields {
+		scanTargets[i] = v.FieldByIndex(f.path).Addr().Interface()
+	}
+	return scanTargets
 }