diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go index fe58eee0..1799de55 100644 --- a/pgtype/builtin_wrappers.go +++ b/pgtype/builtin_wrappers.go @@ -4,6 +4,7 @@ import ( "fmt" "math" "net" + "reflect" "strconv" "time" ) @@ -619,3 +620,39 @@ func (w byteSliceWrapper) UUIDValue() (UUID, error) { copy(uuid.Bytes[:], w) return uuid, nil } + +// structWrapper implements CompositeIndexGetter for a struct. +type structWrapper struct { + s interface{} + exportedFields []reflect.Value +} + +func (w structWrapper) IsNull() bool { + return w.s == nil +} + +func (w structWrapper) Index(i int) interface{} { + if i >= len(w.exportedFields) { + return fmt.Errorf("%#v only has %d public fields - %d is out of bounds", w.s, len(w.exportedFields), i) + } + + return w.exportedFields[i].Interface() +} + +// ptrStructWrapper implements CompositeIndexScanner for a pointer to a struct. +type ptrStructWrapper struct { + s interface{} + exportedFields []reflect.Value +} + +func (w *ptrStructWrapper) ScanNull() error { + return fmt.Errorf("cannot scan NULL into %#v", w.s) +} + +func (w *ptrStructWrapper) ScanIndex(i int) interface{} { + if i >= len(w.exportedFields) { + return fmt.Errorf("%#v only has %d public fields - %d is out of bounds", w.s, len(w.exportedFields), i) + } + + return w.exportedFields[i].Addr().Interface() +} diff --git a/pgtype/composite_test.go b/pgtype/composite_test.go index c9319c2d..9a0eff2a 100644 --- a/pgtype/composite_test.go +++ b/pgtype/composite_test.go @@ -123,3 +123,42 @@ create type point3d as ( require.Equalf(t, input, output, "%v", format.name) } } + +func TestCompositeCodecTranscodeStructWrapper(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + _, err := conn.Exec(context.Background(), `drop type if exists point3d; + +create type point3d as ( + x float8, + y float8, + z float8 +);`) + require.NoError(t, err) + defer conn.Exec(context.Background(), "drop type point3d") + + dt, err := conn.LoadDataType(context.Background(), "point3d") + require.NoError(t, err) + conn.ConnInfo().RegisterDataType(*dt) + + formats := []struct { + name string + code int16 + }{ + {name: "TextFormat", code: pgx.TextFormatCode}, + {name: "BinaryFormat", code: pgx.BinaryFormatCode}, + } + + type anotherPoint struct { + X, Y, Z float64 + } + + for _, format := range formats { + input := anotherPoint{X: 1, Y: 2, Z: 3} + var output anotherPoint + err := conn.QueryRow(context.Background(), "select $1::point3d", pgx.QueryResultFormats{format.code}, input).Scan(&output) + require.NoErrorf(t, err, "%v", format.name) + require.Equalf(t, input, output, "%v", format.name) + } +} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 8c4a8c49..8db5ae3f 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -203,12 +203,14 @@ func NewConnInfo() *ConnInfo { TryWrapDerefPointerEncodePlan, TryWrapBuiltinTypeEncodePlan, TryWrapFindUnderlyingTypeEncodePlan, + TryWrapStructEncodePlan, }, TryWrapScanPlanFuncs: []TryWrapScanPlanFunc{ TryPointerPointerScanPlan, TryWrapBuiltinTypeScanPlan, TryFindUnderlyingTypeScanPlan, + TryWrapStructScanPlan, }, } @@ -887,6 +889,47 @@ func (plan *pointerEmptyInterfaceScanPlan) Scan(src []byte, dst interface{}) err return nil } +// TryWrapStructPlan tries to wrap a struct with a wrapper that implements CompositeIndexGetter. +func TryWrapStructScanPlan(target interface{}) (plan WrappedScanPlanNextSetter, nextValue interface{}, ok bool) { + targetValue := reflect.ValueOf(target) + if targetValue.Kind() != reflect.Ptr { + return nil, nil, false + } + + targetElemValue := targetValue.Elem() + targetElemType := targetElemValue.Type() + + if targetElemType.Kind() == reflect.Struct { + exportedFields := getExportedFieldValues(targetElemValue) + if len(exportedFields) == 0 { + return nil, nil, false + } + + w := ptrStructWrapper{ + s: target, + exportedFields: exportedFields, + } + return &wrapAnyPtrStructScanPlan{}, &w, true + } + + return nil, nil, false +} + +type wrapAnyPtrStructScanPlan struct { + next ScanPlan +} + +func (plan *wrapAnyPtrStructScanPlan) SetNext(next ScanPlan) { plan.next = next } + +func (plan *wrapAnyPtrStructScanPlan) Scan(src []byte, target interface{}) error { + w := ptrStructWrapper{ + s: target, + exportedFields: getExportedFieldValues(reflect.ValueOf(target).Elem()), + } + + return plan.next.Scan(src, &w) +} + // PlanScan prepares a plan to scan a value into target. func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, target interface{}) ScanPlan { if _, ok := target.(*UndecodedBytes); ok { @@ -1406,6 +1449,52 @@ func (plan *wrapFmtStringerEncodePlan) Encode(value interface{}, buf []byte) (ne return plan.next.Encode(fmtStringerWrapper{value.(fmt.Stringer)}, buf) } +// TryWrapStructPlan tries to wrap a struct with a wrapper that implements CompositeIndexGetter. +func TryWrapStructEncodePlan(value interface{}) (plan WrappedEncodePlanNextSetter, nextValue interface{}, ok bool) { + if reflect.TypeOf(value).Kind() == reflect.Struct { + exportedFields := getExportedFieldValues(reflect.ValueOf(value)) + if len(exportedFields) == 0 { + return nil, nil, false + } + + w := structWrapper{ + s: value, + exportedFields: exportedFields, + } + return &wrapAnyStructEncodePlan{}, w, true + } + + return nil, nil, false +} + +type wrapAnyStructEncodePlan struct { + next EncodePlan +} + +func (plan *wrapAnyStructEncodePlan) SetNext(next EncodePlan) { plan.next = next } + +func (plan *wrapAnyStructEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + w := structWrapper{ + s: value, + exportedFields: getExportedFieldValues(reflect.ValueOf(value)), + } + + return plan.next.Encode(w, buf) +} + +func getExportedFieldValues(structValue reflect.Value) []reflect.Value { + structType := structValue.Type() + exportedFields := make([]reflect.Value, 0, structValue.NumField()) + for i := 0; i < structType.NumField(); i++ { + sf := structType.Field(i) + if sf.IsExported() { + exportedFields = append(exportedFields, structValue.Field(i)) + } + } + + return exportedFields +} + // Encode appends the encoded bytes of value to buf. If value is the SQL value NULL then append nothing and return // (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data // written.