From 5f363cb1f02554168b67e4d3e5dbeece248464e0 Mon Sep 17 00:00:00 2001 From: Jeffrey Stiles Date: Mon, 27 Jan 2020 16:19:43 -0800 Subject: [PATCH] Add JSON marshalling for Bool, Date, JSON/B, Timestamptz --- bool.go | 34 +++++++++++++++++++++++++++ bool_test.go | 43 ++++++++++++++++++++++++++++++++++ date.go | 56 ++++++++++++++++++++++++++++++++++++++++++++ date_test.go | 49 ++++++++++++++++++++++++++++++++++++++ int4_test.go | 20 ++++++++++++++++ int8_test.go | 20 ++++++++++++++++ json.go | 23 ++++++++++++++++++ json_test.go | 41 ++++++++++++++++++++++++++++++++ jsonb.go | 8 +++++++ text_test.go | 20 ++++++++++++++++ timestamptz.go | 57 +++++++++++++++++++++++++++++++++++++++++++++ timestamptz_test.go | 47 +++++++++++++++++++++++++++++++++++++ 12 files changed, 418 insertions(+) diff --git a/bool.go b/bool.go index ad55dce4..db02f663 100644 --- a/bool.go +++ b/bool.go @@ -2,6 +2,7 @@ package pgtype import ( "database/sql/driver" + "encoding/json" "strconv" errors "golang.org/x/xerrors" @@ -163,3 +164,36 @@ func (src Bool) Value() (driver.Value, error) { return nil, errUndefined } } + +func (src Bool) MarshalJSON() ([]byte, error) { + switch src.Status { + case Present: + if src.Bool { + return []byte("true"), nil + } else { + return []byte("false"), nil + } + case Null: + return []byte("null"), nil + case Undefined: + return nil, errUndefined + } + + return nil, errBadStatus +} + +func (dst *Bool) UnmarshalJSON(b []byte) error { + var v *bool + err := json.Unmarshal(b, &v) + if err != nil { + return err + } + + if v == nil { + *dst = Bool{Status: Null} + } else { + *dst = Bool{Bool: *v, Status: Present} + } + + return nil +} diff --git a/bool_test.go b/bool_test.go index 64b4064d..8e7a5220 100644 --- a/bool_test.go +++ b/bool_test.go @@ -95,3 +95,46 @@ func TestBoolAssignTo(t *testing.T) { } } } + +func TestBoolMarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Bool + result string + }{ + {source: pgtype.Bool{Status: pgtype.Null}, result: "null"}, + {source: pgtype.Bool{Bool: true, Status: pgtype.Present}, result: "true"}, + {source: pgtype.Bool{Bool: false, Status: pgtype.Present}, result: "false"}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestBoolUnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Bool + }{ + {source: "null", result: pgtype.Bool{Status: pgtype.Null}}, + {source: "true", result: pgtype.Bool{Bool: true, Status: pgtype.Present}}, + {source: "false", result: pgtype.Bool{Bool: false, Status: pgtype.Present}}, + } + for i, tt := range successfulTests { + var r pgtype.Bool + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/date.go b/date.go index 8e35b22a..eaf95dde 100644 --- a/date.go +++ b/date.go @@ -3,6 +3,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "encoding/json" "time" "github.com/jackc/pgio" @@ -208,3 +209,58 @@ func (src Date) Value() (driver.Value, error) { return nil, errUndefined } } + +func (src Date) MarshalJSON() ([]byte, error) { + switch src.Status { + case Null: + return []byte("null"), nil + case Undefined: + return nil, errUndefined + } + + if src.Status != Present { + return nil, errBadStatus + } + + var s string + + switch src.InfinityModifier { + case None: + s = src.Time.Format("2006-01-02") + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + return json.Marshal(s) +} + +func (dst *Date) UnmarshalJSON(b []byte) error { + var s *string + err := json.Unmarshal(b, &s) + if err != nil { + return err + } + + if s == nil { + *dst = Date{Status: Null} + return nil + } + + switch *s { + case "infinity": + *dst = Date{Status: Present, InfinityModifier: Infinity} + case "-infinity": + *dst = Date{Status: Present, InfinityModifier: -Infinity} + default: + t, err := time.ParseInLocation("2006-01-02", *s, time.UTC) + if err != nil { + return err + } + + *dst = Date{Time: t, Status: Present} + } + + return nil +} diff --git a/date_test.go b/date_test.go index bcdbbf20..0b77898b 100644 --- a/date_test.go +++ b/date_test.go @@ -116,3 +116,52 @@ func TestDateAssignTo(t *testing.T) { } } } + +func TestDateMarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Date + result string + }{ + {source: pgtype.Date{Status: pgtype.Null}, result: "null"}, + {source: pgtype.Date{Time: time.Date(2012, 3, 29, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}, result: "\"2012-03-29\""}, + {source: pgtype.Date{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.FixedZone("", -6*60*60)), Status: pgtype.Present}, result: "\"2012-03-29\""}, + {source: pgtype.Date{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.FixedZone("", -6*60*60)), Status: pgtype.Present}, result: "\"2012-03-29\""}, + {source: pgtype.Date{InfinityModifier: pgtype.Infinity, Status: pgtype.Present}, result: "\"infinity\""}, + {source: pgtype.Date{InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}, result: "\"-infinity\""}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestDateUnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Date + }{ + {source: "null", result: pgtype.Date{Status: pgtype.Null}}, + {source: "\"2012-03-29\"", result: pgtype.Date{Time: time.Date(2012, 3, 29, 0, 0, 0, 0, time.UTC), Status: pgtype.Present}}, + {source: "\"2012-03-29\"", result: pgtype.Date{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.FixedZone("", -6*60*60)), Status: pgtype.Present}}, + {source: "\"2012-03-29\"", result: pgtype.Date{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.FixedZone("", -6*60*60)), Status: pgtype.Present}}, + {source: "\"infinity\"", result: pgtype.Date{InfinityModifier: pgtype.Infinity, Status: pgtype.Present}}, + {source: "\"-infinity\"", result: pgtype.Date{InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}}, + } + for i, tt := range successfulTests { + var r pgtype.Date + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r.Time.Year() != tt.result.Time.Year() || r.Time.Month() != tt.result.Time.Month() || r.Time.Day() != tt.result.Time.Day() || r.Status != tt.result.Status || r.InfinityModifier != tt.result.InfinityModifier { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/int4_test.go b/int4_test.go index 77fba8a5..c679de74 100644 --- a/int4_test.go +++ b/int4_test.go @@ -142,6 +142,26 @@ func TestInt4AssignTo(t *testing.T) { } } +func TestInt4MarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Int4 + result string + }{ + {source: pgtype.Int4{Int: 0, Status: pgtype.Null}, result: "null"}, + {source: pgtype.Int4{Int: 1, Status: pgtype.Present}, result: "1"}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + func TestInt4UnmarshalJSON(t *testing.T) { successfulTests := []struct { source string diff --git a/int8_test.go b/int8_test.go index 73600eda..fb6f581b 100644 --- a/int8_test.go +++ b/int8_test.go @@ -143,6 +143,26 @@ func TestInt8AssignTo(t *testing.T) { } } +func TestInt8MarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Int8 + result string + }{ + {source: pgtype.Int8{Int: 0, Status: pgtype.Null}, result: "null"}, + {source: pgtype.Int8{Int: 1, Status: pgtype.Present}, result: "1"}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + func TestInt8UnmarshalJSON(t *testing.T) { successfulTests := []struct { source string diff --git a/json.go b/json.go index 592dfa31..58a5b093 100644 --- a/json.go +++ b/json.go @@ -165,3 +165,26 @@ func (src JSON) Value() (driver.Value, error) { return nil, errUndefined } } + +func (src JSON) MarshalJSON() ([]byte, error) { + switch src.Status { + case Present: + return src.Bytes, nil + case Null: + return []byte("null"), nil + case Undefined: + return nil, errUndefined + } + + return nil, errBadStatus +} + +func (dst *JSON) UnmarshalJSON(b []byte) error { + if b == nil || string(b) == "null" { + *dst = JSON{Status: Null} + } else { + *dst = JSON{Bytes: b, Status: Present} + } + return nil + +} diff --git a/json_test.go b/json_test.go index 918b33d5..bbd3959e 100644 --- a/json_test.go +++ b/json_test.go @@ -134,3 +134,44 @@ func TestJSONAssignTo(t *testing.T) { } } } + +func TestJSONMarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.JSON + result string + }{ + {source: pgtype.JSON{Status: pgtype.Null}, result: "null"}, + {source: pgtype.JSON{Bytes: []byte("{\"a\": 1}"), Status: pgtype.Present}, result: "{\"a\": 1}"}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestJSONUnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.JSON + }{ + {source: "null", result: pgtype.JSON{Status: pgtype.Null}}, + {source: "{\"a\": 1}", result: pgtype.JSON{Bytes: []byte("{\"a\": 1}"), Status: pgtype.Present}}, + } + for i, tt := range successfulTests { + var r pgtype.JSON + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r.Bytes) != string(tt.result.Bytes) || r.Status != tt.result.Status { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} diff --git a/jsonb.go b/jsonb.go index c70be144..43e23fb2 100644 --- a/jsonb.go +++ b/jsonb.go @@ -68,3 +68,11 @@ func (dst *JSONB) Scan(src interface{}) error { func (src JSONB) Value() (driver.Value, error) { return (JSON)(src).Value() } + +func (src JSONB) MarshalJSON() ([]byte, error) { + return (JSON)(src).MarshalJSON() +} + +func (dst *JSONB) UnmarshalJSON(b []byte) error { + return (*JSON)(dst).UnmarshalJSON(b) +} diff --git a/text_test.go b/text_test.go index 3bacba68..cca3a05d 100644 --- a/text_test.go +++ b/text_test.go @@ -122,6 +122,26 @@ func TestTextAssignTo(t *testing.T) { } } +func TestTextMarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Text + result string + }{ + {source: pgtype.Text{String: "", Status: pgtype.Null}, result: "null"}, + {source: pgtype.Text{String: "a", Status: pgtype.Present}, result: "\"a\""}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + func TestTextUnmarshalJSON(t *testing.T) { successfulTests := []struct { source string diff --git a/timestamptz.go b/timestamptz.go index 9af39b16..7ed86eb8 100644 --- a/timestamptz.go +++ b/timestamptz.go @@ -3,6 +3,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "encoding/json" "time" "github.com/jackc/pgio" @@ -220,3 +221,59 @@ func (src Timestamptz) Value() (driver.Value, error) { return nil, errUndefined } } + +func (src Timestamptz) MarshalJSON() ([]byte, error) { + switch src.Status { + case Null: + return []byte("null"), nil + case Undefined: + return nil, errUndefined + } + + if src.Status != Present { + return nil, errBadStatus + } + + var s string + + switch src.InfinityModifier { + case None: + s = src.Time.Format(time.RFC3339Nano) + case Infinity: + s = "infinity" + case NegativeInfinity: + s = "-infinity" + } + + return json.Marshal(s) +} + +func (dst *Timestamptz) UnmarshalJSON(b []byte) error { + var s *string + err := json.Unmarshal(b, &s) + if err != nil { + return err + } + + if s == nil { + *dst = Timestamptz{Status: Null} + return nil + } + + switch *s { + case "infinity": + *dst = Timestamptz{Status: Present, InfinityModifier: Infinity} + case "-infinity": + *dst = Timestamptz{Status: Present, InfinityModifier: -Infinity} + default: + // PostgreSQL uses ISO 8601 for to_json function and casting from a string to timestamptz + tim, err := time.Parse(time.RFC3339Nano, *s) + if err != nil { + return err + } + + *dst = Timestamptz{Time: tim, Status: Present} + } + + return nil +} diff --git a/timestamptz_test.go b/timestamptz_test.go index f6aec068..a020b1ec 100644 --- a/timestamptz_test.go +++ b/timestamptz_test.go @@ -120,3 +120,50 @@ func TestTimestamptzAssignTo(t *testing.T) { } } } + +func TestTimestamptzMarshalJSON(t *testing.T) { + successfulTests := []struct { + source pgtype.Timestamptz + result string + }{ + {source: pgtype.Timestamptz{Status: pgtype.Null}, result: "null"}, + {source: pgtype.Timestamptz{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.FixedZone("", -6*60*60)), Status: pgtype.Present}, result: "\"2012-03-29T10:05:45-06:00\""}, + {source: pgtype.Timestamptz{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.FixedZone("", -6*60*60)), Status: pgtype.Present}, result: "\"2012-03-29T10:05:45.555-06:00\""}, + {source: pgtype.Timestamptz{InfinityModifier: pgtype.Infinity, Status: pgtype.Present}, result: "\"infinity\""}, + {source: pgtype.Timestamptz{InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}, result: "\"-infinity\""}, + } + for i, tt := range successfulTests { + r, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if string(r) != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) + } + } +} + +func TestTimestamptzUnmarshalJSON(t *testing.T) { + successfulTests := []struct { + source string + result pgtype.Timestamptz + }{ + {source: "null", result: pgtype.Timestamptz{Status: pgtype.Null}}, + {source: "\"2012-03-29T10:05:45-06:00\"", result: pgtype.Timestamptz{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.FixedZone("", -6*60*60)), Status: pgtype.Present}}, + {source: "\"2012-03-29T10:05:45.555-06:00\"", result: pgtype.Timestamptz{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.FixedZone("", -6*60*60)), Status: pgtype.Present}}, + {source: "\"infinity\"", result: pgtype.Timestamptz{InfinityModifier: pgtype.Infinity, Status: pgtype.Present}}, + {source: "\"-infinity\"", result: pgtype.Timestamptz{InfinityModifier: pgtype.NegativeInfinity, Status: pgtype.Present}}, + } + for i, tt := range successfulTests { + var r pgtype.Timestamptz + err := r.UnmarshalJSON([]byte(tt.source)) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !r.Time.Equal(tt.result.Time) || r.Status != tt.result.Status || r.InfinityModifier != tt.result.InfinityModifier { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +}