mirror of https://github.com/jackc/pgx.git
Encode / decode named types with compatible underlying type
Handle string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64.pull/163/head
parent
30cb421551
commit
71d8b5b438
|
@ -11,6 +11,7 @@
|
|||
* Encode and decode between all Go and PostgreSQL integer types with bounds checking
|
||||
* Decode inet/cidr to net.IP
|
||||
* Encode/decode [][]byte to/from bytea[]
|
||||
* Encode/decode named types whoses underlying types are string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64
|
||||
|
||||
## Performance
|
||||
|
||||
|
|
66
values.go
66
values.go
|
@ -615,12 +615,14 @@ func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error {
|
|||
return encodeByteSliceSlice(wbuf, oid, arg)
|
||||
}
|
||||
|
||||
if v := reflect.ValueOf(arg); v.Kind() == reflect.Ptr {
|
||||
if v.IsNil() {
|
||||
refVal := reflect.ValueOf(arg)
|
||||
|
||||
if refVal.Kind() == reflect.Ptr {
|
||||
if refVal.IsNil() {
|
||||
wbuf.WriteInt32(-1)
|
||||
return nil
|
||||
} else {
|
||||
arg = v.Elem().Interface()
|
||||
arg = refVal.Elem().Interface()
|
||||
return Encode(wbuf, oid, arg)
|
||||
}
|
||||
}
|
||||
|
@ -691,10 +693,42 @@ func Encode(wbuf *WriteBuf, oid Oid, arg interface{}) error {
|
|||
case Oid:
|
||||
return encodeOid(wbuf, oid, arg)
|
||||
default:
|
||||
if strippedArg, ok := stripNamedType(&refVal); ok {
|
||||
return Encode(wbuf, oid, strippedArg)
|
||||
}
|
||||
return SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
|
||||
}
|
||||
}
|
||||
|
||||
func stripNamedType(val *reflect.Value) (interface{}, bool) {
|
||||
switch val.Kind() {
|
||||
case reflect.Int:
|
||||
return int(val.Int()), true
|
||||
case reflect.Int8:
|
||||
return int8(val.Int()), true
|
||||
case reflect.Int16:
|
||||
return int16(val.Int()), true
|
||||
case reflect.Int32:
|
||||
return int32(val.Int()), true
|
||||
case reflect.Int64:
|
||||
return int64(val.Int()), true
|
||||
case reflect.Uint:
|
||||
return uint(val.Uint()), true
|
||||
case reflect.Uint8:
|
||||
return uint8(val.Uint()), true
|
||||
case reflect.Uint16:
|
||||
return uint16(val.Uint()), true
|
||||
case reflect.Uint32:
|
||||
return uint32(val.Uint()), true
|
||||
case reflect.Uint64:
|
||||
return uint64(val.Uint()), true
|
||||
case reflect.String:
|
||||
return val.String(), true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Decode decodes from vr into d. d must be a pointer. This allows
|
||||
// implementations of the Decoder interface to delegate the actual work of
|
||||
// decoding to the built-in functionality.
|
||||
|
@ -846,9 +880,11 @@ func Decode(vr *ValueReader, d interface{}) error {
|
|||
case *[]net.IPNet:
|
||||
*v = decodeInetArray(vr)
|
||||
default:
|
||||
// if d is a pointer to pointer, strip the pointer and try again
|
||||
if v := reflect.ValueOf(d); v.Kind() == reflect.Ptr {
|
||||
if el := v.Elem(); el.Kind() == reflect.Ptr {
|
||||
el := v.Elem()
|
||||
switch el.Kind() {
|
||||
// if d is a pointer to pointer, strip the pointer and try again
|
||||
case reflect.Ptr:
|
||||
// -1 is a null value
|
||||
if vr.Len() == -1 {
|
||||
if !el.IsNil() {
|
||||
|
@ -864,6 +900,26 @@ func Decode(vr *ValueReader, d interface{}) error {
|
|||
d = el.Interface()
|
||||
return Decode(vr, d)
|
||||
}
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
n := decodeInt(vr)
|
||||
if el.OverflowInt(n) {
|
||||
return fmt.Errorf("Scan cannot decode %d into %T", n, d)
|
||||
}
|
||||
el.SetInt(n)
|
||||
return nil
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
n := decodeInt(vr)
|
||||
if n < 0 {
|
||||
return fmt.Errorf("%d is less than zero for %T", n, d)
|
||||
}
|
||||
if el.OverflowUint(uint64(n)) {
|
||||
return fmt.Errorf("Scan cannot decode %d into %T", n, d)
|
||||
}
|
||||
el.SetUint(uint64(n))
|
||||
return nil
|
||||
case reflect.String:
|
||||
el.SetString(decodeText(vr))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("Scan cannot decode into %T", d)
|
||||
|
|
104
values_test.go
104
values_test.go
|
@ -960,6 +960,110 @@ func TestPointerPointerNonZero(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestEncodeTypeRename(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
type _int int
|
||||
inInt := _int(3)
|
||||
var outInt _int
|
||||
|
||||
type _int8 int8
|
||||
inInt8 := _int8(3)
|
||||
var outInt8 _int8
|
||||
|
||||
type _int16 int16
|
||||
inInt16 := _int16(3)
|
||||
var outInt16 _int16
|
||||
|
||||
type _int32 int32
|
||||
inInt32 := _int32(4)
|
||||
var outInt32 _int32
|
||||
|
||||
type _int64 int64
|
||||
inInt64 := _int64(5)
|
||||
var outInt64 _int64
|
||||
|
||||
type _uint uint
|
||||
inUint := _uint(6)
|
||||
var outUint _uint
|
||||
|
||||
type _uint8 uint8
|
||||
inUint8 := _uint8(7)
|
||||
var outUint8 _uint8
|
||||
|
||||
type _uint16 uint16
|
||||
inUint16 := _uint16(8)
|
||||
var outUint16 _uint16
|
||||
|
||||
type _uint32 uint32
|
||||
inUint32 := _uint32(9)
|
||||
var outUint32 _uint32
|
||||
|
||||
type _uint64 uint64
|
||||
inUint64 := _uint64(10)
|
||||
var outUint64 _uint64
|
||||
|
||||
type _string string
|
||||
inString := _string("foo")
|
||||
var outString _string
|
||||
|
||||
err := conn.QueryRow("select $1::int, $2::int, $3::int2, $4::int4, $5::int8, $6::int, $7::int, $8::int, $9::int, $10::int, $11::text",
|
||||
inInt, inInt8, inInt16, inInt32, inInt64, inUint, inUint8, inUint16, inUint32, inUint64, inString,
|
||||
).Scan(&outInt, &outInt8, &outInt16, &outInt32, &outInt64, &outUint, &outUint8, &outUint16, &outUint32, &outUint64, &outString)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed with type rename: %v", err)
|
||||
}
|
||||
|
||||
if inInt != outInt {
|
||||
t.Errorf("int rename: expected %v, got %v", inInt, outInt)
|
||||
}
|
||||
|
||||
if inInt8 != outInt8 {
|
||||
t.Errorf("int8 rename: expected %v, got %v", inInt8, outInt8)
|
||||
}
|
||||
|
||||
if inInt16 != outInt16 {
|
||||
t.Errorf("int16 rename: expected %v, got %v", inInt16, outInt16)
|
||||
}
|
||||
|
||||
if inInt32 != outInt32 {
|
||||
t.Errorf("int32 rename: expected %v, got %v", inInt32, outInt32)
|
||||
}
|
||||
|
||||
if inInt64 != outInt64 {
|
||||
t.Errorf("int64 rename: expected %v, got %v", inInt64, outInt64)
|
||||
}
|
||||
|
||||
if inUint != outUint {
|
||||
t.Errorf("uint rename: expected %v, got %v", inUint, outUint)
|
||||
}
|
||||
|
||||
if inUint8 != outUint8 {
|
||||
t.Errorf("uint8 rename: expected %v, got %v", inUint8, outUint8)
|
||||
}
|
||||
|
||||
if inUint16 != outUint16 {
|
||||
t.Errorf("uint16 rename: expected %v, got %v", inUint16, outUint16)
|
||||
}
|
||||
|
||||
if inUint32 != outUint32 {
|
||||
t.Errorf("uint32 rename: expected %v, got %v", inUint32, outUint32)
|
||||
}
|
||||
|
||||
if inUint64 != outUint64 {
|
||||
t.Errorf("uint64 rename: expected %v, got %v", inUint64, outUint64)
|
||||
}
|
||||
|
||||
if inString != outString {
|
||||
t.Errorf("string rename: expected %v, got %v", inString, outString)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestRowDecode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
Loading…
Reference in New Issue