diff --git a/conn.go b/conn.go index 0b06a4aa..f256b714 100644 --- a/conn.go +++ b/conn.go @@ -549,7 +549,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} wbuf.WriteInt16(TextFormatCode) default: switch oid { - case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampArrayOid, BoolArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid: + case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampArrayOid, BoolArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid: wbuf.WriteInt16(BinaryFormatCode) default: wbuf.WriteInt16(TextFormatCode) @@ -610,7 +610,9 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} case VarcharArrayOid: err = encodeTextArray(wbuf, arguments[i], VarcharOid) case TimestampArrayOid: - err = encodeTimestampArray(wbuf, arguments[i], VarcharOid) + err = encodeTimestampArray(wbuf, arguments[i], TimestampOid) + case TimestampTzArrayOid: + err = encodeTimestampArray(wbuf, arguments[i], TimestampTzOid) case OidOid: err = encodeOid(wbuf, arguments[i]) default: diff --git a/query.go b/query.go index 9dcbd885..8f7e672a 100644 --- a/query.go +++ b/query.go @@ -342,7 +342,7 @@ func (rows *Rows) Values() ([]interface{}, error) { values = append(values, decodeFloat8Array(vr)) case TextArrayOid, VarcharArrayOid: values = append(values, decodeTextArray(vr)) - case TimestampArrayOid: + case TimestampArrayOid, TimestampTzArrayOid: values = append(values, decodeTimestampArray(vr)) case DateOid: values = append(values, decodeDate(vr)) diff --git a/values.go b/values.go index 79ea6c96..1ba54b01 100644 --- a/values.go +++ b/values.go @@ -12,28 +12,29 @@ import ( // PostgreSQL oids for common types const ( - BoolOid = 16 - ByteaOid = 17 - Int8Oid = 20 - Int2Oid = 21 - Int4Oid = 23 - TextOid = 25 - OidOid = 26 - Float4Oid = 700 - Float8Oid = 701 - BoolArrayOid = 1000 - Int2ArrayOid = 1005 - Int4ArrayOid = 1007 - TextArrayOid = 1009 - VarcharArrayOid = 1015 - Int8ArrayOid = 1016 - Float4ArrayOid = 1021 - Float8ArrayOid = 1022 - VarcharOid = 1043 - DateOid = 1082 - TimestampOid = 1114 - TimestampArrayOid = 1115 - TimestampTzOid = 1184 + BoolOid = 16 + ByteaOid = 17 + Int8Oid = 20 + Int2Oid = 21 + Int4Oid = 23 + TextOid = 25 + OidOid = 26 + Float4Oid = 700 + Float8Oid = 701 + BoolArrayOid = 1000 + Int2ArrayOid = 1005 + Int4ArrayOid = 1007 + TextArrayOid = 1009 + VarcharArrayOid = 1015 + Int8ArrayOid = 1016 + Float4ArrayOid = 1021 + Float8ArrayOid = 1022 + VarcharOid = 1043 + DateOid = 1082 + TimestampOid = 1114 + TimestampArrayOid = 1115 + TimestampTzOid = 1184 + TimestampTzArrayOid = 1185 ) // PostgreSQL format codes @@ -60,6 +61,7 @@ func init() { DefaultTypeFormats["_text"] = BinaryFormatCode DefaultTypeFormats["_varchar"] = BinaryFormatCode DefaultTypeFormats["_timestamp"] = BinaryFormatCode + DefaultTypeFormats["_timestamptz"] = BinaryFormatCode DefaultTypeFormats["bool"] = BinaryFormatCode DefaultTypeFormats["bytea"] = BinaryFormatCode DefaultTypeFormats["date"] = BinaryFormatCode @@ -1596,7 +1598,7 @@ func decodeTimestampArray(vr *ValueReader) []time.Time { return nil } - if vr.Type().DataType != TimestampArrayOid { + if vr.Type().DataType != TimestampArrayOid && vr.Type().DataType != TimestampTzArrayOid { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []time.Time", vr.Type().DataType))) return nil } @@ -1638,7 +1640,7 @@ func encodeTimestampArray(w *WriteBuf, value interface{}, elOid Oid) error { return fmt.Errorf("Expected []time.Time, received %T", value) } - encodeArrayHeader(w, TimestampOid, len(slice), 12) + encodeArrayHeader(w, int(elOid), len(slice), 12) for _, t := range slice { w.WriteInt32(8) microsecSinceUnixEpoch := t.Unix()*1000000 + int64(t.Nanosecond())/1000 diff --git a/values_test.go b/values_test.go index 285bd3bd..3c5b35e6 100644 --- a/values_test.go +++ b/values_test.go @@ -194,7 +194,15 @@ func TestArrayDecoding(t *testing.T) { "select $1::timestamp[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{}, func(t *testing.T, query, scan interface{}) { if reflect.DeepEqual(query, *(scan.(*[]time.Time))) == false { - t.Errorf("failed to encode time.Time[]") + t.Errorf("failed to encode time.Time[] to timestamp[]") + } + }, + }, + { + "select $1::timestamptz[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{}, + func(t *testing.T, query, scan interface{}) { + if reflect.DeepEqual(query, *(scan.(*[]time.Time))) == false { + t.Errorf("failed to encode time.Time[] to timestamptz[]") } }, },