From 61b4fb76895c2677afba96fa30b9db3c3725da2b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 21 Jan 2022 16:50:30 -0600 Subject: [PATCH] Convert time to Codec --- pgtype/builtin_wrappers.go | 32 +++++ pgtype/pgtype.go | 2 +- pgtype/time.go | 273 +++++++++++++++++++------------------ pgtype/time_test.go | 140 +++++-------------- 4 files changed, 211 insertions(+), 236 deletions(-) diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go index 5b6a032a..3020f9bb 100644 --- a/pgtype/builtin_wrappers.go +++ b/pgtype/builtin_wrappers.go @@ -399,6 +399,38 @@ func (w timeWrapper) TimestamptzValue() (Timestamptz, error) { return Timestamptz{Time: time.Time(w), Valid: true}, nil } +func (w *timeWrapper) ScanTime(v Time) error { + if !v.Valid { + return fmt.Errorf("cannot scan NULL into *time.Time") + } + + // 24:00:00 is max allowed time in PostgreSQL, but time.Time will normalize that to 00:00:00 the next day. + var maxRepresentableByTime int64 = 24*60*60*1000000 - 1 + if v.Microseconds > maxRepresentableByTime { + return fmt.Errorf("%d microseconds cannot be represented as time.Time", v.Microseconds) + } + + usec := v.Microseconds + hours := usec / microsecondsPerHour + usec -= hours * microsecondsPerHour + minutes := usec / microsecondsPerMinute + usec -= minutes * microsecondsPerMinute + seconds := usec / microsecondsPerSecond + usec -= seconds * microsecondsPerSecond + ns := usec * 1000 + *w = timeWrapper(time.Date(2000, 1, 1, int(hours), int(minutes), int(seconds), int(ns), time.UTC)) + return nil +} + +func (w timeWrapper) TimeValue() (Time, error) { + t := time.Time(w) + usec := int64(t.Hour())*microsecondsPerHour + + int64(t.Minute())*microsecondsPerMinute + + int64(t.Second())*microsecondsPerSecond + + int64(t.Nanosecond())/1000 + return Time{Microseconds: usec, Valid: true}, nil +} + type durationWrapper time.Duration func (w *durationWrapper) ScanInterval(v Interval) error { diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 5e7e2501..058abf5e 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -340,7 +340,7 @@ func NewConnInfo() *ConnInfo { // ci.RegisterDataType(DataType{Value: &Record{}, Name: "record", OID: RecordOID}) ci.RegisterDataType(DataType{Name: "text", OID: TextOID, Codec: TextCodec{}}) ci.RegisterDataType(DataType{Name: "tid", OID: TIDOID, Codec: TIDCodec{}}) - ci.RegisterDataType(DataType{Value: &Time{}, Name: "time", OID: TimeOID}) + ci.RegisterDataType(DataType{Name: "time", OID: TimeOID, Codec: TimeCodec{}}) ci.RegisterDataType(DataType{Name: "timestamp", OID: TimestampOID, Codec: TimestampCodec{}}) ci.RegisterDataType(DataType{Name: "timestamptz", OID: TimestamptzOID, Codec: TimestamptzCodec{}}) // ci.RegisterDataType(DataType{Value: &Tsrange{}, Name: "tsrange", OID: TsrangeOID}) diff --git a/pgtype/time.go b/pgtype/time.go index 3252a633..47dabe99 100644 --- a/pgtype/time.go +++ b/pgtype/time.go @@ -5,11 +5,18 @@ import ( "encoding/binary" "fmt" "strconv" - "time" "github.com/jackc/pgio" ) +type TimeScanner interface { + ScanTime(v Time) error +} + +type TimeValuer interface { + TimeValue() (Time, error) +} + // Time represents the PostgreSQL time type. The PostgreSQL time is a time of day without time zone. // // Time is represented as the number of microseconds since midnight in the same way that PostgreSQL does. Other time @@ -20,86 +27,151 @@ type Time struct { Valid bool } -// Set converts src into a Time and stores in dst. -func (dst *Time) Set(src interface{}) error { +func (t *Time) ScanTime(v Time) error { + *t = v + return nil +} + +func (t Time) TimeValue() (Time, error) { + return t, nil +} + +// Scan implements the database/sql Scanner interface. +func (t *Time) Scan(src interface{}) error { if src == nil { - *dst = Time{} + *t = Time{} return nil } - if value, ok := src.(interface{ Get() interface{} }); ok { - value2 := value.Get() - if value2 != value { - return dst.Set(value2) - } + switch src := src.(type) { + case string: + return scanPlanTextAnyToTimeScanner{}.Scan(nil, 0, TextFormatCode, []byte(src), t) } - switch value := src.(type) { - case time.Time: - usec := int64(value.Hour())*microsecondsPerHour + - int64(value.Minute())*microsecondsPerMinute + - int64(value.Second())*microsecondsPerSecond + - int64(value.Nanosecond())/1000 - *dst = Time{Microseconds: usec, Valid: true} - case *time.Time: - if value == nil { - *dst = Time{} - } else { - return dst.Set(*value) - } - default: - if originalSrc, ok := underlyingTimeType(src); ok { - return dst.Set(originalSrc) - } - return fmt.Errorf("cannot convert %v to Time", value) + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (t Time) Value() (driver.Value, error) { + if !t.Valid { + return nil, nil + } + + buf, err := TimeCodec{}.PlanEncode(nil, 0, TextFormatCode, t).Encode(t, nil) + if err != nil { + return nil, err + } + return string(buf), err +} + +type TimeCodec struct{} + +func (TimeCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (TimeCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (TimeCodec) PlanEncode(ci *ConnInfo, oid uint32, format int16, value interface{}) EncodePlan { + if _, ok := value.(TimeValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanTimeCodecBinary{} + case TextFormatCode: + return encodePlanTimeCodecText{} } return nil } -func (dst Time) Get() interface{} { - if !dst.Valid { - return nil +type encodePlanTimeCodecBinary struct{} + +func (encodePlanTimeCodecBinary) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + t, err := value.(TimeValuer).TimeValue() + if err != nil { + return nil, err } - return dst.Microseconds + + if !t.Valid { + return nil, nil + } + + return pgio.AppendInt64(buf, t.Microseconds), nil } -func (src *Time) AssignTo(dst interface{}) error { - if !src.Valid { - return NullAssignTo(dst) +type encodePlanTimeCodecText struct{} + +func (encodePlanTimeCodecText) Encode(value interface{}, buf []byte) (newBuf []byte, err error) { + t, err := value.(TimeValuer).TimeValue() + if err != nil { + return nil, err } - switch v := dst.(type) { - case *time.Time: - // 24:00:00 is max allowed time in PostgreSQL, but time.Time will normalize that to 00:00:00 the next day. - var maxRepresentableByTime int64 = 24*60*60*1000000 - 1 - if src.Microseconds > maxRepresentableByTime { - return fmt.Errorf("%d microseconds cannot be represented as time.Time", src.Microseconds) - } - - usec := src.Microseconds - hours := usec / microsecondsPerHour - usec -= hours * microsecondsPerHour - minutes := usec / microsecondsPerMinute - usec -= minutes * microsecondsPerMinute - seconds := usec / microsecondsPerSecond - usec -= seconds * microsecondsPerSecond - ns := usec * 1000 - *v = time.Date(2000, 1, 1, int(hours), int(minutes), int(seconds), int(ns), time.UTC) - return nil - default: - if nextDst, retry := GetAssignToDstType(dst); retry { - return src.AssignTo(nextDst) - } - return fmt.Errorf("unable to assign to %T", dst) + if !t.Valid { + return nil, nil } + + usec := t.Microseconds + hours := usec / microsecondsPerHour + usec -= hours * microsecondsPerHour + minutes := usec / microsecondsPerMinute + usec -= minutes * microsecondsPerMinute + seconds := usec / microsecondsPerSecond + usec -= seconds * microsecondsPerSecond + + s := fmt.Sprintf("%02d:%02d:%02d.%06d", hours, minutes, seconds, usec) + + return append(buf, s...), nil } -// DecodeText decodes from src into dst. -func (dst *Time) DecodeText(ci *ConnInfo, src []byte) error { +func (TimeCodec) PlanScan(ci *ConnInfo, oid uint32, format int16, target interface{}, actualTarget bool) ScanPlan { + + switch format { + case BinaryFormatCode: + switch target.(type) { + case TimeScanner: + return scanPlanBinaryTimeToTimeScanner{} + } + case TextFormatCode: + switch target.(type) { + case TimeScanner: + return scanPlanTextAnyToTimeScanner{} + } + } + + return nil +} + +type scanPlanBinaryTimeToTimeScanner struct{} + +func (scanPlanBinaryTimeToTimeScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(TimeScanner) + if src == nil { - *dst = Time{} - return nil + return scanner.ScanTime(Time{}) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for time: %v", len(src)) + } + + usec := int64(binary.BigEndian.Uint64(src)) + + return scanner.ScanTime(Time{Microseconds: usec, Valid: true}) +} + +type scanPlanTextAnyToTimeScanner struct{} + +func (scanPlanTextAnyToTimeScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + scanner := (dst).(TimeScanner) + + if src == nil { + return scanner.ScanTime(Time{}) } s := string(src) @@ -140,79 +212,22 @@ func (dst *Time) DecodeText(ci *ConnInfo, src []byte) error { usec += n } - *dst = Time{Microseconds: usec, Valid: true} - - return nil + return scanner.ScanTime(Time{Microseconds: usec, Valid: true}) } -// DecodeBinary decodes from src into dst. -func (dst *Time) DecodeBinary(ci *ConnInfo, src []byte) error { +func (c TimeCodec) DecodeDatabaseSQLValue(ci *ConnInfo, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, ci, oid, format, src) +} + +func (c TimeCodec) DecodeValue(ci *ConnInfo, oid uint32, format int16, src []byte) (interface{}, error) { if src == nil { - *dst = Time{} - return nil - } - - if len(src) != 8 { - return fmt.Errorf("invalid length for time: %v", len(src)) - } - - usec := int64(binary.BigEndian.Uint64(src)) - *dst = Time{Microseconds: usec, Valid: true} - - return nil -} - -// EncodeText writes the text encoding of src into w. -func (src Time) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { return nil, nil } - usec := src.Microseconds - hours := usec / microsecondsPerHour - usec -= hours * microsecondsPerHour - minutes := usec / microsecondsPerMinute - usec -= minutes * microsecondsPerMinute - seconds := usec / microsecondsPerSecond - usec -= seconds * microsecondsPerSecond - - s := fmt.Sprintf("%02d:%02d:%02d.%06d", hours, minutes, seconds, usec) - - return append(buf, s...), nil -} - -// EncodeBinary writes the binary encoding of src into w. If src.Time is not in -// the UTC time zone it returns an error. -func (src Time) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { - if !src.Valid { - return nil, nil + var t Time + err := codecScan(c, ci, oid, format, src, &t) + if err != nil { + return nil, err } - - return pgio.AppendInt64(buf, src.Microseconds), nil -} - -// Scan implements the database/sql Scanner interface. -func (dst *Time) Scan(src interface{}) error { - if src == nil { - *dst = Time{} - return nil - } - - switch src := src.(type) { - case string: - return dst.DecodeText(nil, []byte(src)) - case []byte: - srcCopy := make([]byte, len(src)) - copy(srcCopy, src) - return dst.DecodeText(nil, srcCopy) - case time.Time: - return dst.Set(src) - } - - return fmt.Errorf("cannot scan %T", src) -} - -// Value implements the database/sql/driver Valuer interface. -func (src Time) Value() (driver.Value, error) { - return EncodeValueText(src) + return t, nil } diff --git a/pgtype/time_test.go b/pgtype/time_test.go index f710ed03..8394a951 100644 --- a/pgtype/time_test.go +++ b/pgtype/time_test.go @@ -1,117 +1,45 @@ package pgtype_test import ( - "reflect" "testing" "time" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" ) -func TestTimeTranscode(t *testing.T) { - testutil.TestSuccessfulTranscode(t, "time", []interface{}{ - &pgtype.Time{Microseconds: 0, Valid: true}, - &pgtype.Time{Microseconds: 1, Valid: true}, - &pgtype.Time{Microseconds: 86399999999, Valid: true}, - &pgtype.Time{Microseconds: 86400000000, Valid: true}, - &pgtype.Time{}, +func TestTimeCodec(t *testing.T) { + testPgxCodec(t, "time", []PgxTranscodeTestCase{ + { + pgtype.Time{Microseconds: 0, Valid: true}, + new(pgtype.Time), + isExpectedEq(pgtype.Time{Microseconds: 0, Valid: true}), + }, + { + pgtype.Time{Microseconds: 1, Valid: true}, + new(pgtype.Time), + isExpectedEq(pgtype.Time{Microseconds: 1, Valid: true}), + }, + { + pgtype.Time{Microseconds: 86399999999, Valid: true}, + new(pgtype.Time), + isExpectedEq(pgtype.Time{Microseconds: 86399999999, Valid: true}), + }, + { + pgtype.Time{Microseconds: 86400000000, Valid: true}, + new(pgtype.Time), + isExpectedEq(pgtype.Time{Microseconds: 86400000000, Valid: true}), + }, + { + time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), + new(pgtype.Time), + isExpectedEq(pgtype.Time{Microseconds: 0, Valid: true}), + }, + { + pgtype.Time{Microseconds: 0, Valid: true}, + new(time.Time), + isExpectedEq(time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)), + }, + {pgtype.Time{}, new(pgtype.Time), isExpectedEq(pgtype.Time{})}, + {nil, new(pgtype.Time), isExpectedEq(pgtype.Time{})}, }) } - -func TestTimeSet(t *testing.T) { - type _time time.Time - - successfulTests := []struct { - source interface{} - result pgtype.Time - }{ - {source: time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), result: pgtype.Time{Microseconds: 0, Valid: true}}, - {source: time.Date(1900, 1, 1, 1, 0, 0, 0, time.UTC), result: pgtype.Time{Microseconds: 3600000000, Valid: true}}, - {source: time.Date(1900, 1, 1, 0, 1, 0, 0, time.UTC), result: pgtype.Time{Microseconds: 60000000, Valid: true}}, - {source: time.Date(1900, 1, 1, 0, 0, 1, 0, time.UTC), result: pgtype.Time{Microseconds: 1000000, Valid: true}}, - {source: time.Date(1970, 1, 1, 0, 0, 0, 1, time.UTC), result: pgtype.Time{Microseconds: 0, Valid: true}}, - {source: time.Date(1970, 1, 1, 0, 0, 0, 1000, time.UTC), result: pgtype.Time{Microseconds: 1, Valid: true}}, - {source: time.Date(1999, 12, 31, 23, 59, 59, 999999999, time.UTC), result: pgtype.Time{Microseconds: 86399999999, Valid: true}}, - {source: time.Date(2015, 1, 1, 0, 0, 0, 2000, time.Local), result: pgtype.Time{Microseconds: 2, Valid: true}}, - {source: func(t time.Time) *time.Time { return &t }(time.Date(2015, 1, 1, 0, 0, 0, 2000, time.Local)), result: pgtype.Time{Microseconds: 2, Valid: true}}, - {source: nil, result: pgtype.Time{}}, - {source: (*time.Time)(nil), result: pgtype.Time{}}, - {source: _time(time.Date(1970, 1, 1, 0, 0, 0, 3000, time.UTC)), result: pgtype.Time{Microseconds: 3, Valid: true}}, - } - - for i, tt := range successfulTests { - var r pgtype.Time - err := r.Set(tt.source) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if r != tt.result { - t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) - } - } -} - -func TestTimeAssignTo(t *testing.T) { - var tim time.Time - var ptim *time.Time - - simpleTests := []struct { - src pgtype.Time - dst interface{} - expected interface{} - }{ - {src: pgtype.Time{Microseconds: 0, Valid: true}, dst: &tim, expected: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)}, - {src: pgtype.Time{Microseconds: 3600000000, Valid: true}, dst: &tim, expected: time.Date(2000, 1, 1, 1, 0, 0, 0, time.UTC)}, - {src: pgtype.Time{Microseconds: 60000000, Valid: true}, dst: &tim, expected: time.Date(2000, 1, 1, 0, 1, 0, 0, time.UTC)}, - {src: pgtype.Time{Microseconds: 1000000, Valid: true}, dst: &tim, expected: time.Date(2000, 1, 1, 0, 0, 1, 0, time.UTC)}, - {src: pgtype.Time{Microseconds: 1, Valid: true}, dst: &tim, expected: time.Date(2000, 1, 1, 0, 0, 0, 1000, time.UTC)}, - {src: pgtype.Time{Microseconds: 86399999999, Valid: true}, dst: &tim, expected: time.Date(2000, 1, 1, 23, 59, 59, 999999000, time.UTC)}, - {src: pgtype.Time{Microseconds: 0}, dst: &ptim, expected: ((*time.Time)(nil))}, - } - - for i, tt := range simpleTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - pointerAllocTests := []struct { - src pgtype.Time - dst interface{} - expected interface{} - }{ - {src: pgtype.Time{Microseconds: 0, Valid: true}, dst: &ptim, expected: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)}, - } - - for i, tt := range pointerAllocTests { - err := tt.src.AssignTo(tt.dst) - if err != nil { - t.Errorf("%d: %v", i, err) - } - - if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { - t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) - } - } - - errorTests := []struct { - src pgtype.Time - dst interface{} - }{ - {src: pgtype.Time{Microseconds: 86400000000, Valid: true}, dst: &tim}, - } - - for i, tt := range errorTests { - err := tt.src.AssignTo(tt.dst) - if err == nil { - t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) - } - } -}