Add composite to arbitrary struct encoding and decoding

query-exec-mode
Jack Christensen 2022-02-05 14:24:34 -06:00
parent 727fc19cb7
commit 3a94113118
3 changed files with 165 additions and 0 deletions

View File

@ -4,6 +4,7 @@ import (
"fmt"
"math"
"net"
"reflect"
"strconv"
"time"
)
@ -619,3 +620,39 @@ func (w byteSliceWrapper) UUIDValue() (UUID, error) {
copy(uuid.Bytes[:], w)
return uuid, nil
}
// structWrapper implements CompositeIndexGetter for a struct.
type structWrapper struct {
s interface{}
exportedFields []reflect.Value
}
func (w structWrapper) IsNull() bool {
return w.s == nil
}
func (w structWrapper) Index(i int) interface{} {
if i >= len(w.exportedFields) {
return fmt.Errorf("%#v only has %d public fields - %d is out of bounds", w.s, len(w.exportedFields), i)
}
return w.exportedFields[i].Interface()
}
// ptrStructWrapper implements CompositeIndexScanner for a pointer to a struct.
type ptrStructWrapper struct {
s interface{}
exportedFields []reflect.Value
}
func (w *ptrStructWrapper) ScanNull() error {
return fmt.Errorf("cannot scan NULL into %#v", w.s)
}
func (w *ptrStructWrapper) ScanIndex(i int) interface{} {
if i >= len(w.exportedFields) {
return fmt.Errorf("%#v only has %d public fields - %d is out of bounds", w.s, len(w.exportedFields), i)
}
return w.exportedFields[i].Addr().Interface()
}

View File

@ -123,3 +123,42 @@ create type point3d as (
require.Equalf(t, input, output, "%v", format.name)
}
}
func TestCompositeCodecTranscodeStructWrapper(t *testing.T) {
conn := testutil.MustConnectPgx(t)
defer testutil.MustCloseContext(t, conn)
_, err := conn.Exec(context.Background(), `drop type if exists point3d;
create type point3d as (
x float8,
y float8,
z float8
);`)
require.NoError(t, err)
defer conn.Exec(context.Background(), "drop type point3d")
dt, err := conn.LoadDataType(context.Background(), "point3d")
require.NoError(t, err)
conn.ConnInfo().RegisterDataType(*dt)
formats := []struct {
name string
code int16
}{
{name: "TextFormat", code: pgx.TextFormatCode},
{name: "BinaryFormat", code: pgx.BinaryFormatCode},
}
type anotherPoint struct {
X, Y, Z float64
}
for _, format := range formats {
input := anotherPoint{X: 1, Y: 2, Z: 3}
var output anotherPoint
err := conn.QueryRow(context.Background(), "select $1::point3d", pgx.QueryResultFormats{format.code}, input).Scan(&output)
require.NoErrorf(t, err, "%v", format.name)
require.Equalf(t, input, output, "%v", format.name)
}
}

View File

@ -203,12 +203,14 @@ func NewConnInfo() *ConnInfo {
TryWrapDerefPointerEncodePlan,
TryWrapBuiltinTypeEncodePlan,
TryWrapFindUnderlyingTypeEncodePlan,
TryWrapStructEncodePlan,
},
TryWrapScanPlanFuncs: []TryWrapScanPlanFunc{
TryPointerPointerScanPlan,
TryWrapBuiltinTypeScanPlan,
TryFindUnderlyingTypeScanPlan,
TryWrapStructScanPlan,
},
}
@ -887,6 +889,47 @@ func (plan *pointerEmptyInterfaceScanPlan) Scan(src []byte, dst interface{}) err
return nil
}
// TryWrapStructPlan tries to wrap a struct with a wrapper that implements CompositeIndexGetter.
func TryWrapStructScanPlan(target interface{}) (plan WrappedScanPlanNextSetter, nextValue interface{}, ok bool) {
targetValue := reflect.ValueOf(target)
if targetValue.Kind() != reflect.Ptr {
return nil, nil, false
}
targetElemValue := targetValue.Elem()
targetElemType := targetElemValue.Type()
if targetElemType.Kind() == reflect.Struct {
exportedFields := getExportedFieldValues(targetElemValue)
if len(exportedFields) == 0 {
return nil, nil, false
}
w := ptrStructWrapper{
s: target,
exportedFields: exportedFields,
}
return &wrapAnyPtrStructScanPlan{}, &w, true
}
return nil, nil, false
}
type wrapAnyPtrStructScanPlan struct {
next ScanPlan
}
func (plan *wrapAnyPtrStructScanPlan) SetNext(next ScanPlan) { plan.next = next }
func (plan *wrapAnyPtrStructScanPlan) Scan(src []byte, target interface{}) error {
w := ptrStructWrapper{
s: target,
exportedFields: getExportedFieldValues(reflect.ValueOf(target).Elem()),
}
return plan.next.Scan(src, &w)
}
// PlanScan prepares a plan to scan a value into target.
func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, target interface{}) ScanPlan {
if _, ok := target.(*UndecodedBytes); ok {
@ -1406,6 +1449,52 @@ func (plan *wrapFmtStringerEncodePlan) Encode(value interface{}, buf []byte) (ne
return plan.next.Encode(fmtStringerWrapper{value.(fmt.Stringer)}, buf)
}
// TryWrapStructPlan tries to wrap a struct with a wrapper that implements CompositeIndexGetter.
func TryWrapStructEncodePlan(value interface{}) (plan WrappedEncodePlanNextSetter, nextValue interface{}, ok bool) {
if reflect.TypeOf(value).Kind() == reflect.Struct {
exportedFields := getExportedFieldValues(reflect.ValueOf(value))
if len(exportedFields) == 0 {
return nil, nil, false
}
w := structWrapper{
s: value,
exportedFields: exportedFields,
}
return &wrapAnyStructEncodePlan{}, w, true
}
return nil, nil, false
}
type wrapAnyStructEncodePlan struct {
next EncodePlan
}
func (plan *wrapAnyStructEncodePlan) SetNext(next EncodePlan) { plan.next = next }
func (plan *wrapAnyStructEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) {
w := structWrapper{
s: value,
exportedFields: getExportedFieldValues(reflect.ValueOf(value)),
}
return plan.next.Encode(w, buf)
}
func getExportedFieldValues(structValue reflect.Value) []reflect.Value {
structType := structValue.Type()
exportedFields := make([]reflect.Value, 0, structValue.NumField())
for i := 0; i < structType.NumField(); i++ {
sf := structType.Field(i)
if sf.IsExported() {
exportedFields = append(exportedFields, structValue.Field(i))
}
}
return exportedFields
}
// Encode appends the encoded bytes of value to buf. If value is the SQL value NULL then append nothing and return
// (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data
// written.