diff --git a/pgtype/date.go b/pgtype/date.go new file mode 100644 index 00000000..840414b6 --- /dev/null +++ b/pgtype/date.go @@ -0,0 +1,93 @@ +package pgtype + +import ( + "fmt" + "io" + "time" + + "github.com/jackc/pgx/pgio" +) + +type Date struct { + // time.Time is embedded to hide internal implementation. Possibly do date + // implementation at some point rather than simply delegating to time.Time. + t time.Time +} + +func (d *Date) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + return fmt.Errorf("invalid length for int8: %v", size) + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + d.t, err = time.Parse("2006-01-02", string(buf)) + if err != nil { + return err + } + + return nil +} + +func (d *Date) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size != 4 { + return fmt.Errorf("invalid length for date: %v", size) + } + + dayOffset, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + d.t = time.Date(2000, 1, int(1+dayOffset), 0, 0, 0, 0, time.Local) + + return nil +} + +func (d Date) EncodeText(w io.Writer) error { + _, err := pgio.WriteInt32(w, 10) + if err != nil { + return nil + } + + _, err = w.Write([]byte(d.t.Format("2006-01-02"))) + return err +} + +func (d Date) EncodeBinary(w io.Writer) error { + _, err := pgio.WriteInt32(w, 4) + if err != nil { + return err + } + + tUnix := time.Date(d.t.Year(), d.t.Month(), d.t.Day(), 0, 0, 0, 0, time.UTC).Unix() + dateEpoch := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC).Unix() + + secSinceDateEpoch := tUnix - dateEpoch + daysSinceDateEpoch := secSinceDateEpoch / 86400 + + _, err = pgio.WriteInt32(w, int32(daysSinceDateEpoch)) + return err +} + +func (d Date) Time() time.Time { + return d.t +} + +func DateFromTime(t time.Time) Date { + return Date{t: t} +} diff --git a/values.go b/values.go index 3614febb..4fc2a500 100644 --- a/values.go +++ b/values.go @@ -2228,43 +2228,42 @@ func encodeJSONB(w *WriteBuf, oid OID, value interface{}) error { } func decodeDate(vr *ValueReader) time.Time { - var zeroTime time.Time - if vr.Len() == -1 { vr.Fatal(ProtocolError("Cannot decode null into time.Time")) - return zeroTime + return time.Time{} } if vr.Type().DataType != DateOID { vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into time.Time", vr.Type().DataType))) - return zeroTime + return time.Time{} } - if vr.Type().FormatCode != BinaryFormatCode { + vr.err = errRewoundLen + + var d pgtype.Date + var err error + switch vr.Type().FormatCode { + case TextFormatCode: + err = d.DecodeText(&valueReader2{vr}) + case BinaryFormatCode: + err = d.DecodeBinary(&valueReader2{vr}) + default: vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return zeroTime + return time.Time{} } - if vr.Len() != 4 { - vr.Fatal(ProtocolError(fmt.Sprintf("Received an invalid size for an date: %d", vr.Len()))) + if err != nil { + vr.Fatal(err) + return time.Time{} } - dayOffset := vr.ReadInt32() - return time.Date(2000, 1, int(1+dayOffset), 0, 0, 0, 0, time.Local) + + return d.Time() } func encodeTime(w *WriteBuf, oid OID, value time.Time) error { switch oid { case DateOID: - tUnix := time.Date(value.Year(), value.Month(), value.Day(), 0, 0, 0, 0, time.UTC).Unix() - dateEpoch := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC).Unix() - - secSinceDateEpoch := tUnix - dateEpoch - daysSinceDateEpoch := secSinceDateEpoch / 86400 - - w.WriteInt32(4) - w.WriteInt32(int32(daysSinceDateEpoch)) - - return nil + return pgtype.DateFromTime(value).EncodeBinary(w) case TimestampTzOID, TimestampOID: microsecSinceUnixEpoch := value.Unix()*1000000 + int64(value.Nanosecond())/1000 microsecSinceY2K := microsecSinceUnixEpoch - microsecFromUnixEpochToY2K