From 1e36edf4b0076fbdeb8ab1aa00c2dfc3c2af9d85 Mon Sep 17 00:00:00 2001 From: Kelsey Francis Date: Sun, 27 Aug 2017 19:31:22 -0700 Subject: [PATCH] Add UUIDArray type Also change UUID.Set() to convert nil to NULL in order for UUIDArray.Set() to support converting [][]byte slices that contain nil. --- pgtype/pgtype.go | 2 + pgtype/uuid.go | 17 +- pgtype/uuid_array.go | 355 ++++++++++++++++++++++++++++++++++++++ pgtype/uuid_array_test.go | 205 ++++++++++++++++++++++ pgtype/uuid_test.go | 8 + 5 files changed, 583 insertions(+), 4 deletions(-) create mode 100644 pgtype/uuid_array.go create mode 100644 pgtype/uuid_array_test.go diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 00175a30..be13ec77 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -49,6 +49,7 @@ const ( NumericOID = 1700 RecordOID = 2249 UUIDOID = 2950 + UUIDArrayOID = 2951 JSONBOID = 3802 ) @@ -223,6 +224,7 @@ func init() { "_text": &TextArray{}, "_timestamp": &TimestampArray{}, "_timestamptz": &TimestamptzArray{}, + "_uuid": &UUIDArray{}, "_varchar": &VarcharArray{}, "aclitem": &ACLItem{}, "bool": &Bool{}, diff --git a/pgtype/uuid.go b/pgtype/uuid.go index 33e79536..f8297b39 100644 --- a/pgtype/uuid.go +++ b/pgtype/uuid.go @@ -14,15 +14,24 @@ type UUID struct { } func (dst *UUID) Set(src interface{}) error { + if src == nil { + *dst = UUID{Status: Null} + return nil + } + switch value := src.(type) { case [16]byte: *dst = UUID{Bytes: value, Status: Present} case []byte: - if len(value) != 16 { - return errors.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value)) + if value != nil { + if len(value) != 16 { + return errors.Errorf("[]byte must be 16 bytes to convert to UUID: %d", len(value)) + } + *dst = UUID{Status: Present} + copy(dst.Bytes[:], value) + } else { + *dst = UUID{Status: Null} } - *dst = UUID{Status: Present} - copy(dst.Bytes[:], value) case string: uuid, err := parseUUID(value) if err != nil { diff --git a/pgtype/uuid_array.go b/pgtype/uuid_array.go new file mode 100644 index 00000000..c18aec4f --- /dev/null +++ b/pgtype/uuid_array.go @@ -0,0 +1,355 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + + "github.com/jackc/pgx/pgio" + "github.com/pkg/errors" +) + +type UUIDArray struct { + Elements []UUID + Dimensions []ArrayDimension + Status Status +} + +func (dst *UUIDArray) Set(src interface{}) error { + if src == nil { + *dst = UUIDArray{Status: Null} + return nil + } + + switch value := src.(type) { + + case [][16]byte: + if value == nil { + *dst = UUIDArray{Status: Null} + } else if len(value) == 0 { + *dst = UUIDArray{Status: Present} + } else { + elements := make([]UUID, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = UUIDArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case [][]byte: + if value == nil { + *dst = UUIDArray{Status: Null} + } else if len(value) == 0 { + *dst = UUIDArray{Status: Present} + } else { + elements := make([]UUID, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = UUIDArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + case []string: + if value == nil { + *dst = UUIDArray{Status: Null} + } else if len(value) == 0 { + *dst = UUIDArray{Status: Present} + } else { + elements := make([]UUID, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = UUIDArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingPtrType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to UUIDArray", value) + } + + return nil +} + +func (dst *UUIDArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *UUIDArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + + case *[][16]byte: + *v = make([][16]byte, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[][]byte: + *v = make([][]byte, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + case *[]string: + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *UUIDArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = UUIDArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []UUID + + if len(uta.Elements) > 0 { + elements = make([]UUID, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem UUID + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = UUIDArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *UUIDArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = UUIDArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = UUIDArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]UUID, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = UUIDArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *UUIDArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +func (src *UUIDArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("uuid"); ok { + arrayHeader.ElementOID = int32(dt.OID) + } else { + return nil, errors.Errorf("unable to find oid for type name %v", "uuid") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + buf = arrayHeader.EncodeBinary(ci, buf) + + for i := range src.Elements { + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + + elemBuf, err := src.Elements[i].EncodeBinary(ci, buf) + if err != nil { + return nil, err + } + if elemBuf != nil { + buf = elemBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *UUIDArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *UUIDArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/uuid_array_test.go b/pgtype/uuid_array_test.go new file mode 100644 index 00000000..ee9d3dfa --- /dev/null +++ b/pgtype/uuid_array_test.go @@ -0,0 +1,205 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestUUIDArrayTranscode(t *testing.T) { + testutil.TestSuccessfulTranscode(t, "uuid[]", []interface{}{ + &pgtype.UUIDArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.UUIDArray{Status: pgtype.Null}, + &pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, + {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Status: pgtype.Present}, + {Status: pgtype.Null}, + {Bytes: [16]byte{64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79}, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, + {Bytes: [16]byte{48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestUUIDArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.UUIDArray + }{ + { + source: nil, + result: pgtype.UUIDArray{Status: pgtype.Null}, + }, + { + source: [][16]byte{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, + result: pgtype.UUIDArray{ + Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][16]byte{}, + result: pgtype.UUIDArray{Status: pgtype.Present}, + }, + { + source: ([][16]byte)(nil), + result: pgtype.UUIDArray{Status: pgtype.Null}, + }, + { + source: [][]byte{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, + result: pgtype.UUIDArray{ + Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: [][]byte{ + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + {16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, + nil, + {32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, + }, + result: pgtype.UUIDArray{ + Elements: []pgtype.UUID{ + {Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + {Bytes: [16]byte{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, Status: pgtype.Present}, + {Status: pgtype.Null}, + {Bytes: [16]byte{32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 4}}, + Status: pgtype.Present}, + }, + { + source: [][]byte{}, + result: pgtype.UUIDArray{Status: pgtype.Present}, + }, + { + source: ([][]byte)(nil), + result: pgtype.UUIDArray{Status: pgtype.Null}, + }, + { + source: []string{"00010203-0405-0607-0809-0a0b0c0d0e0f"}, + result: pgtype.UUIDArray{ + Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: []string{}, + result: pgtype.UUIDArray{Status: pgtype.Present}, + }, + { + source: ([]string)(nil), + result: pgtype.UUIDArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.UUIDArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestUUIDArrayAssignTo(t *testing.T) { + var byteArraySlice [][16]byte + var byteSliceSlice [][]byte + var stringSlice []string + + simpleTests := []struct { + src pgtype.UUIDArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.UUIDArray{ + Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &byteArraySlice, + expected: [][16]byte{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, + }, + { + src: pgtype.UUIDArray{Status: pgtype.Null}, + dst: &byteArraySlice, + expected: ([][16]byte)(nil), + }, + { + src: pgtype.UUIDArray{ + Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &byteSliceSlice, + expected: [][]byte{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}}, + }, + { + src: pgtype.UUIDArray{Status: pgtype.Null}, + dst: &byteSliceSlice, + expected: ([][]byte)(nil), + }, + { + src: pgtype.UUIDArray{ + Elements: []pgtype.UUID{{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + expected: []string{"00010203-0405-0607-0809-0a0b0c0d0e0f"}, + }, + { + src: pgtype.UUIDArray{Status: pgtype.Null}, + dst: &stringSlice, + expected: ([]string)(nil), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/pgtype/uuid_test.go b/pgtype/uuid_test.go index 5ab52b35..162d999f 100644 --- a/pgtype/uuid_test.go +++ b/pgtype/uuid_test.go @@ -20,6 +20,10 @@ func TestUUIDSet(t *testing.T) { source interface{} result pgtype.UUID }{ + { + source: nil, + result: pgtype.UUID{Status: pgtype.Null}, + }, { source: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, @@ -28,6 +32,10 @@ func TestUUIDSet(t *testing.T) { source: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, }, + { + source: ([]byte)(nil), + result: pgtype.UUID{Status: pgtype.Null}, + }, { source: "00010203-0405-0607-0809-0a0b0c0d0e0f", result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present},