diff --git a/README.md b/README.md index 0a4cacc3..1c466c11 100644 --- a/README.md +++ b/README.md @@ -85,11 +85,11 @@ skip tests for connection types that are not configured. To setup the normal test environment, first install these dependencies: go get github.com/cockroachdb/apd + go get github.com/gofrs/uuid go get github.com/hashicorp/go-version go get github.com/jackc/fake go get github.com/lib/pq go get github.com/pkg/errors - go get github.com/satori/go.uuid go get github.com/shopspring/decimal go get github.com/sirupsen/logrus go get go.uber.org/zap diff --git a/pgtype/ext/gofrs-uuid/uuid.go b/pgtype/ext/gofrs-uuid/uuid.go new file mode 100644 index 00000000..e859f6ef --- /dev/null +++ b/pgtype/ext/gofrs-uuid/uuid.go @@ -0,0 +1,161 @@ +package uuid + +import ( + "database/sql/driver" + + "github.com/gofrs/uuid" + "github.com/pkg/errors" + + "github.com/jackc/pgx/pgtype" +) + +var errUndefined = errors.New("cannot encode status undefined") + +type UUID struct { + UUID uuid.UUID + Status pgtype.Status +} + +func (dst *UUID) Set(src interface{}) error { + switch value := src.(type) { + case uuid.UUID: + *dst = UUID{UUID: value, Status: pgtype.Present} + case [16]byte: + *dst = UUID{UUID: uuid.UUID(value), Status: pgtype.Present} + case []byte: + if len(value) != 16 { + return errors.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value)) + } + *dst = UUID{Status: pgtype.Present} + copy(dst.UUID[:], value) + case string: + uuid, err := uuid.FromString(value) + if err != nil { + return err + } + *dst = UUID{UUID: uuid, Status: pgtype.Present} + default: + // If all else fails see if pgtype.UUID can handle it. If so, translate through that. + pgUUID := &pgtype.UUID{} + if err := pgUUID.Set(value); err != nil { + return errors.Errorf("cannot convert %v to UUID", value) + } + + *dst = UUID{UUID: uuid.UUID(pgUUID.Bytes), Status: pgUUID.Status} + } + + return nil +} + +func (dst *UUID) Get() interface{} { + switch dst.Status { + case pgtype.Present: + return dst.UUID + case pgtype.Null: + return nil + default: + return dst.Status + } +} + +func (src *UUID) AssignTo(dst interface{}) error { + switch src.Status { + case pgtype.Present: + switch v := dst.(type) { + case *uuid.UUID: + *v = src.UUID + case *[16]byte: + *v = [16]byte(src.UUID) + return nil + case *[]byte: + *v = make([]byte, 16) + copy(*v, src.UUID[:]) + return nil + case *string: + *v = src.UUID.String() + return nil + default: + if nextDst, retry := pgtype.GetAssignToDstType(v); retry { + return src.AssignTo(nextDst) + } + } + case pgtype.Null: + return pgtype.NullAssignTo(dst) + } + + return errors.Errorf("cannot assign %v into %T", src, dst) +} + +func (dst *UUID) DecodeText(ci *pgtype.ConnInfo, src []byte) error { + if src == nil { + *dst = UUID{Status: pgtype.Null} + return nil + } + + u, err := uuid.FromString(string(src)) + if err != nil { + return err + } + + *dst = UUID{UUID: u, Status: pgtype.Present} + return nil +} + +func (dst *UUID) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { + if src == nil { + *dst = UUID{Status: pgtype.Null} + return nil + } + + if len(src) != 16 { + return errors.Errorf("invalid length for UUID: %v", len(src)) + } + + *dst = UUID{Status: pgtype.Present} + copy(dst.UUID[:], src) + return nil +} + +func (src *UUID) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case pgtype.Null: + return nil, nil + case pgtype.Undefined: + return nil, errUndefined + } + + return append(buf, src.UUID.String()...), nil +} + +func (src *UUID) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case pgtype.Null: + return nil, nil + case pgtype.Undefined: + return nil, errUndefined + } + + return append(buf, src.UUID[:]...), nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *UUID) Scan(src interface{}) error { + if src == nil { + *dst = UUID{Status: pgtype.Null} + return nil + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + return dst.DecodeText(nil, src) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *UUID) Value() (driver.Value, error) { + return pgtype.EncodeValueText(src) +} diff --git a/pgtype/ext/gofrs-uuid/uuid_test.go b/pgtype/ext/gofrs-uuid/uuid_test.go new file mode 100644 index 00000000..d76edb18 --- /dev/null +++ b/pgtype/ext/gofrs-uuid/uuid_test.go @@ -0,0 +1,97 @@ +package uuid_test + +import ( + "bytes" + "testing" + + "github.com/jackc/pgx/pgtype" + gofrs "github.com/jackc/pgx/pgtype/ext/gofrs-uuid" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestUUIDTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "uuid", []interface{}{ + &gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + &gofrs.UUID{Status: pgtype.Null}, + }) +} + +func TestUUIDSet(t *testing.T) { + successfulTests := []struct { + source interface{} + result gofrs.UUID + }{ + { + source: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + result: gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, + { + source: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + result: gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, + { + source: "00010203-0405-0607-0809-0a0b0c0d0e0f", + result: gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, + } + + for i, tt := range successfulTests { + var r gofrs.UUID + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if r != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestUUIDAssignTo(t *testing.T) { + { + src := gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst [16]byte + expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + + { + src := gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst []byte + expected := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if bytes.Compare(dst, expected) != 0 { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + + { + src := gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst string + expected := "00010203-0405-0607-0809-0a0b0c0d0e0f" + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + +} diff --git a/query_test.go b/query_test.go index ea1fd66e..500399b9 100644 --- a/query_test.go +++ b/query_test.go @@ -11,10 +11,10 @@ import ( "time" "github.com/cockroachdb/apd" + "github.com/gofrs/uuid" "github.com/jackc/pgx" "github.com/jackc/pgx/pgtype" - satori "github.com/jackc/pgx/pgtype/ext/satori-uuid" - uuid "github.com/satori/go.uuid" + gofrs "github.com/jackc/pgx/pgtype/ext/gofrs-uuid" "github.com/shopspring/decimal" ) @@ -1140,7 +1140,7 @@ func TestConnQueryDatabaseSQLDriverValuerWithBinaryPgTypeThatAcceptsSameType(t * defer closeConn(t, conn) conn.ConnInfo.RegisterDataType(pgtype.DataType{ - Value: &satori.UUID{}, + Value: &gofrs.UUID{}, Name: "uuid", OID: 2950, }) diff --git a/travis/install.bash b/travis/install.bash index 3c3e44cf..c3b344e5 100755 --- a/travis/install.bash +++ b/travis/install.bash @@ -4,6 +4,7 @@ set -eux go get -u github.com/cockroachdb/apd go get -u github.com/shopspring/decimal go get -u gopkg.in/inconshreveable/log15.v2 +go get -u github.com/gofrs/uuid go get -u github.com/jackc/fake go get -u github.com/lib/pq go get -u github.com/hashicorp/go-version