Encode / decode named types with compatible underlying type

Handle string, int, int8, int16, int32, int64, uint, uint8, uint16,
uint32, uint64.
pull/163/head
Jack Christensen 2016-07-05 18:01:44 -05:00
parent 30cb421551
commit 71d8b5b438
3 changed files with 166 additions and 5 deletions

View File

@ -11,6 +11,7 @@
* Encode and decode between all Go and PostgreSQL integer types with bounds checking
* Decode inet/cidr to net.IP
* Encode/decode [][]byte to/from bytea[]
* Encode/decode named types whoses underlying types are string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64
## Performance

View File

@ -615,12 +615,14 @@ func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error {
return encodeByteSliceSlice(wbuf, oid, arg)
}
if v := reflect.ValueOf(arg); v.Kind() == reflect.Ptr {
if v.IsNil() {
refVal := reflect.ValueOf(arg)
if refVal.Kind() == reflect.Ptr {
if refVal.IsNil() {
wbuf.WriteInt32(-1)
return nil
} else {
arg = v.Elem().Interface()
arg = refVal.Elem().Interface()
return Encode(wbuf, oid, arg)
}
}
@ -691,10 +693,42 @@ func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error {
case Oid:
return encodeOid(wbuf, oid, arg)
default:
if strippedArg, ok := stripNamedType(&refVal); ok {
return Encode(wbuf, oid, strippedArg)
}
return SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
}
}
func stripNamedType(val *reflect.Value) (interface{}, bool) {
switch val.Kind() {
case reflect.Int:
return int(val.Int()), true
case reflect.Int8:
return int8(val.Int()), true
case reflect.Int16:
return int16(val.Int()), true
case reflect.Int32:
return int32(val.Int()), true
case reflect.Int64:
return int64(val.Int()), true
case reflect.Uint:
return uint(val.Uint()), true
case reflect.Uint8:
return uint8(val.Uint()), true
case reflect.Uint16:
return uint16(val.Uint()), true
case reflect.Uint32:
return uint32(val.Uint()), true
case reflect.Uint64:
return uint64(val.Uint()), true
case reflect.String:
return val.String(), true
}
return nil, false
}
// Decode decodes from vr into d. d must be a pointer. This allows
// implementations of the Decoder interface to delegate the actual work of
// decoding to the built-in functionality.
@ -846,9 +880,11 @@ func Decode(vr *ValueReader, d interface{}) error {
case *[]net.IPNet:
*v = decodeInetArray(vr)
default:
// if d is a pointer to pointer, strip the pointer and try again
if v := reflect.ValueOf(d); v.Kind() == reflect.Ptr {
if el := v.Elem(); el.Kind() == reflect.Ptr {
el := v.Elem()
switch el.Kind() {
// if d is a pointer to pointer, strip the pointer and try again
case reflect.Ptr:
// -1 is a null value
if vr.Len() == -1 {
if !el.IsNil() {
@ -864,6 +900,26 @@ func Decode(vr *ValueReader, d interface{}) error {
d = el.Interface()
return Decode(vr, d)
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
n := decodeInt(vr)
if el.OverflowInt(n) {
return fmt.Errorf("Scan cannot decode %d into %T", n, d)
}
el.SetInt(n)
return nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
n := decodeInt(vr)
if n < 0 {
return fmt.Errorf("%d is less than zero for %T", n, d)
}
if el.OverflowUint(uint64(n)) {
return fmt.Errorf("Scan cannot decode %d into %T", n, d)
}
el.SetUint(uint64(n))
return nil
case reflect.String:
el.SetString(decodeText(vr))
return nil
}
}
return fmt.Errorf("Scan cannot decode into %T", d)

View File

@ -960,6 +960,110 @@ func TestPointerPointerNonZero(t *testing.T) {
}
}
func TestEncodeTypeRename(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
type _int int
inInt := _int(3)
var outInt _int
type _int8 int8
inInt8 := _int8(3)
var outInt8 _int8
type _int16 int16
inInt16 := _int16(3)
var outInt16 _int16
type _int32 int32
inInt32 := _int32(4)
var outInt32 _int32
type _int64 int64
inInt64 := _int64(5)
var outInt64 _int64
type _uint uint
inUint := _uint(6)
var outUint _uint
type _uint8 uint8
inUint8 := _uint8(7)
var outUint8 _uint8
type _uint16 uint16
inUint16 := _uint16(8)
var outUint16 _uint16
type _uint32 uint32
inUint32 := _uint32(9)
var outUint32 _uint32
type _uint64 uint64
inUint64 := _uint64(10)
var outUint64 _uint64
type _string string
inString := _string("foo")
var outString _string
err := conn.QueryRow("select $1::int, $2::int, $3::int2, $4::int4, $5::int8, $6::int, $7::int, $8::int, $9::int, $10::int, $11::text",
inInt, inInt8, inInt16, inInt32, inInt64, inUint, inUint8, inUint16, inUint32, inUint64, inString,
).Scan(&outInt, &outInt8, &outInt16, &outInt32, &outInt64, &outUint, &outUint8, &outUint16, &outUint32, &outUint64, &outString)
if err != nil {
t.Fatalf("Failed with type rename: %v", err)
}
if inInt != outInt {
t.Errorf("int rename: expected %v, got %v", inInt, outInt)
}
if inInt8 != outInt8 {
t.Errorf("int8 rename: expected %v, got %v", inInt8, outInt8)
}
if inInt16 != outInt16 {
t.Errorf("int16 rename: expected %v, got %v", inInt16, outInt16)
}
if inInt32 != outInt32 {
t.Errorf("int32 rename: expected %v, got %v", inInt32, outInt32)
}
if inInt64 != outInt64 {
t.Errorf("int64 rename: expected %v, got %v", inInt64, outInt64)
}
if inUint != outUint {
t.Errorf("uint rename: expected %v, got %v", inUint, outUint)
}
if inUint8 != outUint8 {
t.Errorf("uint8 rename: expected %v, got %v", inUint8, outUint8)
}
if inUint16 != outUint16 {
t.Errorf("uint16 rename: expected %v, got %v", inUint16, outUint16)
}
if inUint32 != outUint32 {
t.Errorf("uint32 rename: expected %v, got %v", inUint32, outUint32)
}
if inUint64 != outUint64 {
t.Errorf("uint64 rename: expected %v, got %v", inUint64, outUint64)
}
if inString != outString {
t.Errorf("string rename: expected %v, got %v", inString, outString)
}
ensureConnValid(t, conn)
}
func TestRowDecode(t *testing.T) {
t.Parallel()