Add pgx.Oid serialization

pull/29/head
Jack Christensen 2014-08-08 10:51:54 -05:00
parent 4d4a45fc34
commit 0ddf94ef9d
4 changed files with 52 additions and 1 deletions

View File

@ -471,7 +471,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
wbuf.WriteInt16(TextFormatCode) wbuf.WriteInt16(TextFormatCode)
default: default:
switch oid { switch oid {
case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid: case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid:
wbuf.WriteInt16(BinaryFormatCode) wbuf.WriteInt16(BinaryFormatCode)
default: default:
wbuf.WriteInt16(TextFormatCode) wbuf.WriteInt16(TextFormatCode)
@ -529,6 +529,8 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
err = encodeTextArray(wbuf, arguments[i], TextOid) err = encodeTextArray(wbuf, arguments[i], TextOid)
case VarcharArrayOid: case VarcharArrayOid:
err = encodeTextArray(wbuf, arguments[i], VarcharOid) err = encodeTextArray(wbuf, arguments[i], VarcharOid)
case OidOid:
err = encodeOid(wbuf, arguments[i])
default: default:
return SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) return SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
} }

View File

@ -208,6 +208,8 @@ func (rows *Rows) Scan(dest ...interface{}) (err error) {
*d = decodeInt2(vr) *d = decodeInt2(vr)
case *int32: case *int32:
*d = decodeInt4(vr) *d = decodeInt4(vr)
case *Oid:
*d = decodeOid(vr)
case *string: case *string:
*d = decodeText(vr) *d = decodeText(vr)
case *float32: case *float32:

View File

@ -332,6 +332,7 @@ func TestQueryRowCoreTypes(t *testing.T) {
f64 float64 f64 float64
b bool b bool
t time.Time t time.Time
oid pgx.Oid
} }
var actual, zero allTypes var actual, zero allTypes
@ -352,6 +353,7 @@ func TestQueryRowCoreTypes(t *testing.T) {
{"select $1::timestamptz", []interface{}{time.Unix(123, 5000)}, []interface{}{&actual.t}, allTypes{t: time.Unix(123, 5000)}}, {"select $1::timestamptz", []interface{}{time.Unix(123, 5000)}, []interface{}{&actual.t}, allTypes{t: time.Unix(123, 5000)}},
{"select $1::timestamp", []interface{}{time.Date(2010, 1, 2, 3, 4, 5, 0, time.Local)}, []interface{}{&actual.t}, allTypes{t: time.Date(2010, 1, 2, 3, 4, 5, 0, time.Local)}}, {"select $1::timestamp", []interface{}{time.Date(2010, 1, 2, 3, 4, 5, 0, time.Local)}, []interface{}{&actual.t}, allTypes{t: time.Date(2010, 1, 2, 3, 4, 5, 0, time.Local)}},
{"select $1::date", []interface{}{time.Date(1987, 1, 2, 0, 0, 0, 0, time.Local)}, []interface{}{&actual.t}, allTypes{t: time.Date(1987, 1, 2, 0, 0, 0, 0, time.Local)}}, {"select $1::date", []interface{}{time.Date(1987, 1, 2, 0, 0, 0, 0, time.Local)}, []interface{}{&actual.t}, allTypes{t: time.Date(1987, 1, 2, 0, 0, 0, 0, time.Local)}},
{"select $1::oid", []interface{}{pgx.Oid(42)}, []interface{}{&actual.oid}, allTypes{oid: 42}},
} }
for i, tt := range tests { for i, tt := range tests {

View File

@ -17,6 +17,7 @@ const (
Int2Oid = 21 Int2Oid = 21
Int4Oid = 23 Int4Oid = 23
TextOid = 25 TextOid = 25
OidOid = 26
Float4Oid = 700 Float4Oid = 700
Float8Oid = 701 Float8Oid = 701
Int2ArrayOid = 1005 Int2ArrayOid = 1005
@ -58,6 +59,7 @@ func init() {
DefaultOidFormats[Float8ArrayOid] = BinaryFormatCode DefaultOidFormats[Float8ArrayOid] = BinaryFormatCode
DefaultOidFormats[TextArrayOid] = BinaryFormatCode DefaultOidFormats[TextArrayOid] = BinaryFormatCode
DefaultOidFormats[VarcharArrayOid] = BinaryFormatCode DefaultOidFormats[VarcharArrayOid] = BinaryFormatCode
DefaultOidFormats[OidOid] = BinaryFormatCode
} }
type SerializationError string type SerializationError string
@ -680,6 +682,49 @@ func encodeInt4(w *WriteBuf, value interface{}) error {
return nil return nil
} }
func decodeOid(vr *ValueReader) Oid {
if vr.Len() == -1 {
vr.Fatal(ProtocolError("Cannot decode null into Oid"))
return 0
}
if vr.Type().DataType != OidOid {
vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", OidOid, vr.Type().DataType)))
return 0
}
switch vr.Type().FormatCode {
case TextFormatCode:
s := vr.ReadString(vr.Len())
n, err := strconv.ParseInt(s, 10, 32)
if err != nil {
vr.Fatal(ProtocolError(fmt.Sprintf("Received invalid Oid: %v", s)))
}
return Oid(n)
case BinaryFormatCode:
if vr.Len() != 4 {
vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an Oid: %d", vr.Len())))
return 0
}
return Oid(vr.ReadInt32())
default:
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
return Oid(0)
}
}
func encodeOid(w *WriteBuf, value interface{}) error {
v, ok := value.(Oid)
if !ok {
return fmt.Errorf("Expected Oid, received %T", value)
}
w.WriteInt32(4)
w.WriteInt32(int32(v))
return nil
}
func decodeFloat4(vr *ValueReader) float32 { func decodeFloat4(vr *ValueReader) float32 {
if vr.Len() == -1 { if vr.Len() == -1 {
vr.Fatal(ProtocolError("Cannot decode null into float32")) vr.Fatal(ProtocolError("Cannot decode null into float32"))