diff --git a/conn.go b/conn.go index d564326c..b0f2cac0 100644 --- a/conn.go +++ b/conn.go @@ -741,9 +741,9 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} wbuf.WriteInt16(TextFormatCode) default: switch oid { - case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid: + case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid: wbuf.WriteInt16(BinaryFormatCode) - case TextOid, VarcharOid, DateOid, TimestampTzOid: + case TextOid, VarcharOid, DateOid: wbuf.WriteInt16(TextFormatCode) default: return SerializationError(fmt.Sprintf("Parameter %d oid %d is not a core type and argument type %T does not implement TextEncoder or BinaryEncoder", i, oid, arg)) diff --git a/values.go b/values.go index d4af61eb..ea86556f 100644 --- a/values.go +++ b/values.go @@ -324,6 +324,45 @@ func (n NullBool) EncodeBinary(w *WriteBuf) error { return encodeBool(w, n.Bool) } +// NullTime represents an bigint that may be null. +// NullTime implements the Scanner, TextEncoder, and BinaryEncoder interfaces +// so it may be used both as an argument to Query[Row] and a destination for +// Scan for prepared and unprepared queries. +// +// If Valid is false then the value is NULL. +type NullTime struct { + Time time.Time + Valid bool // Valid is true if Time is not NULL +} + +func (n *NullTime) Scan(rows *Rows, fd *FieldDescription, size int32) error { + if size == -1 { + n.Time, n.Valid = time.Time{}, false + return nil + } + n.Valid = true + n.Time = decodeTimestampTz(rows, fd, size) + + return rows.Err() +} + +func (n NullTime) EncodeText() (string, byte, error) { + if n.Valid { + return n.Time.Format("2006-01-02 15:04:05.999999 -0700"), UnsafeText, nil + } else { + return "", NullText, nil + } +} + +func (n NullTime) EncodeBinary(w *WriteBuf) error { + if !n.Valid { + w.WriteInt32(-1) + return nil + } + + return encodeTimestampTz(w, n.Time) +} + var literalPattern *regexp.Regexp = regexp.MustCompile(`\$\d+`) // QuoteString escapes and quotes a string making it safe for interpolation @@ -843,6 +882,8 @@ func encodeDate(w *WriteBuf, value interface{}) error { return encodeText(w, s) } +const microsecFromUnixEpochToY2K = 946684800 * 1000000 + func decodeTimestampTz(rows *Rows, fd *FieldDescription, size int32) time.Time { var zeroTime time.Time @@ -864,7 +905,6 @@ func decodeTimestampTz(rows *Rows, fd *FieldDescription, size int32) time.Time { if size != 8 { rows.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an timestamptz: %d", size))) } - microsecFromUnixEpochToY2K := int64(946684800 * 1000000) microsecSinceY2K := rows.mr.ReadInt64() microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K return time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000) @@ -880,6 +920,11 @@ func encodeTimestampTz(w *WriteBuf, value interface{}) error { return fmt.Errorf("Expected time.Time, received %T", value) } - s := t.Format("2006-01-02 15:04:05.999999 -0700") - return encodeText(w, s) + microsecSinceUnixEpoch := t.Unix()*1000000 + int64(t.Nanosecond())/1000 + microsecSinceY2K := microsecSinceUnixEpoch - microsecFromUnixEpochToY2K + + w.WriteInt32(8) + w.WriteInt64(microsecSinceY2K) + + return nil } diff --git a/values_test.go b/values_test.go index aa2151c1..0c208836 100644 --- a/values_test.go +++ b/values_test.go @@ -201,6 +201,7 @@ func TestNullX(t *testing.T) { f32 pgx.NullFloat32 f64 pgx.NullFloat64 b pgx.NullBool + t pgx.NullTime } var actual, zero allTypes @@ -225,6 +226,8 @@ func TestNullX(t *testing.T) { {"select $1::float8", []interface{}{pgx.NullFloat64{Float64: 1.23, Valid: false}}, []interface{}{&actual.f64}, allTypes{f64: pgx.NullFloat64{Float64: 0, Valid: false}}}, {"select $1::bool", []interface{}{pgx.NullBool{Bool: true, Valid: true}}, []interface{}{&actual.b}, allTypes{b: pgx.NullBool{Bool: true, Valid: true}}}, {"select $1::bool", []interface{}{pgx.NullBool{Bool: true, Valid: false}}, []interface{}{&actual.b}, allTypes{b: pgx.NullBool{Bool: false, Valid: false}}}, + {"select $1::timestamptz", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}}, + {"select $1::timestamptz", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: false}}, []interface{}{&actual.b}, allTypes{t: pgx.NullTime{Time: time.Time{}, Valid: false}}}, } for i, tt := range tests {