From fb55203324d02724bdd984775fe8924c75ef66fa Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 31 Jul 2014 13:35:44 -0500 Subject: [PATCH] Add support for varchar[] --- conn.go | 6 ++++-- query.go | 2 +- query_test.go | 2 ++ values.go | 46 ++++++++++++++++++++++++---------------------- 4 files changed, 31 insertions(+), 25 deletions(-) diff --git a/conn.go b/conn.go index 10c17345..5f1cf262 100644 --- a/conn.go +++ b/conn.go @@ -471,7 +471,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} wbuf.WriteInt16(TextFormatCode) default: switch oid { - case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid: + case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid: wbuf.WriteInt16(BinaryFormatCode) default: wbuf.WriteInt16(TextFormatCode) @@ -526,7 +526,9 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} case Float8ArrayOid: err = encodeFloat8Array(wbuf, arguments[i]) case TextArrayOid: - err = encodeTextArray(wbuf, arguments[i]) + err = encodeTextArray(wbuf, arguments[i], TextOid) + case VarcharArrayOid: + err = encodeTextArray(wbuf, arguments[i], VarcharOid) default: return SerializationError(fmt.Sprintf("%T is not a core type and it does not implement Encoder", arg)) } diff --git a/query.go b/query.go index 033f0a2e..e659ad22 100644 --- a/query.go +++ b/query.go @@ -306,7 +306,7 @@ func (rows *Rows) Values() ([]interface{}, error) { values = append(values, decodeFloat4Array(vr)) case Float8ArrayOid: values = append(values, decodeFloat8Array(vr)) - case TextArrayOid: + case TextArrayOid, VarcharArrayOid: values = append(values, decodeTextArray(vr)) case DateOid: values = append(values, decodeDate(vr)) diff --git a/query_test.go b/query_test.go index 44955ccd..506afa3c 100644 --- a/query_test.go +++ b/query_test.go @@ -738,6 +738,8 @@ func TestQueryRowCoreStringSlice(t *testing.T) { }{ {"select $1::text[]", []string{"Adam", "Eve", "UTF-8 Characters Å Æ Ë Ͽ"}}, {"select $1::text[]", []string{}}, + {"select $1::varchar[]", []string{"Adam", "Eve", "UTF-8 Characters Å Æ Ë Ͽ"}}, + {"select $1::varchar[]", []string{}}, } for i, tt := range tests { diff --git a/values.go b/values.go index 57fa83e8..f5e6c97d 100644 --- a/values.go +++ b/values.go @@ -11,24 +11,25 @@ import ( // PostgreSQL oids for common types const ( - BoolOid = 16 - ByteaOid = 17 - Int8Oid = 20 - Int2Oid = 21 - Int4Oid = 23 - TextOid = 25 - Float4Oid = 700 - Float8Oid = 701 - Int2ArrayOid = 1005 - Int4ArrayOid = 1007 - TextArrayOid = 1009 - Int8ArrayOid = 1016 - Float4ArrayOid = 1021 - Float8ArrayOid = 1022 - VarcharOid = 1043 - DateOid = 1082 - TimestampOid = 1114 - TimestampTzOid = 1184 + BoolOid = 16 + ByteaOid = 17 + Int8Oid = 20 + Int2Oid = 21 + Int4Oid = 23 + TextOid = 25 + Float4Oid = 700 + Float8Oid = 701 + Int2ArrayOid = 1005 + Int4ArrayOid = 1007 + TextArrayOid = 1009 + VarcharArrayOid = 1015 + Int8ArrayOid = 1016 + Float4ArrayOid = 1021 + Float8ArrayOid = 1022 + VarcharOid = 1043 + DateOid = 1082 + TimestampOid = 1114 + TimestampTzOid = 1184 ) // PostgreSQL format codes @@ -56,6 +57,7 @@ func init() { DefaultOidFormats[Float4ArrayOid] = BinaryFormatCode DefaultOidFormats[Float8ArrayOid] = BinaryFormatCode DefaultOidFormats[TextArrayOid] = BinaryFormatCode + DefaultOidFormats[VarcharArrayOid] = BinaryFormatCode } type SerializationError string @@ -1318,8 +1320,8 @@ func decodeTextArray(vr *ValueReader) []string { return nil } - if vr.Type().DataType != TextArrayOid { - vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v but received type oid %v", TextArrayOid, vr.Type().DataType))) + if vr.Type().DataType != TextArrayOid && vr.Type().DataType != VarcharArrayOid { + vr.Fatal(ProtocolError(fmt.Sprintf("Expected type oid %v or %v but received type oid %v", TextArrayOid, VarcharArrayOid, vr.Type().DataType))) return nil } @@ -1348,7 +1350,7 @@ func decodeTextArray(vr *ValueReader) []string { return a } -func encodeTextArray(w *WriteBuf, value interface{}) error { +func encodeTextArray(w *WriteBuf, value interface{}, elOid Oid) error { slice, ok := value.([]string) if !ok { return fmt.Errorf("Expected []string, received %T", value) @@ -1364,7 +1366,7 @@ func encodeTextArray(w *WriteBuf, value interface{}) error { w.WriteInt32(1) // number of dimensions w.WriteInt32(0) // no nulls - w.WriteInt32(TextOid) // type of elements + w.WriteInt32(int32(elOid)) // type of elements w.WriteInt32(int32(len(slice))) // number of elements w.WriteInt32(1) // index of first element