From cab445ddd2102f6d05aa4c8dcf7e6e304faaa772 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 14 Apr 2017 16:46:39 -0500 Subject: [PATCH] Add satori-uuid type Make pgtype.EncodeValueText public --- pgtype/box.go | 2 +- pgtype/circle.go | 2 +- pgtype/database_sql.go | 2 +- pgtype/daterange.go | 2 +- pgtype/ext/satori-uuid/uuid.go | 164 ++++++++++++++++++++++++++++ pgtype/ext/satori-uuid/uuid_test.go | 97 ++++++++++++++++ pgtype/hstore.go | 2 +- pgtype/inet.go | 2 +- pgtype/int4range.go | 2 +- pgtype/int8range.go | 2 +- pgtype/interval.go | 2 +- pgtype/line.go | 2 +- pgtype/lseg.go | 2 +- pgtype/macaddr.go | 2 +- pgtype/numrange.go | 2 +- pgtype/path.go | 2 +- pgtype/point.go | 2 +- pgtype/polygon.go | 2 +- pgtype/tid.go | 2 +- pgtype/tsrange.go | 2 +- pgtype/tstzrange.go | 2 +- pgtype/typed_range.go.erb | 2 +- pgtype/uuid.go | 2 +- pgtype/varbit.go | 2 +- 24 files changed, 283 insertions(+), 22 deletions(-) create mode 100644 pgtype/ext/satori-uuid/uuid.go create mode 100644 pgtype/ext/satori-uuid/uuid_test.go diff --git a/pgtype/box.go b/pgtype/box.go index 138953a5..2e4f39ee 100644 --- a/pgtype/box.go +++ b/pgtype/box.go @@ -164,5 +164,5 @@ func (dst *Box) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Box) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/circle.go b/pgtype/circle.go index 62e2e8b3..8c8f4693 100644 --- a/pgtype/circle.go +++ b/pgtype/circle.go @@ -146,5 +146,5 @@ func (dst *Circle) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Circle) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/database_sql.go b/pgtype/database_sql.go index 2ddd842d..e255b646 100644 --- a/pgtype/database_sql.go +++ b/pgtype/database_sql.go @@ -31,7 +31,7 @@ func DatabaseSQLValue(ci *ConnInfo, src Value) (interface{}, error) { return nil, errors.New("cannot convert to database/sql compatible value") } -func encodeValueText(src TextEncoder) (interface{}, error) { +func EncodeValueText(src TextEncoder) (interface{}, error) { buf := &bytes.Buffer{} null, err := src.EncodeText(nil, buf) if err != nil { diff --git a/pgtype/daterange.go b/pgtype/daterange.go index d78c4803..5cecca20 100644 --- a/pgtype/daterange.go +++ b/pgtype/daterange.go @@ -264,5 +264,5 @@ func (dst *Daterange) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Daterange) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/ext/satori-uuid/uuid.go b/pgtype/ext/satori-uuid/uuid.go new file mode 100644 index 00000000..1b65f48a --- /dev/null +++ b/pgtype/ext/satori-uuid/uuid.go @@ -0,0 +1,164 @@ +package uuid + +import ( + "database/sql/driver" + "errors" + "fmt" + "io" + + "github.com/jackc/pgx/pgtype" + uuid "github.com/satori/go.uuid" +) + +var errUndefined = errors.New("cannot encode status undefined") + +type Uuid struct { + UUID uuid.UUID + Status pgtype.Status +} + +func (dst *Uuid) Set(src interface{}) error { + switch value := src.(type) { + case uuid.UUID: + *dst = Uuid{UUID: value, Status: pgtype.Present} + case [16]byte: + *dst = Uuid{UUID: uuid.UUID(value), Status: pgtype.Present} + case []byte: + if len(value) != 16 { + return fmt.Errorf("[]byte must be 16 bytes to convert to Uuid: %d", len(value)) + } + *dst = Uuid{Status: pgtype.Present} + copy(dst.UUID[:], value) + case string: + uuid, err := uuid.FromString(value) + if err != nil { + return err + } + *dst = Uuid{UUID: uuid, Status: pgtype.Present} + default: + // If all else fails see if pgtype.Uuid can handle it. If so, translate through that. + pgUuid := &pgtype.Uuid{} + if err := pgUuid.Set(value); err != nil { + return fmt.Errorf("cannot convert %v to Uuid", value) + } + + *dst = Uuid{UUID: uuid.UUID(pgUuid.Bytes), Status: pgUuid.Status} + } + + return nil +} + +func (dst *Uuid) Get() interface{} { + switch dst.Status { + case pgtype.Present: + return dst.UUID + case pgtype.Null: + return nil + default: + return dst.Status + } +} + +func (src *Uuid) AssignTo(dst interface{}) error { + switch src.Status { + case pgtype.Present: + switch v := dst.(type) { + case *uuid.UUID: + *v = src.UUID + case *[16]byte: + *v = [16]byte(src.UUID) + return nil + case *[]byte: + *v = make([]byte, 16) + copy(*v, src.UUID[:]) + return nil + case *string: + *v = src.UUID.String() + return nil + default: + if nextDst, retry := pgtype.GetAssignToDstType(v); retry { + return src.AssignTo(nextDst) + } + } + case pgtype.Null: + return pgtype.NullAssignTo(dst) + } + + return fmt.Errorf("cannot assign %v into %T", src, dst) +} + +func (dst *Uuid) DecodeText(ci *pgtype.ConnInfo, src []byte) error { + if src == nil { + *dst = Uuid{Status: pgtype.Null} + return nil + } + + u, err := uuid.FromString(string(src)) + if err != nil { + return err + } + + *dst = Uuid{UUID: u, Status: pgtype.Present} + return nil +} + +func (dst *Uuid) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + if src == nil { + *dst = Uuid{Status: pgtype.Null} + return nil + } + + if len(src) != 16 { + return fmt.Errorf("invalid length for Uuid: %v", len(src)) + } + + *dst = Uuid{Status: pgtype.Present} + copy(dst.UUID[:], src) + return nil +} + +func (src *Uuid) EncodeText(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case pgtype.Null: + return true, nil + case pgtype.Undefined: + return false, errUndefined + } + + _, err := io.WriteString(w, src.UUID.String()) + return false, err +} + +func (src *Uuid) EncodeBinary(ci *pgtype.ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case pgtype.Null: + return true, nil + case pgtype.Undefined: + return false, errUndefined + } + + _, err := w.Write(src.UUID[:]) + return false, err +} + +// Scan implements the database/sql Scanner interface. +func (dst *Uuid) Scan(src interface{}) error { + if src == nil { + *dst = Uuid{Status: pgtype.Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *Uuid) Value() (driver.Value, error) { + return pgtype.EncodeValueText(src) +} diff --git a/pgtype/ext/satori-uuid/uuid_test.go b/pgtype/ext/satori-uuid/uuid_test.go new file mode 100644 index 00000000..993fb837 --- /dev/null +++ b/pgtype/ext/satori-uuid/uuid_test.go @@ -0,0 +1,97 @@ +package uuid_test + +import ( + "bytes" + "testing" + + "github.com/jackc/pgx/pgtype" + satori "github.com/jackc/pgx/pgtype/ext/satori-uuid" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestUuidTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "uuid", []interface{}{ + &satori.Uuid{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + &satori.Uuid{Status: pgtype.Null}, + }) +} + +func TestUuidSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result satori.Uuid + }{ + { + source: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + result: satori.Uuid{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, + { + source: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + result: satori.Uuid{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, + { + source: "00010203-0405-0607-0809-0a0b0c0d0e0f", + result: satori.Uuid{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, + } + + for i, tt := range successfulTests { + var r satori.Uuid + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestUuidAssignTo(t *testing.T) { + { + src := satori.Uuid{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst [16]byte + expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + + { + src := satori.Uuid{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst []byte + expected := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if bytes.Compare(dst, expected) != 0 { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + + { + src := satori.Uuid{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst string + expected := "00010203-0405-0607-0809-0a0b0c0d0e0f" + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + +} diff --git a/pgtype/hstore.go b/pgtype/hstore.go index 3d55f783..04df2acc 100644 --- a/pgtype/hstore.go +++ b/pgtype/hstore.go @@ -463,5 +463,5 @@ func (dst *Hstore) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Hstore) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/inet.go b/pgtype/inet.go index 62734088..e3a7ec88 100644 --- a/pgtype/inet.go +++ b/pgtype/inet.go @@ -221,5 +221,5 @@ func (dst *Inet) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Inet) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/int4range.go b/pgtype/int4range.go index 8b04cf3c..12a48dab 100644 --- a/pgtype/int4range.go +++ b/pgtype/int4range.go @@ -264,5 +264,5 @@ func (dst *Int4range) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Int4range) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/int8range.go b/pgtype/int8range.go index f8e056cb..3541dbe2 100644 --- a/pgtype/int8range.go +++ b/pgtype/int8range.go @@ -264,5 +264,5 @@ func (dst *Int8range) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Int8range) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/interval.go b/pgtype/interval.go index 1cbdffc3..050d5610 100644 --- a/pgtype/interval.go +++ b/pgtype/interval.go @@ -267,5 +267,5 @@ func (dst *Interval) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Interval) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/line.go b/pgtype/line.go index 08a74e84..06f01f21 100644 --- a/pgtype/line.go +++ b/pgtype/line.go @@ -144,5 +144,5 @@ func (dst *Line) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Line) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/lseg.go b/pgtype/lseg.go index b86256e0..986724cc 100644 --- a/pgtype/lseg.go +++ b/pgtype/lseg.go @@ -164,5 +164,5 @@ func (dst *Lseg) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Lseg) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/macaddr.go b/pgtype/macaddr.go index cfbb513d..0fe092e4 100644 --- a/pgtype/macaddr.go +++ b/pgtype/macaddr.go @@ -150,5 +150,5 @@ func (dst *Macaddr) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Macaddr) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/numrange.go b/pgtype/numrange.go index a1b5b184..b0baec9a 100644 --- a/pgtype/numrange.go +++ b/pgtype/numrange.go @@ -264,5 +264,5 @@ func (dst *Numrange) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Numrange) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/path.go b/pgtype/path.go index fb4193d9..2fd6cfc7 100644 --- a/pgtype/path.go +++ b/pgtype/path.go @@ -203,5 +203,5 @@ func (dst *Path) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Path) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/point.go b/pgtype/point.go index 788a76c9..3d51766e 100644 --- a/pgtype/point.go +++ b/pgtype/point.go @@ -138,5 +138,5 @@ func (dst *Point) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Point) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/polygon.go b/pgtype/polygon.go index 1e2df011..af99ee3d 100644 --- a/pgtype/polygon.go +++ b/pgtype/polygon.go @@ -182,5 +182,5 @@ func (dst *Polygon) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Polygon) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/tid.go b/pgtype/tid.go index f24c6244..7976afde 100644 --- a/pgtype/tid.go +++ b/pgtype/tid.go @@ -142,5 +142,5 @@ func (dst *Tid) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Tid) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/tsrange.go b/pgtype/tsrange.go index 3bf5f5ca..78a94af2 100644 --- a/pgtype/tsrange.go +++ b/pgtype/tsrange.go @@ -264,5 +264,5 @@ func (dst *Tsrange) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Tsrange) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/tstzrange.go b/pgtype/tstzrange.go index 8e80a8f9..d1fc7326 100644 --- a/pgtype/tstzrange.go +++ b/pgtype/tstzrange.go @@ -264,5 +264,5 @@ func (dst *Tstzrange) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Tstzrange) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/typed_range.go.erb b/pgtype/typed_range.go.erb index 922b98b4..e46f71c7 100644 --- a/pgtype/typed_range.go.erb +++ b/pgtype/typed_range.go.erb @@ -264,5 +264,5 @@ func (dst *<%= range_type %>) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src <%= range_type %>) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/uuid.go b/pgtype/uuid.go index 03029ffd..c830c086 100644 --- a/pgtype/uuid.go +++ b/pgtype/uuid.go @@ -169,5 +169,5 @@ func (dst *Uuid) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Uuid) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) } diff --git a/pgtype/varbit.go b/pgtype/varbit.go index d28e95cd..00c34e10 100644 --- a/pgtype/varbit.go +++ b/pgtype/varbit.go @@ -137,5 +137,5 @@ func (dst *Varbit) Scan(src interface{}) error { // Value implements the database/sql/driver Valuer interface. func (src *Varbit) Value() (driver.Value, error) { - return encodeValueText(src) + return EncodeValueText(src) }