mirror of https://github.com/jackc/pgx.git
Add composite to arbitrary struct encoding and decoding
parent
727fc19cb7
commit
3a94113118
|
@ -4,6 +4,7 @@ import (
|
|||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
@ -619,3 +620,39 @@ func (w byteSliceWrapper) UUIDValue() (UUID, error) {
|
|||
copy(uuid.Bytes[:], w)
|
||||
return uuid, nil
|
||||
}
|
||||
|
||||
// structWrapper implements CompositeIndexGetter for a struct.
|
||||
type structWrapper struct {
|
||||
s interface{}
|
||||
exportedFields []reflect.Value
|
||||
}
|
||||
|
||||
func (w structWrapper) IsNull() bool {
|
||||
return w.s == nil
|
||||
}
|
||||
|
||||
func (w structWrapper) Index(i int) interface{} {
|
||||
if i >= len(w.exportedFields) {
|
||||
return fmt.Errorf("%#v only has %d public fields - %d is out of bounds", w.s, len(w.exportedFields), i)
|
||||
}
|
||||
|
||||
return w.exportedFields[i].Interface()
|
||||
}
|
||||
|
||||
// ptrStructWrapper implements CompositeIndexScanner for a pointer to a struct.
|
||||
type ptrStructWrapper struct {
|
||||
s interface{}
|
||||
exportedFields []reflect.Value
|
||||
}
|
||||
|
||||
func (w *ptrStructWrapper) ScanNull() error {
|
||||
return fmt.Errorf("cannot scan NULL into %#v", w.s)
|
||||
}
|
||||
|
||||
func (w *ptrStructWrapper) ScanIndex(i int) interface{} {
|
||||
if i >= len(w.exportedFields) {
|
||||
return fmt.Errorf("%#v only has %d public fields - %d is out of bounds", w.s, len(w.exportedFields), i)
|
||||
}
|
||||
|
||||
return w.exportedFields[i].Addr().Interface()
|
||||
}
|
||||
|
|
|
@ -123,3 +123,42 @@ create type point3d as (
|
|||
require.Equalf(t, input, output, "%v", format.name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompositeCodecTranscodeStructWrapper(t *testing.T) {
|
||||
conn := testutil.MustConnectPgx(t)
|
||||
defer testutil.MustCloseContext(t, conn)
|
||||
|
||||
_, err := conn.Exec(context.Background(), `drop type if exists point3d;
|
||||
|
||||
create type point3d as (
|
||||
x float8,
|
||||
y float8,
|
||||
z float8
|
||||
);`)
|
||||
require.NoError(t, err)
|
||||
defer conn.Exec(context.Background(), "drop type point3d")
|
||||
|
||||
dt, err := conn.LoadDataType(context.Background(), "point3d")
|
||||
require.NoError(t, err)
|
||||
conn.ConnInfo().RegisterDataType(*dt)
|
||||
|
||||
formats := []struct {
|
||||
name string
|
||||
code int16
|
||||
}{
|
||||
{name: "TextFormat", code: pgx.TextFormatCode},
|
||||
{name: "BinaryFormat", code: pgx.BinaryFormatCode},
|
||||
}
|
||||
|
||||
type anotherPoint struct {
|
||||
X, Y, Z float64
|
||||
}
|
||||
|
||||
for _, format := range formats {
|
||||
input := anotherPoint{X: 1, Y: 2, Z: 3}
|
||||
var output anotherPoint
|
||||
err := conn.QueryRow(context.Background(), "select $1::point3d", pgx.QueryResultFormats{format.code}, input).Scan(&output)
|
||||
require.NoErrorf(t, err, "%v", format.name)
|
||||
require.Equalf(t, input, output, "%v", format.name)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -203,12 +203,14 @@ func NewConnInfo() *ConnInfo {
|
|||
TryWrapDerefPointerEncodePlan,
|
||||
TryWrapBuiltinTypeEncodePlan,
|
||||
TryWrapFindUnderlyingTypeEncodePlan,
|
||||
TryWrapStructEncodePlan,
|
||||
},
|
||||
|
||||
TryWrapScanPlanFuncs: []TryWrapScanPlanFunc{
|
||||
TryPointerPointerScanPlan,
|
||||
TryWrapBuiltinTypeScanPlan,
|
||||
TryFindUnderlyingTypeScanPlan,
|
||||
TryWrapStructScanPlan,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -887,6 +889,47 @@ func (plan *pointerEmptyInterfaceScanPlan) Scan(src []byte, dst interface{}) err
|
|||
return nil
|
||||
}
|
||||
|
||||
// TryWrapStructPlan tries to wrap a struct with a wrapper that implements CompositeIndexGetter.
|
||||
func TryWrapStructScanPlan(target interface{}) (plan WrappedScanPlanNextSetter, nextValue interface{}, ok bool) {
|
||||
targetValue := reflect.ValueOf(target)
|
||||
if targetValue.Kind() != reflect.Ptr {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
targetElemValue := targetValue.Elem()
|
||||
targetElemType := targetElemValue.Type()
|
||||
|
||||
if targetElemType.Kind() == reflect.Struct {
|
||||
exportedFields := getExportedFieldValues(targetElemValue)
|
||||
if len(exportedFields) == 0 {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
w := ptrStructWrapper{
|
||||
s: target,
|
||||
exportedFields: exportedFields,
|
||||
}
|
||||
return &wrapAnyPtrStructScanPlan{}, &w, true
|
||||
}
|
||||
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
type wrapAnyPtrStructScanPlan struct {
|
||||
next ScanPlan
|
||||
}
|
||||
|
||||
func (plan *wrapAnyPtrStructScanPlan) SetNext(next ScanPlan) { plan.next = next }
|
||||
|
||||
func (plan *wrapAnyPtrStructScanPlan) Scan(src []byte, target interface{}) error {
|
||||
w := ptrStructWrapper{
|
||||
s: target,
|
||||
exportedFields: getExportedFieldValues(reflect.ValueOf(target).Elem()),
|
||||
}
|
||||
|
||||
return plan.next.Scan(src, &w)
|
||||
}
|
||||
|
||||
// PlanScan prepares a plan to scan a value into target.
|
||||
func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, target interface{}) ScanPlan {
|
||||
if _, ok := target.(*UndecodedBytes); ok {
|
||||
|
@ -1406,6 +1449,52 @@ func (plan *wrapFmtStringerEncodePlan) Encode(value interface{}, buf []byte) (ne
|
|||
return plan.next.Encode(fmtStringerWrapper{value.(fmt.Stringer)}, buf)
|
||||
}
|
||||
|
||||
// TryWrapStructPlan tries to wrap a struct with a wrapper that implements CompositeIndexGetter.
|
||||
func TryWrapStructEncodePlan(value interface{}) (plan WrappedEncodePlanNextSetter, nextValue interface{}, ok bool) {
|
||||
if reflect.TypeOf(value).Kind() == reflect.Struct {
|
||||
exportedFields := getExportedFieldValues(reflect.ValueOf(value))
|
||||
if len(exportedFields) == 0 {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
w := structWrapper{
|
||||
s: value,
|
||||
exportedFields: exportedFields,
|
||||
}
|
||||
return &wrapAnyStructEncodePlan{}, w, true
|
||||
}
|
||||
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
type wrapAnyStructEncodePlan struct {
|
||||
next EncodePlan
|
||||
}
|
||||
|
||||
func (plan *wrapAnyStructEncodePlan) SetNext(next EncodePlan) { plan.next = next }
|
||||
|
||||
func (plan *wrapAnyStructEncodePlan) Encode(value interface{}, buf []byte) (newBuf []byte, err error) {
|
||||
w := structWrapper{
|
||||
s: value,
|
||||
exportedFields: getExportedFieldValues(reflect.ValueOf(value)),
|
||||
}
|
||||
|
||||
return plan.next.Encode(w, buf)
|
||||
}
|
||||
|
||||
func getExportedFieldValues(structValue reflect.Value) []reflect.Value {
|
||||
structType := structValue.Type()
|
||||
exportedFields := make([]reflect.Value, 0, structValue.NumField())
|
||||
for i := 0; i < structType.NumField(); i++ {
|
||||
sf := structType.Field(i)
|
||||
if sf.IsExported() {
|
||||
exportedFields = append(exportedFields, structValue.Field(i))
|
||||
}
|
||||
}
|
||||
|
||||
return exportedFields
|
||||
}
|
||||
|
||||
// Encode appends the encoded bytes of value to buf. If value is the SQL value NULL then append nothing and return
|
||||
// (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data
|
||||
// written.
|
||||
|
|
Loading…
Reference in New Issue