diff --git a/conn.go b/conn.go index fdc7fe59..ce0539e6 100644 --- a/conn.go +++ b/conn.go @@ -731,7 +731,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} wbuf.WriteInt16(TextFormatCode) default: switch oid { - case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampOid, TimestampArrayOid, BoolArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid, InetOid, CidrOid: + case BoolOid, ByteaOid, Int2Oid, Int4Oid, Int8Oid, Float4Oid, Float8Oid, TimestampTzOid, TimestampTzArrayOid, TimestampOid, TimestampArrayOid, DateOid, BoolArrayOid, Int2ArrayOid, Int4ArrayOid, Int8ArrayOid, Float4ArrayOid, Float8ArrayOid, TextArrayOid, VarcharArrayOid, OidOid, InetOid, CidrOid: wbuf.WriteInt16(BinaryFormatCode) default: wbuf.WriteInt16(TextFormatCode) diff --git a/values.go b/values.go index 6633ab8b..50e30e2d 100644 --- a/values.go +++ b/values.go @@ -386,9 +386,10 @@ func (n NullBool) Encode(w *WriteBuf, oid Oid) error { return encodeBool(w, n.Bool) } -// NullTime represents an bigint that may be null. NullTime implements the +// NullTime represents an time.Time that may be null. NullTime implements the // Scanner and Encoder interfaces so it may be used both as an argument to -// Query[Row] and a destination for Scan. +// Query[Row] and a destination for Scan. It corresponds with the PostgreSQL +// types timestamptz, timestamp, and date. // // If Valid is false then the value is NULL. type NullTime struct { @@ -398,7 +399,7 @@ type NullTime struct { func (n *NullTime) Scan(vr *ValueReader) error { oid := vr.Type().DataType - if oid != TimestampTzOid && oid != TimestampOid { + if oid != TimestampTzOid && oid != TimestampOid && oid != DateOid { return SerializationError(fmt.Sprintf("NullTime.Scan cannot decode OID %d", vr.Type().DataType)) } @@ -408,10 +409,13 @@ func (n *NullTime) Scan(vr *ValueReader) error { } n.Valid = true - if oid == TimestampTzOid { + switch oid { + case TimestampTzOid: n.Time = decodeTimestampTz(vr) - } else { + case TimestampOid: n.Time = decodeTimestamp(vr) + case DateOid: + n.Time = decodeDate(vr) } return vr.Err() @@ -420,7 +424,7 @@ func (n *NullTime) Scan(vr *ValueReader) error { func (n NullTime) FormatCode() int16 { return BinaryFormatCode } func (n NullTime) Encode(w *WriteBuf, oid Oid) error { - if oid != TimestampTzOid && oid != TimestampOid { + if oid != TimestampTzOid && oid != TimestampOid && oid != DateOid { return SerializationError(fmt.Sprintf("NullTime.Encode cannot encode into OID %d", oid)) } @@ -429,10 +433,15 @@ func (n NullTime) Encode(w *WriteBuf, oid Oid) error { return nil } - if oid == TimestampTzOid { + switch oid { + case TimestampTzOid: return encodeTimestampTz(w, n.Time) - } else { + case TimestampOid: return encodeTimestamp(w, n.Time) + case DateOid: + return encodeDate(w, n.Time) + default: + panic("unreachable") } } @@ -1055,8 +1064,16 @@ func encodeDate(w *WriteBuf, value interface{}) error { return fmt.Errorf("Expected time.Time, received %T", value) } - s := t.Format("2006-01-02") - return encodeText(w, s) + tUnix := time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC).Unix() + dateEpoch := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC).Unix() + + secSinceDateEpoch := tUnix - dateEpoch + daysSinceDateEpoch := secSinceDateEpoch / 86400 + + w.WriteInt32(4) + w.WriteInt32(int32(daysSinceDateEpoch)) + + return nil } const microsecFromUnixEpochToY2K = 946684800 * 1000000 diff --git a/values_test.go b/values_test.go index bdfa7cc4..46e79d11 100644 --- a/values_test.go +++ b/values_test.go @@ -17,6 +17,12 @@ func TestDateTranscode(t *testing.T) { defer closeConn(t, conn) dates := []time.Time{ + time.Date(1, 1, 1, 0, 0, 0, 0, time.Local), + time.Date(1000, 1, 1, 0, 0, 0, 0, time.Local), + time.Date(1600, 1, 1, 0, 0, 0, 0, time.Local), + time.Date(1700, 1, 1, 0, 0, 0, 0, time.Local), + time.Date(1800, 1, 1, 0, 0, 0, 0, time.Local), + time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local), time.Date(1990, 1, 1, 0, 0, 0, 0, time.Local), time.Date(1999, 12, 31, 0, 0, 0, 0, time.Local), time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), @@ -24,6 +30,11 @@ func TestDateTranscode(t *testing.T) { time.Date(2004, 2, 29, 0, 0, 0, 0, time.Local), time.Date(2013, 7, 4, 0, 0, 0, 0, time.Local), time.Date(2013, 12, 25, 0, 0, 0, 0, time.Local), + time.Date(2029, 1, 1, 0, 0, 0, 0, time.Local), + time.Date(2081, 1, 1, 0, 0, 0, 0, time.Local), + time.Date(2096, 2, 29, 0, 0, 0, 0, time.Local), + time.Date(2550, 1, 1, 0, 0, 0, 0, time.Local), + time.Date(9999, 12, 31, 0, 0, 0, 0, time.Local), } for _, actualDate := range dates { @@ -280,6 +291,8 @@ func TestNullX(t *testing.T) { {"select $1::timestamptz", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: false}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Time{}, Valid: false}}}, {"select $1::timestamp", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}}, {"select $1::timestamp", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: false}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Time{}, Valid: false}}}, + {"select $1::date", []interface{}{pgx.NullTime{Time: time.Date(1990, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Date(1990, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}}}, + {"select $1::date", []interface{}{pgx.NullTime{Time: time.Date(1990, 1, 1, 0, 0, 0, 0, time.Local), Valid: false}}, []interface{}{&actual.t}, allTypes{t: pgx.NullTime{Time: time.Time{}, Valid: false}}}, {"select 42::int4, $1::float8", []interface{}{pgx.NullFloat64{Float64: 1.23, Valid: true}}, []interface{}{&actual.i32, &actual.f64}, allTypes{i32: pgx.NullInt32{Int32: 42, Valid: true}, f64: pgx.NullFloat64{Float64: 1.23, Valid: true}}}, } @@ -473,7 +486,7 @@ func TestNullXMismatch(t *testing.T) { {"select $1::date", []interface{}{pgx.NullFloat32{Float32: 1.23, Valid: true}}, []interface{}{&actual.f32}, "cannot encode into OID 1082"}, {"select $1::date", []interface{}{pgx.NullFloat64{Float64: 1.23, Valid: true}}, []interface{}{&actual.f64}, "cannot encode into OID 1082"}, {"select $1::date", []interface{}{pgx.NullBool{Bool: true, Valid: true}}, []interface{}{&actual.b}, "cannot encode into OID 1082"}, - {"select $1::date", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}, []interface{}{&actual.t}, "cannot encode into OID 1082"}, + {"select $1::int4", []interface{}{pgx.NullTime{Time: time.Unix(123, 5000), Valid: true}}, []interface{}{&actual.t}, "cannot encode into OID 23"}, } for i, tt := range tests {