From d7f92427adf195b84fe5f21caefd2309519e07e9 Mon Sep 17 00:00:00 2001 From: Bekmamat Date: Sat, 19 Sep 2020 21:50:56 +0300 Subject: [PATCH] fixed marshaling and unmarshaling --- point.go | 10 ++++-- point_test.go | 24 ++++++------- uuid.go | 17 +++++++-- uuid_test.go | 98 +++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 132 insertions(+), 17 deletions(-) diff --git a/point.go b/point.go index 55c6c8d1..8e6bacf2 100644 --- a/point.go +++ b/point.go @@ -53,7 +53,9 @@ func parsePoint(src []byte) (*Point, error) { if len(src) < 5 { return nil, errors.Errorf("invalid length for point: %v", len(src)) } - + if src[0] == '"' && src[len(src)-1] == '"' { + src = src[1 : len(src)-1] + } parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2) if len(parts) < 2 { return nil, errors.Errorf("invalid format for point") @@ -190,7 +192,11 @@ func (src Point) Value() (driver.Value, error) { func (src Point) MarshalJSON() ([]byte, error) { switch src.Status { case Present: - return []byte(fmt.Sprintf("(%g,%g)", src.P.X, src.P.Y)), nil + var buff bytes.Buffer + buff.WriteByte('"') + buff.WriteString(fmt.Sprintf("(%g,%g)", src.P.X, src.P.Y)) + buff.WriteByte('"') + return buff.Bytes(), nil case Null: return []byte("null"), nil case Undefined: diff --git a/point_test.go b/point_test.go index 3601cf02..63f8df07 100644 --- a/point_test.go +++ b/point_test.go @@ -72,7 +72,7 @@ func TestPoint_MarshalJSON(t *testing.T) { name: "first", point: pgtype.Point{ P: pgtype.Vec2{}, - Status: 0, + Status: pgtype.Undefined, }, want: nil, wantErr: true, @@ -83,7 +83,7 @@ func TestPoint_MarshalJSON(t *testing.T) { P: pgtype.Vec2{X: 12.245, Y: 432.12}, Status: pgtype.Present, }, - want: []byte("(12.245,432.12)"), + want: []byte(`"(12.245,432.12)"`), wantErr: false, }, { @@ -113,26 +113,26 @@ func TestPoint_MarshalJSON(t *testing.T) { func TestPoint_UnmarshalJSON(t *testing.T) { tests := []struct { name string - status pgtype.Status + status pgtype.Status arg []byte wantErr bool }{ { - name: "first", - status: pgtype.Present, - arg: []byte("(123.123,54.12)"), + name: "first", + status: pgtype.Present, + arg: []byte(`"(123.123,54.12)"`), wantErr: false, }, { - name: "second", - status: pgtype.Undefined, - arg: []byte("(123.123,54.1sad2)"), + name: "second", + status: pgtype.Undefined, + arg: []byte(`"(123.123,54.1sad2)"`), wantErr: true, }, { - name: "third", - status: pgtype.Null, - arg: []byte("null"), + name: "third", + status: pgtype.Null, + arg: []byte("null"), wantErr: false, }, } diff --git a/uuid.go b/uuid.go index caaef2a7..b1681a78 100644 --- a/uuid.go +++ b/uuid.go @@ -1,6 +1,7 @@ package pgtype import ( + "bytes" "database/sql/driver" "encoding/hex" "fmt" @@ -207,7 +208,11 @@ func (src UUID) Value() (driver.Value, error) { func (src UUID) MarshalJSON() ([]byte, error) { switch src.Status { case Present: - return []byte(encodeUUID(src.Bytes)), nil + var buff bytes.Buffer + buff.WriteByte('"') + buff.WriteString(encodeUUID(src.Bytes)) + buff.WriteByte('"') + return buff.Bytes(), nil case Null: return []byte("null"), nil case Undefined: @@ -216,6 +221,12 @@ func (src UUID) MarshalJSON() ([]byte, error) { return nil, errBadStatus } -func (dst *UUID) UnmarshalJSON(bytes []byte) error { - return dst.Set(bytes) +func (dst *UUID) UnmarshalJSON(src []byte) error { + if bytes.Compare(src, []byte("null")) == 0 { + return dst.Set(nil) + } + if len(src) != 38 { + return errors.Errorf("invalid length for UUID: %v", len(src)) + } + return dst.Set(string(src[1 : len(src)-1])) } diff --git a/uuid_test.go b/uuid_test.go index 9f7b19e2..8de5b9f6 100644 --- a/uuid_test.go +++ b/uuid_test.go @@ -2,6 +2,7 @@ package pgtype_test import ( "bytes" + "reflect" "testing" "github.com/jackc/pgtype" @@ -127,3 +128,100 @@ func TestUUIDAssignTo(t *testing.T) { } } + +func TestUUID_MarshalJSON(t *testing.T) { + tests := []struct { + name string + src pgtype.UUID + want []byte + wantErr bool + }{ + { + name: "first", + src: pgtype.UUID{ + Bytes: [16]byte{29, 72, 90, 122, 109, 24, 69, 153, 140, 108, 52, 66, 86, 22, 136, 122}, + Status: pgtype.Present, + }, + want: []byte(`"1d485a7a-6d18-4599-8c6c-34425616887a"`), + wantErr: false, + }, + { + name: "second", + src: pgtype.UUID{ + Bytes: [16]byte{}, + Status: pgtype.Undefined, + }, + want: nil, + wantErr: true, + }, + { + name: "third", + src: pgtype.UUID{ + Bytes: [16]byte{}, + Status: pgtype.Null, + }, + want: []byte("null"), + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.src.MarshalJSON() + if (err != nil) != tt.wantErr { + t.Errorf("MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MarshalJSON() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUUID_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + want *pgtype.UUID + src []byte + wantErr bool + }{ + { + name: "first", + want: &pgtype.UUID{ + Bytes: [16]byte{29, 72, 90, 122, 109, 24, 69, 153, 140, 108, 52, 66, 86, 22, 136, 122}, + Status: pgtype.Present, + }, + src: []byte(`"1d485a7a-6d18-4599-8c6c-34425616887a"`), + wantErr: false, + }, + { + name: "second", + want: &pgtype.UUID{ + Bytes: [16]byte{}, + Status: pgtype.Null, + }, + src: []byte("null"), + wantErr: false, + }, + { + name: "third", + want: &pgtype.UUID{ + Bytes: [16]byte{}, + Status: pgtype.Undefined, + }, + src: []byte("1d485a7a-6d18-4599-8c6c-34425616887a"), + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := &pgtype.UUID{} + if err := got.UnmarshalJSON(tt.src); (err != nil) != tt.wantErr { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("UnmarshalJSON() got = %v, want %v", got, tt.want) + } + }) + } +}