From d52bd74254a9b2a19ebac3b10ca590e99f58d6af Mon Sep 17 00:00:00 2001
From: David Hudson <david-h-3110@hotmail.co.uk>
Date: Fri, 13 Sep 2019 16:37:38 +0100
Subject: [PATCH] pgtype: Add ext type for gofrs uuid implementation

Add ext type for https://github.com/gofrs/uuid uuid type.

Change test and README from github.com/satori/go.uuid to
github.com/gofrs/uuid. The reason is due to this issue:
https://github.com/satori/go.uuid/issues/73. This was taken on board and
fixed in the community project of gofrs. The gofrs implementation has
the same interface as the original.
---
 README.md                          |   2 +-
 pgtype/ext/gofrs-uuid/uuid.go      | 161 +++++++++++++++++++++++++++++
 pgtype/ext/gofrs-uuid/uuid_test.go |  97 +++++++++++++++++
 query_test.go                      |   6 +-
 travis/install.bash                |   1 +
 5 files changed, 263 insertions(+), 4 deletions(-)
 create mode 100644 pgtype/ext/gofrs-uuid/uuid.go
 create mode 100644 pgtype/ext/gofrs-uuid/uuid_test.go

diff --git a/README.md b/README.md
index 0a4cacc3..1c466c11 100644
--- a/README.md
+++ b/README.md
@@ -85,11 +85,11 @@ skip tests for connection types that are not configured.
 To setup the normal test environment, first install these dependencies:
 
     go get github.com/cockroachdb/apd
+    go get github.com/gofrs/uuid
     go get github.com/hashicorp/go-version
     go get github.com/jackc/fake
     go get github.com/lib/pq
     go get github.com/pkg/errors
-    go get github.com/satori/go.uuid
     go get github.com/shopspring/decimal
     go get github.com/sirupsen/logrus
     go get go.uber.org/zap
diff --git a/pgtype/ext/gofrs-uuid/uuid.go b/pgtype/ext/gofrs-uuid/uuid.go
new file mode 100644
index 00000000..e859f6ef
--- /dev/null
+++ b/pgtype/ext/gofrs-uuid/uuid.go
@@ -0,0 +1,161 @@
+package uuid
+
+import (
+	"database/sql/driver"
+
+	"github.com/gofrs/uuid"
+	"github.com/pkg/errors"
+
+	"github.com/jackc/pgx/pgtype"
+)
+
+var errUndefined = errors.New("cannot encode status undefined")
+
+type UUID struct {
+	UUID   uuid.UUID
+	Status pgtype.Status
+}
+
+func (dst *UUID) Set(src interface{}) error {
+	switch value := src.(type) {
+	case uuid.UUID:
+		*dst = UUID{UUID: value, Status: pgtype.Present}
+	case [16]byte:
+		*dst = UUID{UUID: uuid.UUID(value), Status: pgtype.Present}
+	case []byte:
+		if len(value) != 16 {
+			return errors.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value))
+		}
+		*dst = UUID{Status: pgtype.Present}
+		copy(dst.UUID[:], value)
+	case string:
+		uuid, err := uuid.FromString(value)
+		if err != nil {
+			return err
+		}
+		*dst = UUID{UUID: uuid, Status: pgtype.Present}
+	default:
+		// If all else fails see if pgtype.UUID can handle it. If so, translate through that.
+		pgUUID := &pgtype.UUID{}
+		if err := pgUUID.Set(value); err != nil {
+			return errors.Errorf("cannot convert %v to UUID", value)
+		}
+
+		*dst = UUID{UUID: uuid.UUID(pgUUID.Bytes), Status: pgUUID.Status}
+	}
+
+	return nil
+}
+
+func (dst *UUID) Get() interface{} {
+	switch dst.Status {
+	case pgtype.Present:
+		return dst.UUID
+	case pgtype.Null:
+		return nil
+	default:
+		return dst.Status
+	}
+}
+
+func (src *UUID) AssignTo(dst interface{}) error {
+	switch src.Status {
+	case pgtype.Present:
+		switch v := dst.(type) {
+		case *uuid.UUID:
+			*v = src.UUID
+		case *[16]byte:
+			*v = [16]byte(src.UUID)
+			return nil
+		case *[]byte:
+			*v = make([]byte, 16)
+			copy(*v, src.UUID[:])
+			return nil
+		case *string:
+			*v = src.UUID.String()
+			return nil
+		default:
+			if nextDst, retry := pgtype.GetAssignToDstType(v); retry {
+				return src.AssignTo(nextDst)
+			}
+		}
+	case pgtype.Null:
+		return pgtype.NullAssignTo(dst)
+	}
+
+	return errors.Errorf("cannot assign %v into %T", src, dst)
+}
+
+func (dst *UUID) DecodeText(ci *pgtype.ConnInfo, src []byte) error {
+	if src == nil {
+		*dst = UUID{Status: pgtype.Null}
+		return nil
+	}
+
+	u, err := uuid.FromString(string(src))
+	if err != nil {
+		return err
+	}
+
+	*dst = UUID{UUID: u, Status: pgtype.Present}
+	return nil
+}
+
+func (dst *UUID) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error {
+	if src == nil {
+		*dst = UUID{Status: pgtype.Null}
+		return nil
+	}
+
+	if len(src) != 16 {
+		return errors.Errorf("invalid length for UUID: %v", len(src))
+	}
+
+	*dst = UUID{Status: pgtype.Present}
+	copy(dst.UUID[:], src)
+	return nil
+}
+
+func (src *UUID) EncodeText(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) {
+	switch src.Status {
+	case pgtype.Null:
+		return nil, nil
+	case pgtype.Undefined:
+		return nil, errUndefined
+	}
+
+	return append(buf, src.UUID.String()...), nil
+}
+
+func (src *UUID) EncodeBinary(ci *pgtype.ConnInfo, buf []byte) ([]byte, error) {
+	switch src.Status {
+	case pgtype.Null:
+		return nil, nil
+	case pgtype.Undefined:
+		return nil, errUndefined
+	}
+
+	return append(buf, src.UUID[:]...), nil
+}
+
+// Scan implements the database/sql Scanner interface.
+func (dst *UUID) Scan(src interface{}) error {
+	if src == nil {
+		*dst = UUID{Status: pgtype.Null}
+		return nil
+	}
+
+	switch src := src.(type) {
+	case string:
+		return dst.DecodeText(nil, []byte(src))
+	case []byte:
+		return dst.DecodeText(nil, src)
+	}
+
+	return errors.Errorf("cannot scan %T", src)
+}
+
+// Value implements the database/sql/driver Valuer interface.
+func (src *UUID) Value() (driver.Value, error) {
+	return pgtype.EncodeValueText(src)
+}
diff --git a/pgtype/ext/gofrs-uuid/uuid_test.go b/pgtype/ext/gofrs-uuid/uuid_test.go
new file mode 100644
index 00000000..d76edb18
--- /dev/null
+++ b/pgtype/ext/gofrs-uuid/uuid_test.go
@@ -0,0 +1,97 @@
+package uuid_test
+
+import (
+	"bytes"
+	"testing"
+
+	"github.com/jackc/pgx/pgtype"
+	gofrs "github.com/jackc/pgx/pgtype/ext/gofrs-uuid"
+	"github.com/jackc/pgx/pgtype/testutil"
+)
+
+func TestUUIDTranscode(t *testing.T) {
+	testutil.TestSuccessfulTranscode(t, "uuid", []interface{}{
+		&gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present},
+		&gofrs.UUID{Status: pgtype.Null},
+	})
+}
+
+func TestUUIDSet(t *testing.T) {
+	successfulTests := []struct {
+		source interface{}
+		result gofrs.UUID
+	}{
+		{
+			source: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
+			result: gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present},
+		},
+		{
+			source: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
+			result: gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present},
+		},
+		{
+			source: "00010203-0405-0607-0809-0a0b0c0d0e0f",
+			result: gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present},
+		},
+	}
+
+	for i, tt := range successfulTests {
+		var r gofrs.UUID
+		err := r.Set(tt.source)
+		if err != nil {
+			t.Errorf("%d: %v", i, err)
+		}
+
+		if r != tt.result {
+			t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r)
+		}
+	}
+}
+
+func TestUUIDAssignTo(t *testing.T) {
+	{
+		src := gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}
+		var dst [16]byte
+		expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
+
+		err := src.AssignTo(&dst)
+		if err != nil {
+			t.Error(err)
+		}
+
+		if dst != expected {
+			t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst)
+		}
+	}
+
+	{
+		src := gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}
+		var dst []byte
+		expected := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
+
+		err := src.AssignTo(&dst)
+		if err != nil {
+			t.Error(err)
+		}
+
+		if bytes.Compare(dst, expected) != 0 {
+			t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst)
+		}
+	}
+
+	{
+		src := gofrs.UUID{UUID: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}
+		var dst string
+		expected := "00010203-0405-0607-0809-0a0b0c0d0e0f"
+
+		err := src.AssignTo(&dst)
+		if err != nil {
+			t.Error(err)
+		}
+
+		if dst != expected {
+			t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst)
+		}
+	}
+
+}
diff --git a/query_test.go b/query_test.go
index ea1fd66e..500399b9 100644
--- a/query_test.go
+++ b/query_test.go
@@ -11,10 +11,10 @@ import (
 	"time"
 
 	"github.com/cockroachdb/apd"
+	"github.com/gofrs/uuid"
 	"github.com/jackc/pgx"
 	"github.com/jackc/pgx/pgtype"
-	satori "github.com/jackc/pgx/pgtype/ext/satori-uuid"
-	uuid "github.com/satori/go.uuid"
+	gofrs "github.com/jackc/pgx/pgtype/ext/gofrs-uuid"
 	"github.com/shopspring/decimal"
 )
 
@@ -1140,7 +1140,7 @@ func TestConnQueryDatabaseSQLDriverValuerWithBinaryPgTypeThatAcceptsSameType(t *
 	defer closeConn(t, conn)
 
 	conn.ConnInfo.RegisterDataType(pgtype.DataType{
-		Value: &satori.UUID{},
+		Value: &gofrs.UUID{},
 		Name:  "uuid",
 		OID:   2950,
 	})
diff --git a/travis/install.bash b/travis/install.bash
index 3c3e44cf..c3b344e5 100755
--- a/travis/install.bash
+++ b/travis/install.bash
@@ -4,6 +4,7 @@ set -eux
 go get -u github.com/cockroachdb/apd
 go get -u github.com/shopspring/decimal
 go get -u gopkg.in/inconshreveable/log15.v2
+go get -u github.com/gofrs/uuid
 go get -u github.com/jackc/fake
 go get -u github.com/lib/pq
 go get -u github.com/hashicorp/go-version