fixed marshaling and unmarshaling

non-blocking
Bekmamat 2020-09-19 21:50:56 +03:00 committed by Jack Christensen
parent fbe354aea1
commit d7f92427ad
4 changed files with 132 additions and 17 deletions

View File

@ -53,7 +53,9 @@ func parsePoint(src []byte) (*Point, error) {
if len(src) < 5 {
return nil, errors.Errorf("invalid length for point: %v", len(src))
}
if src[0] == '"' && src[len(src)-1] == '"' {
src = src[1 : len(src)-1]
}
parts := strings.SplitN(string(src[1:len(src)-1]), ",", 2)
if len(parts) < 2 {
return nil, errors.Errorf("invalid format for point")
@ -190,7 +192,11 @@ func (src Point) Value() (driver.Value, error) {
func (src Point) MarshalJSON() ([]byte, error) {
switch src.Status {
case Present:
return []byte(fmt.Sprintf("(%g,%g)", src.P.X, src.P.Y)), nil
var buff bytes.Buffer
buff.WriteByte('"')
buff.WriteString(fmt.Sprintf("(%g,%g)", src.P.X, src.P.Y))
buff.WriteByte('"')
return buff.Bytes(), nil
case Null:
return []byte("null"), nil
case Undefined:

View File

@ -72,7 +72,7 @@ func TestPoint_MarshalJSON(t *testing.T) {
name: "first",
point: pgtype.Point{
P: pgtype.Vec2{},
Status: 0,
Status: pgtype.Undefined,
},
want: nil,
wantErr: true,
@ -83,7 +83,7 @@ func TestPoint_MarshalJSON(t *testing.T) {
P: pgtype.Vec2{X: 12.245, Y: 432.12},
Status: pgtype.Present,
},
want: []byte("(12.245,432.12)"),
want: []byte(`"(12.245,432.12)"`),
wantErr: false,
},
{
@ -113,26 +113,26 @@ func TestPoint_MarshalJSON(t *testing.T) {
func TestPoint_UnmarshalJSON(t *testing.T) {
tests := []struct {
name string
status pgtype.Status
status pgtype.Status
arg []byte
wantErr bool
}{
{
name: "first",
status: pgtype.Present,
arg: []byte("(123.123,54.12)"),
name: "first",
status: pgtype.Present,
arg: []byte(`"(123.123,54.12)"`),
wantErr: false,
},
{
name: "second",
status: pgtype.Undefined,
arg: []byte("(123.123,54.1sad2)"),
name: "second",
status: pgtype.Undefined,
arg: []byte(`"(123.123,54.1sad2)"`),
wantErr: true,
},
{
name: "third",
status: pgtype.Null,
arg: []byte("null"),
name: "third",
status: pgtype.Null,
arg: []byte("null"),
wantErr: false,
},
}

17
uuid.go
View File

@ -1,6 +1,7 @@
package pgtype
import (
"bytes"
"database/sql/driver"
"encoding/hex"
"fmt"
@ -207,7 +208,11 @@ func (src UUID) Value() (driver.Value, error) {
func (src UUID) MarshalJSON() ([]byte, error) {
switch src.Status {
case Present:
return []byte(encodeUUID(src.Bytes)), nil
var buff bytes.Buffer
buff.WriteByte('"')
buff.WriteString(encodeUUID(src.Bytes))
buff.WriteByte('"')
return buff.Bytes(), nil
case Null:
return []byte("null"), nil
case Undefined:
@ -216,6 +221,12 @@ func (src UUID) MarshalJSON() ([]byte, error) {
return nil, errBadStatus
}
func (dst *UUID) UnmarshalJSON(bytes []byte) error {
return dst.Set(bytes)
func (dst *UUID) UnmarshalJSON(src []byte) error {
if bytes.Compare(src, []byte("null")) == 0 {
return dst.Set(nil)
}
if len(src) != 38 {
return errors.Errorf("invalid length for UUID: %v", len(src))
}
return dst.Set(string(src[1 : len(src)-1]))
}

View File

@ -2,6 +2,7 @@ package pgtype_test
import (
"bytes"
"reflect"
"testing"
"github.com/jackc/pgtype"
@ -127,3 +128,100 @@ func TestUUIDAssignTo(t *testing.T) {
}
}
func TestUUID_MarshalJSON(t *testing.T) {
tests := []struct {
name string
src pgtype.UUID
want []byte
wantErr bool
}{
{
name: "first",
src: pgtype.UUID{
Bytes: [16]byte{29, 72, 90, 122, 109, 24, 69, 153, 140, 108, 52, 66, 86, 22, 136, 122},
Status: pgtype.Present,
},
want: []byte(`"1d485a7a-6d18-4599-8c6c-34425616887a"`),
wantErr: false,
},
{
name: "second",
src: pgtype.UUID{
Bytes: [16]byte{},
Status: pgtype.Undefined,
},
want: nil,
wantErr: true,
},
{
name: "third",
src: pgtype.UUID{
Bytes: [16]byte{},
Status: pgtype.Null,
},
want: []byte("null"),
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.src.MarshalJSON()
if (err != nil) != tt.wantErr {
t.Errorf("MarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("MarshalJSON() got = %v, want %v", got, tt.want)
}
})
}
}
func TestUUID_UnmarshalJSON(t *testing.T) {
tests := []struct {
name string
want *pgtype.UUID
src []byte
wantErr bool
}{
{
name: "first",
want: &pgtype.UUID{
Bytes: [16]byte{29, 72, 90, 122, 109, 24, 69, 153, 140, 108, 52, 66, 86, 22, 136, 122},
Status: pgtype.Present,
},
src: []byte(`"1d485a7a-6d18-4599-8c6c-34425616887a"`),
wantErr: false,
},
{
name: "second",
want: &pgtype.UUID{
Bytes: [16]byte{},
Status: pgtype.Null,
},
src: []byte("null"),
wantErr: false,
},
{
name: "third",
want: &pgtype.UUID{
Bytes: [16]byte{},
Status: pgtype.Undefined,
},
src: []byte("1d485a7a-6d18-4599-8c6c-34425616887a"),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := &pgtype.UUID{}
if err := got.UnmarshalJSON(tt.src); (err != nil) != tt.wantErr {
t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("UnmarshalJSON() got = %v, want %v", got, tt.want)
}
})
}
}