From 71d8b5b4384d3af9c047c3b0d1ae5c386e3869c5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 5 Jul 2016 18:01:44 -0500 Subject: [PATCH] Encode / decode named types with compatible underlying type Handle string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64. --- CHANGELOG.md | 1 + values.go | 66 ++++++++++++++++++++++++++++--- values_test.go | 104 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 166 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 93c69721..bbeb2c29 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ * Encode and decode between all Go and PostgreSQL integer types with bounds checking * Decode inet/cidr to net.IP * Encode/decode [][]byte to/from bytea[] +* Encode/decode named types whoses underlying types are string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64 ## Performance diff --git a/values.go b/values.go index f80d7519..b6e0a84b 100644 --- a/values.go +++ b/values.go @@ -615,12 +615,14 @@ func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error { return encodeByteSliceSlice(wbuf, oid, arg) } - if v := reflect.ValueOf(arg); v.Kind() == reflect.Ptr { - if v.IsNil() { + refVal := reflect.ValueOf(arg) + + if refVal.Kind() == reflect.Ptr { + if refVal.IsNil() { wbuf.WriteInt32(-1) return nil } else { - arg = v.Elem().Interface() + arg = refVal.Elem().Interface() return Encode(wbuf, oid, arg) } } @@ -691,10 +693,42 @@ func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error { case Oid: return encodeOid(wbuf, oid, arg) default: + if strippedArg, ok := stripNamedType(&refVal); ok { + return Encode(wbuf, oid, strippedArg) + } return SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) } } +func stripNamedType(val *reflect.Value) (interface{}, bool) { + switch val.Kind() { + case reflect.Int: + return int(val.Int()), true + case reflect.Int8: + return int8(val.Int()), true + case reflect.Int16: + return int16(val.Int()), true + case reflect.Int32: + return int32(val.Int()), true + case reflect.Int64: + return int64(val.Int()), true + case reflect.Uint: + return uint(val.Uint()), true + case reflect.Uint8: + return uint8(val.Uint()), true + case reflect.Uint16: + return uint16(val.Uint()), true + case reflect.Uint32: + return uint32(val.Uint()), true + case reflect.Uint64: + return uint64(val.Uint()), true + case reflect.String: + return val.String(), true + } + + return nil, false +} + // Decode decodes from vr into d. d must be a pointer. This allows // implementations of the Decoder interface to delegate the actual work of // decoding to the built-in functionality. @@ -846,9 +880,11 @@ func Decode(vr *ValueReader, d interface{}) error { case *[]net.IPNet: *v = decodeInetArray(vr) default: - // if d is a pointer to pointer, strip the pointer and try again if v := reflect.ValueOf(d); v.Kind() == reflect.Ptr { - if el := v.Elem(); el.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if d is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: // -1 is a null value if vr.Len() == -1 { if !el.IsNil() { @@ -864,6 +900,26 @@ func Decode(vr *ValueReader, d interface{}) error { d = el.Interface() return Decode(vr, d) } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + n := decodeInt(vr) + if el.OverflowInt(n) { + return fmt.Errorf("Scan cannot decode %d into %T", n, d) + } + el.SetInt(n) + return nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + n := decodeInt(vr) + if n < 0 { + return fmt.Errorf("%d is less than zero for %T", n, d) + } + if el.OverflowUint(uint64(n)) { + return fmt.Errorf("Scan cannot decode %d into %T", n, d) + } + el.SetUint(uint64(n)) + return nil + case reflect.String: + el.SetString(decodeText(vr)) + return nil } } return fmt.Errorf("Scan cannot decode into %T", d) diff --git a/values_test.go b/values_test.go index 0e29c7d1..063598d9 100644 --- a/values_test.go +++ b/values_test.go @@ -960,6 +960,110 @@ func TestPointerPointerNonZero(t *testing.T) { } } +func TestEncodeTypeRename(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + type _int int + inInt := _int(3) + var outInt _int + + type _int8 int8 + inInt8 := _int8(3) + var outInt8 _int8 + + type _int16 int16 + inInt16 := _int16(3) + var outInt16 _int16 + + type _int32 int32 + inInt32 := _int32(4) + var outInt32 _int32 + + type _int64 int64 + inInt64 := _int64(5) + var outInt64 _int64 + + type _uint uint + inUint := _uint(6) + var outUint _uint + + type _uint8 uint8 + inUint8 := _uint8(7) + var outUint8 _uint8 + + type _uint16 uint16 + inUint16 := _uint16(8) + var outUint16 _uint16 + + type _uint32 uint32 + inUint32 := _uint32(9) + var outUint32 _uint32 + + type _uint64 uint64 + inUint64 := _uint64(10) + var outUint64 _uint64 + + type _string string + inString := _string("foo") + var outString _string + + err := conn.QueryRow("select $1::int, $2::int, $3::int2, $4::int4, $5::int8, $6::int, $7::int, $8::int, $9::int, $10::int, $11::text", + inInt, inInt8, inInt16, inInt32, inInt64, inUint, inUint8, inUint16, inUint32, inUint64, inString, + ).Scan(&outInt, &outInt8, &outInt16, &outInt32, &outInt64, &outUint, &outUint8, &outUint16, &outUint32, &outUint64, &outString) + if err != nil { + t.Fatalf("Failed with type rename: %v", err) + } + + if inInt != outInt { + t.Errorf("int rename: expected %v, got %v", inInt, outInt) + } + + if inInt8 != outInt8 { + t.Errorf("int8 rename: expected %v, got %v", inInt8, outInt8) + } + + if inInt16 != outInt16 { + t.Errorf("int16 rename: expected %v, got %v", inInt16, outInt16) + } + + if inInt32 != outInt32 { + t.Errorf("int32 rename: expected %v, got %v", inInt32, outInt32) + } + + if inInt64 != outInt64 { + t.Errorf("int64 rename: expected %v, got %v", inInt64, outInt64) + } + + if inUint != outUint { + t.Errorf("uint rename: expected %v, got %v", inUint, outUint) + } + + if inUint8 != outUint8 { + t.Errorf("uint8 rename: expected %v, got %v", inUint8, outUint8) + } + + if inUint16 != outUint16 { + t.Errorf("uint16 rename: expected %v, got %v", inUint16, outUint16) + } + + if inUint32 != outUint32 { + t.Errorf("uint32 rename: expected %v, got %v", inUint32, outUint32) + } + + if inUint64 != outUint64 { + t.Errorf("uint64 rename: expected %v, got %v", inUint64, outUint64) + } + + if inString != outString { + t.Errorf("string rename: expected %v, got %v", inString, outString) + } + + ensureConnValid(t, conn) +} + func TestRowDecode(t *testing.T) { t.Parallel()