Add basic record to []interface{} decoding

refs #155
pull/159/head
Jack Christensen 2016-06-21 15:00:47 -05:00
parent 9d284da48e
commit 30cb421551
4 changed files with 126 additions and 6 deletions

View File

@ -1,15 +1,13 @@
# Unreleased
## Features
* Add PrepareEx
## Fixes
* Fix *ConnPool.Deallocate() not deleting prepared statement from map
## Features
* Add PrepareEx
* Add basic record to []interface{} decoding
* Encode and decode between all Go and PostgreSQL integer types with bounds checking
* Decode inet/cidr to net.IP
* Encode/decode [][]byte to/from bytea[]

11
conn.go
View File

@ -332,7 +332,14 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
}
func (c *Conn) loadPgTypes() error {
rows, err := c.Query("select t.oid, t.typname from pg_type t left join pg_type base_type on t.typelem=base_type.oid where t.typtype='b' and (base_type.oid is null or base_type.typtype='b');")
rows, err := c.Query(`select t.oid, t.typname
from pg_type t
left join pg_type base_type on t.typelem=base_type.oid
where (
t.typtype='b'
and (base_type.oid is null or base_type.typtype='b')
)
or t.typname in('record');`)
if err != nil {
return err
}
@ -910,7 +917,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
wbuf.WriteInt16(TextFormatCode)
default:
switch oid {
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampOid, TimestampArrayOid, DateOid, BoolArrayOid, ByteaArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid, InetOid, CidrOid, InetArrayOid, CidrArrayOid:
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampOid, TimestampArrayOid, DateOid, BoolArrayOid, ByteaArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid, InetOid, CidrOid, InetArrayOid, CidrArrayOid, RecordOid:
wbuf.WriteInt16(BinaryFormatCode)
default:
wbuf.WriteInt16(TextFormatCode)

View File

@ -27,6 +27,7 @@ const (
CidrArrayOid = 651
Float4Oid = 700
Float8Oid = 701
UnknownOid = 705
InetOid = 869
BoolArrayOid = 1000
Int2ArrayOid = 1005
@ -44,6 +45,7 @@ const (
TimestampArrayOid = 1115
TimestampTzOid = 1184
TimestampTzArrayOid = 1185
RecordOid = 2249
UuidOid = 2950
JsonbOid = 3802
)
@ -91,8 +93,11 @@ func init() {
"int4": BinaryFormatCode,
"int8": BinaryFormatCode,
"oid": BinaryFormatCode,
"record": BinaryFormatCode,
"text": BinaryFormatCode,
"timestamp": BinaryFormatCode,
"timestamptz": BinaryFormatCode,
"varchar": BinaryFormatCode,
}
}
@ -807,6 +812,8 @@ func Decode(vr *ValueReader, d interface{}) error {
*v = decodeTimestampArray(vr)
case *[][]byte:
*v = decodeByteaArray(vr)
case *[]interface{}:
*v = decodeRecord(vr)
case *time.Time:
switch vr.Type().DataType {
case DateOid:
@ -1613,6 +1620,77 @@ func encodeIP(w *WriteBuf, oid Oid, value net.IP) error {
return encodeIPNet(w, oid, ipnet)
}
func decodeRecord(vr *ValueReader) []interface{} {
if vr.Len() == -1 {
return nil
}
if vr.Type().FormatCode != BinaryFormatCode {
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
return nil
}
if vr.Type().DataType != RecordOid {
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []interface{}", vr.Type().DataType)))
return nil
}
valueCount := vr.ReadInt32()
record := make([]interface{}, 0, int(valueCount))
for i := int32(0); i < valueCount; i++ {
fd := FieldDescription{FormatCode: BinaryFormatCode}
fieldVR := ValueReader{mr: vr.mr, fd: &fd}
fd.DataType = vr.ReadOid()
fieldVR.valueBytesRemaining = vr.ReadInt32()
vr.valueBytesRemaining -= fieldVR.valueBytesRemaining
switch fd.DataType {
case BoolOid:
record = append(record, decodeBool(&fieldVR))
case ByteaOid:
record = append(record, decodeBytea(&fieldVR))
case Int8Oid:
record = append(record, decodeInt8(&fieldVR))
case Int2Oid:
record = append(record, decodeInt2(&fieldVR))
case Int4Oid:
record = append(record, decodeInt4(&fieldVR))
case OidOid:
record = append(record, decodeOid(&fieldVR))
case Float4Oid:
record = append(record, decodeFloat4(&fieldVR))
case Float8Oid:
record = append(record, decodeFloat8(&fieldVR))
case DateOid:
record = append(record, decodeDate(&fieldVR))
case TimestampTzOid:
record = append(record, decodeTimestampTz(&fieldVR))
case TimestampOid:
record = append(record, decodeTimestamp(&fieldVR))
case InetOid, CidrOid:
record = append(record, decodeInet(&fieldVR))
case TextOid, VarcharOid, UnknownOid:
record = append(record, decodeText(&fieldVR))
default:
vr.Fatal(fmt.Errorf("decodeRecord cannot decode oid %d", fd.DataType))
return nil
}
// Consume any remaining data
if fieldVR.Len() > 0 {
fieldVR.ReadBytes(fieldVR.Len())
}
if fieldVR.Err() != nil {
vr.Fatal(fieldVR.Err())
return nil
}
}
return record
}
func decode1dArrayHeader(vr *ValueReader) (length int32, err error) {
numDims := vr.ReadInt32()
if numDims > 1 {

View File

@ -959,3 +959,40 @@ func TestPointerPointerNonZero(t *testing.T) {
t.Errorf("Expected dest to be nil, got %#v", dest)
}
}
func TestRowDecode(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
tests := []struct {
sql string
expected []interface{}
}{
{
"select row(1, 'cat', '2015-01-01 08:12:42'::timestamptz)",
[]interface{}{
int32(1),
"cat",
time.Date(2015, 1, 1, 8, 12, 42, 0, time.Local),
},
},
}
for i, tt := range tests {
var actual []interface{}
err := conn.QueryRow(tt.sql).Scan(&actual)
if err != nil {
t.Errorf("%d. Unexpected failure: %v (sql -> %v)", i, err, tt.sql)
continue
}
if !reflect.DeepEqual(actual, tt.expected) {
t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.expected, actual, tt.sql)
}
ensureConnValid(t, conn)
}
}