diff --git a/CHANGELOG.md b/CHANGELOG.md index 4aae2ce5..93c69721 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,15 +1,13 @@ # Unreleased -## Features - -* Add PrepareEx - ## Fixes * Fix *ConnPool.Deallocate() not deleting prepared statement from map ## Features +* Add PrepareEx +* Add basic record to []interface{} decoding * Encode and decode between all Go and PostgreSQL integer types with bounds checking * Decode inet/cidr to net.IP * Encode/decode [][]byte to/from bytea[] diff --git a/conn.go b/conn.go index f9b89a78..c2519003 100644 --- a/conn.go +++ b/conn.go @@ -332,7 +332,14 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl } func (c *Conn) loadPgTypes() error { - rows, err := c.Query("select t.oid, t.typname from pg_type t left join pg_type base_type on t.typelem=base_type.oid where t.typtype='b' and (base_type.oid is null or base_type.typtype='b');") + rows, err := c.Query(`select t.oid, t.typname +from pg_type t +left join pg_type base_type on t.typelem=base_type.oid +where ( + t.typtype='b' + and (base_type.oid is null or base_type.typtype='b') + ) + or t.typname in('record');`) if err != nil { return err } @@ -910,7 +917,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} wbuf.WriteInt16(TextFormatCode) default: switch oid { - case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampOid, TimestampArrayOid, DateOid, BoolArrayOid, ByteaArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid, InetOid, CidrOid, InetArrayOid, CidrArrayOid: + case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampOid, TimestampArrayOid, DateOid, BoolArrayOid, ByteaArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid, InetOid, CidrOid, InetArrayOid, CidrArrayOid, RecordOid: wbuf.WriteInt16(BinaryFormatCode) default: wbuf.WriteInt16(TextFormatCode) diff --git a/values.go b/values.go index 734ad2ba..f80d7519 100644 --- a/values.go +++ b/values.go @@ -27,6 +27,7 @@ const ( CidrArrayOid = 651 Float4Oid = 700 Float8Oid = 701 + UnknownOid = 705 InetOid = 869 BoolArrayOid = 1000 Int2ArrayOid = 1005 @@ -44,6 +45,7 @@ const ( TimestampArrayOid = 1115 TimestampTzOid = 1184 TimestampTzArrayOid = 1185 + RecordOid = 2249 UuidOid = 2950 JsonbOid = 3802 ) @@ -91,8 +93,11 @@ func init() { "int4": BinaryFormatCode, "int8": BinaryFormatCode, "oid": BinaryFormatCode, + "record": BinaryFormatCode, + "text": BinaryFormatCode, "timestamp": BinaryFormatCode, "timestamptz": BinaryFormatCode, + "varchar": BinaryFormatCode, } } @@ -807,6 +812,8 @@ func Decode(vr *ValueReader, d interface{}) error { *v = decodeTimestampArray(vr) case *[][]byte: *v = decodeByteaArray(vr) + case *[]interface{}: + *v = decodeRecord(vr) case *time.Time: switch vr.Type().DataType { case DateOid: @@ -1613,6 +1620,77 @@ func encodeIP(w *WriteBuf, oid Oid, value net.IP) error { return encodeIPNet(w, oid, ipnet) } +func decodeRecord(vr *ValueReader) []interface{} { + if vr.Len() == -1 { + return nil + } + + if vr.Type().FormatCode != BinaryFormatCode { + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return nil + } + + if vr.Type().DataType != RecordOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []interface{}", vr.Type().DataType))) + return nil + } + + valueCount := vr.ReadInt32() + record := make([]interface{}, 0, int(valueCount)) + + for i := int32(0); i < valueCount; i++ { + fd := FieldDescription{FormatCode: BinaryFormatCode} + fieldVR := ValueReader{mr: vr.mr, fd: &fd} + fd.DataType = vr.ReadOid() + fieldVR.valueBytesRemaining = vr.ReadInt32() + vr.valueBytesRemaining -= fieldVR.valueBytesRemaining + + switch fd.DataType { + case BoolOid: + record = append(record, decodeBool(&fieldVR)) + case ByteaOid: + record = append(record, decodeBytea(&fieldVR)) + case Int8Oid: + record = append(record, decodeInt8(&fieldVR)) + case Int2Oid: + record = append(record, decodeInt2(&fieldVR)) + case Int4Oid: + record = append(record, decodeInt4(&fieldVR)) + case OidOid: + record = append(record, decodeOid(&fieldVR)) + case Float4Oid: + record = append(record, decodeFloat4(&fieldVR)) + case Float8Oid: + record = append(record, decodeFloat8(&fieldVR)) + case DateOid: + record = append(record, decodeDate(&fieldVR)) + case TimestampTzOid: + record = append(record, decodeTimestampTz(&fieldVR)) + case TimestampOid: + record = append(record, decodeTimestamp(&fieldVR)) + case InetOid, CidrOid: + record = append(record, decodeInet(&fieldVR)) + case TextOid, VarcharOid, UnknownOid: + record = append(record, decodeText(&fieldVR)) + default: + vr.Fatal(fmt.Errorf("decodeRecord cannot decode oid %d", fd.DataType)) + return nil + } + + // Consume any remaining data + if fieldVR.Len() > 0 { + fieldVR.ReadBytes(fieldVR.Len()) + } + + if fieldVR.Err() != nil { + vr.Fatal(fieldVR.Err()) + return nil + } + } + + return record +} + func decode1dArrayHeader(vr *ValueReader) (length int32, err error) { numDims := vr.ReadInt32() if numDims > 1 { diff --git a/values_test.go b/values_test.go index 14a8aa17..0e29c7d1 100644 --- a/values_test.go +++ b/values_test.go @@ -959,3 +959,40 @@ func TestPointerPointerNonZero(t *testing.T) { t.Errorf("Expected dest to be nil, got %#v", dest) } } + +func TestRowDecode(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + tests := []struct { + sql string + expected []interface{} + }{ + { + "select row(1, 'cat', '2015-01-01 08:12:42'::timestamptz)", + []interface{}{ + int32(1), + "cat", + time.Date(2015, 1, 1, 8, 12, 42, 0, time.Local), + }, + }, + } + + for i, tt := range tests { + var actual []interface{} + + err := conn.QueryRow(tt.sql).Scan(&actual) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v)", i, err, tt.sql) + continue + } + + if !reflect.DeepEqual(actual, tt.expected) { + t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.expected, actual, tt.sql) + } + + ensureConnValid(t, conn) + } +}