From 0ddf94ef9d13a10aa6cb21ecb99985919edebbe9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 8 Aug 2014 10:51:54 -0500 Subject: [PATCH] Add pgx.Oid serialization --- conn.go | 4 +++- query.go | 2 ++ query_test.go | 2 ++ values.go | 45 +++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 52 insertions(+), 1 deletion(-) diff --git a/conn.go b/conn.go index 98229ad5..5b7dfc5a 100644 --- a/conn.go +++ b/conn.go @@ -471,7 +471,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} wbuf.WriteInt16(TextFormatCode) default: switch oid { - case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid: + case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid: wbuf.WriteInt16(BinaryFormatCode) default: wbuf.WriteInt16(TextFormatCode) @@ -529,6 +529,8 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} err = encodeTextArray(wbuf, arguments[i], TextOid) case VarcharArrayOid: err = encodeTextArray(wbuf, arguments[i], VarcharOid) + case OidOid: + err = encodeOid(wbuf, arguments[i]) default: return SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) } diff --git a/query.go b/query.go index e659ad22..d6c73e20 100644 --- a/query.go +++ b/query.go @@ -208,6 +208,8 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) { *d = decodeInt2(vr) case *int32: *d = decodeInt4(vr) + case *Oid: + *d = decodeOid(vr) case *string: *d = decodeText(vr) case *float32: diff --git a/query_test.go b/query_test.go index 0e6e36fa..9e31a009 100644 --- a/query_test.go +++ b/query_test.go @@ -332,6 +332,7 @@ func TestQueryRowCoreTypes(t *testing.T) { f64 float64 b bool t time.Time + oid pgx.Oid } var actual, zero allTypes @@ -352,6 +353,7 @@ func TestQueryRowCoreTypes(t *testing.T) { {"select $1::timestamptz", []interface{}{time.Unix(123, 5000)}, []interface{}{&actual.t}, allTypes{t: time.Unix(123, 5000)}}, {"select $1::timestamp", []interface{}{time.Date(2010, 1, 2, 3, 4, 5, 0, time.Local)}, []interface{}{&actual.t}, allTypes{t: time.Date(2010, 1, 2, 3, 4, 5, 0, time.Local)}}, {"select $1::date", []interface{}{time.Date(1987, 1, 2, 0, 0, 0, 0, time.Local)}, []interface{}{&actual.t}, allTypes{t: time.Date(1987, 1, 2, 0, 0, 0, 0, time.Local)}}, + {"select $1::oid", []interface{}{pgx.Oid(42)}, []interface{}{&actual.oid}, allTypes{oid: 42}}, } for i, tt := range tests { diff --git a/values.go b/values.go index 65af3568..cd2c4e09 100644 --- a/values.go +++ b/values.go @@ -17,6 +17,7 @@ const ( Int2Oid = 21 Int4Oid = 23 TextOid = 25 + OidOid = 26 Float4Oid = 700 Float8Oid = 701 Int2ArrayOid = 1005 @@ -58,6 +59,7 @@ func init() { DefaultOidFormats[Float8ArrayOid] = BinaryFormatCode DefaultOidFormats[TextArrayOid] = BinaryFormatCode DefaultOidFormats[VarcharArrayOid] = BinaryFormatCode + DefaultOidFormats[OidOid] = BinaryFormatCode } type SerializationError string @@ -680,6 +682,49 @@ func encodeInt4(w *WriteBuf, value interface{}) error { return nil } +func decodeOid(vr *ValueReader) Oid { + if vr.Len() == -1 { + vr.Fatal(ProtocolError("Cannot decode null into Oid")) + return 0 + } + + if vr.Type().DataType != OidOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", OidOid, vr.Type().DataType))) + return 0 + } + + switch vr.Type().FormatCode { + case TextFormatCode: + s := vr.ReadString(vr.Len()) + n, err := strconv.ParseInt(s, 10, 32) + if err != nil { + vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid Oid: %v", s))) + } + return Oid(n) + case BinaryFormatCode: + if vr.Len() != 4 { + vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an Oid: %d", vr.Len()))) + return 0 + } + return Oid(vr.ReadInt32()) + default: + vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) + return Oid(0) + } +} + +func encodeOid(w *WriteBuf, value interface{}) error { + v, ok := value.(Oid) + if !ok { + return fmt.Errorf("Expected Oid, received %T", value) + } + + w.WriteInt32(4) + w.WriteInt32(int32(v)) + + return nil +} + func decodeFloat4(vr *ValueReader) float32 { if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into float32"))