From 86620c5e91fc6cc990e72214c91fa54cf88014ec Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Mar 2017 13:32:32 -0600 Subject: [PATCH] Add pgtype.ByteaArray Also fix up quoting array elements for text arrays. --- array.go | 14 +++ boolarray.go | 7 +- byteaarray.go | 287 ++++++++++++++++++++++++++++++++++++++++++++ byteaarray_test.go | 119 ++++++++++++++++++ datearray.go | 7 +- float4array.go | 7 +- float8array.go | 7 +- inetarray.go | 7 +- int2array.go | 7 +- int4array.go | 7 +- int8array.go | 7 +- textarray.go | 7 +- textarray_test.go | 8 +- timestamparray.go | 7 +- timestamptzarray.go | 7 +- typed_array.go.erb | 7 +- typed_array_gen.sh | 1 + 17 files changed, 437 insertions(+), 76 deletions(-) create mode 100644 byteaarray.go create mode 100644 byteaarray_test.go diff --git a/array.go b/array.go index 6b705103..90092c8d 100644 --- a/array.go +++ b/array.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "strconv" + "strings" "unicode" "github.com/jackc/pgx/pgio" @@ -371,3 +372,16 @@ func EncodeTextArrayDimensions(w io.Writer, dimensions []ArrayDimension) error { return pgio.WriteByte(w, '=') } + +var quoteArrayReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) + +func quoteArrayElement(src string) string { + return `"` + quoteArrayReplacer.Replace(src) + `"` +} + +func QuoteArrayElementIfNeeded(src string) string { + if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || src[0] == ' ' || src[len(src)-1] == ' ' || strings.ContainsAny(src, `{},"\`) { + return quoteArrayElement(src) + } + return src +} diff --git a/boolarray.go b/boolarray.go index f7323281..65a6bc9c 100644 --- a/boolarray.go +++ b/boolarray.go @@ -208,13 +208,8 @@ func (src *BoolArray) EncodeText(w io.Writer) (bool, error) { if err != nil { return false, err } - } else if elemBuf.Len() == 0 { - _, err = io.WriteString(w, `""`) - if err != nil { - return false, err - } } else { - _, err = elemBuf.WriteTo(w) + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) if err != nil { return false, err } diff --git a/byteaarray.go b/byteaarray.go new file mode 100644 index 00000000..7a4f1601 --- /dev/null +++ b/byteaarray.go @@ -0,0 +1,287 @@ +package pgtype + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type ByteaArray struct { + Elements []Bytea + Dimensions []ArrayDimension + Status Status +} + +func (dst *ByteaArray) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case ByteaArray: + *dst = value + + case [][]byte: + if value == nil { + *dst = ByteaArray{Status: Null} + } else if len(value) == 0 { + *dst = ByteaArray{Status: Present} + } else { + elements := make([]Bytea, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *dst = ByteaArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Bytea", value) + } + + return nil +} + +func (src *ByteaArray) AssignTo(dst interface{}) error { + switch v := dst.(type) { + + case *[][]byte: + if src.Status == Present { + *v = make([][]byte, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil +} + +func (dst *ByteaArray) DecodeText(src []byte) error { + if src == nil { + *dst = ByteaArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Bytea + + if len(uta.Elements) > 0 { + elements = make([]Bytea, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Bytea + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = ByteaArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *ByteaArray) DecodeBinary(src []byte) error { + if src == nil { + *dst = ByteaArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = ByteaArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Bytea, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(elemSrc) + if err != nil { + return err + } + } + + *dst = ByteaArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *ByteaArray) EncodeText(w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if len(src.Dimensions) == 0 { + _, err := io.WriteString(w, "{}") + return false, err + } + + err := EncodeTextArrayDimensions(w, src.Dimensions) + if err != nil { + return false, err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(w, ',') + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(w, '{') + if err != nil { + return false, err + } + } + } + + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(elemBuf) + if err != nil { + return false, err + } + if null { + _, err = io.WriteString(w, `NULL`) + if err != nil { + return false, err + } + } else { + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(w, '}') + if err != nil { + return false, err + } + } + } + } + + return false, nil +} + +func (src *ByteaArray) EncodeBinary(w io.Writer) (bool, error) { + return src.encodeBinary(w, ByteaOID) +} + +func (src *ByteaArray) encodeBinary(w io.Writer, elementOID int32) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + arrayHeader := ArrayHeader{ + ElementOID: elementOID, + Dimensions: src.Dimensions, + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(w) + if err != nil { + return false, err + } + + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(elemBuf) + if err != nil { + return false, err + } + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } + } + } + + return false, err +} diff --git a/byteaarray_test.go b/byteaarray_test.go new file mode 100644 index 00000000..b39776d9 --- /dev/null +++ b/byteaarray_test.go @@ -0,0 +1,119 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestByteaArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "bytea[]", []interface{}{ + &pgtype.ByteaArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.ByteaArray{ + Elements: []pgtype.Bytea{ + pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + pgtype.Bytea{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.ByteaArray{Status: pgtype.Null}, + &pgtype.ByteaArray{ + Elements: []pgtype.Bytea{ + pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present}, + pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + pgtype.Bytea{Status: pgtype.Null}, + pgtype.Bytea{Bytes: []byte{1}, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.ByteaArray{ + Elements: []pgtype.Bytea{ + pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present}, + pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + pgtype.Bytea{Bytes: []byte{1}, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestByteaArrayConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.ByteaArray + }{ + { + source: [][]byte{{1, 2, 3}}, + result: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([][]byte)(nil)), + result: pgtype.ByteaArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.ByteaArray + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestByteaArrayAssignTo(t *testing.T) { + var byteByteSlice [][]byte + + simpleTests := []struct { + src pgtype.ByteaArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.ByteaArray{ + Elements: []pgtype.Bytea{{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &byteByteSlice, + expected: [][]byte{{1, 2, 3}}, + }, + { + src: pgtype.ByteaArray{Status: pgtype.Null}, + dst: &byteByteSlice, + expected: (([][]byte)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/datearray.go b/datearray.go index 9552739b..623ff9b3 100644 --- a/datearray.go +++ b/datearray.go @@ -209,13 +209,8 @@ func (src *DateArray) EncodeText(w io.Writer) (bool, error) { if err != nil { return false, err } - } else if elemBuf.Len() == 0 { - _, err = io.WriteString(w, `""`) - if err != nil { - return false, err - } } else { - _, err = elemBuf.WriteTo(w) + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) if err != nil { return false, err } diff --git a/float4array.go b/float4array.go index 9ab08dcc..c55f76d0 100644 --- a/float4array.go +++ b/float4array.go @@ -208,13 +208,8 @@ func (src *Float4Array) EncodeText(w io.Writer) (bool, error) { if err != nil { return false, err } - } else if elemBuf.Len() == 0 { - _, err = io.WriteString(w, `""`) - if err != nil { - return false, err - } } else { - _, err = elemBuf.WriteTo(w) + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) if err != nil { return false, err } diff --git a/float8array.go b/float8array.go index ce7e3b90..d08a5351 100644 --- a/float8array.go +++ b/float8array.go @@ -208,13 +208,8 @@ func (src *Float8Array) EncodeText(w io.Writer) (bool, error) { if err != nil { return false, err } - } else if elemBuf.Len() == 0 { - _, err = io.WriteString(w, `""`) - if err != nil { - return false, err - } } else { - _, err = elemBuf.WriteTo(w) + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) if err != nil { return false, err } diff --git a/inetarray.go b/inetarray.go index 32cde554..12d9493b 100644 --- a/inetarray.go +++ b/inetarray.go @@ -240,13 +240,8 @@ func (src *InetArray) EncodeText(w io.Writer) (bool, error) { if err != nil { return false, err } - } else if elemBuf.Len() == 0 { - _, err = io.WriteString(w, `""`) - if err != nil { - return false, err - } } else { - _, err = elemBuf.WriteTo(w) + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) if err != nil { return false, err } diff --git a/int2array.go b/int2array.go index f7cc2492..37ee9926 100644 --- a/int2array.go +++ b/int2array.go @@ -239,13 +239,8 @@ func (src *Int2Array) EncodeText(w io.Writer) (bool, error) { if err != nil { return false, err } - } else if elemBuf.Len() == 0 { - _, err = io.WriteString(w, `""`) - if err != nil { - return false, err - } } else { - _, err = elemBuf.WriteTo(w) + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) if err != nil { return false, err } diff --git a/int4array.go b/int4array.go index fa710af7..f6f62e4b 100644 --- a/int4array.go +++ b/int4array.go @@ -239,13 +239,8 @@ func (src *Int4Array) EncodeText(w io.Writer) (bool, error) { if err != nil { return false, err } - } else if elemBuf.Len() == 0 { - _, err = io.WriteString(w, `""`) - if err != nil { - return false, err - } } else { - _, err = elemBuf.WriteTo(w) + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) if err != nil { return false, err } diff --git a/int8array.go b/int8array.go index 65f42477..92d8ec46 100644 --- a/int8array.go +++ b/int8array.go @@ -239,13 +239,8 @@ func (src *Int8Array) EncodeText(w io.Writer) (bool, error) { if err != nil { return false, err } - } else if elemBuf.Len() == 0 { - _, err = io.WriteString(w, `""`) - if err != nil { - return false, err - } } else { - _, err = elemBuf.WriteTo(w) + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) if err != nil { return false, err } diff --git a/textarray.go b/textarray.go index c3e595e0..182e76f5 100644 --- a/textarray.go +++ b/textarray.go @@ -208,13 +208,8 @@ func (src *TextArray) EncodeText(w io.Writer) (bool, error) { if err != nil { return false, err } - } else if elemBuf.Len() == 0 { - _, err = io.WriteString(w, `""`) - if err != nil { - return false, err - } } else { - _, err = elemBuf.WriteTo(w) + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) if err != nil { return false, err } diff --git a/textarray_test.go b/textarray_test.go index 29e3a6c7..a22e003d 100644 --- a/textarray_test.go +++ b/textarray_test.go @@ -25,12 +25,12 @@ func TestTextArrayTranscode(t *testing.T) { &pgtype.TextArray{Status: pgtype.Null}, &pgtype.TextArray{ Elements: []pgtype.Text{ - pgtype.Text{String: "bar", Status: pgtype.Present}, - pgtype.Text{String: "baz", Status: pgtype.Present}, - pgtype.Text{String: "quz", Status: pgtype.Present}, + pgtype.Text{String: "bar ", Status: pgtype.Present}, + pgtype.Text{String: "NuLL", Status: pgtype.Present}, + pgtype.Text{String: `wow"quz\`, Status: pgtype.Present}, pgtype.Text{String: "", Status: pgtype.Present}, pgtype.Text{Status: pgtype.Null}, - pgtype.Text{String: "foo", Status: pgtype.Present}, + pgtype.Text{String: "null", Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, Status: pgtype.Present, diff --git a/timestamparray.go b/timestamparray.go index 21e4de98..b0fb25fa 100644 --- a/timestamparray.go +++ b/timestamparray.go @@ -209,13 +209,8 @@ func (src *TimestampArray) EncodeText(w io.Writer) (bool, error) { if err != nil { return false, err } - } else if elemBuf.Len() == 0 { - _, err = io.WriteString(w, `""`) - if err != nil { - return false, err - } } else { - _, err = elemBuf.WriteTo(w) + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) if err != nil { return false, err } diff --git a/timestamptzarray.go b/timestamptzarray.go index 597b1842..25374717 100644 --- a/timestamptzarray.go +++ b/timestamptzarray.go @@ -209,13 +209,8 @@ func (src *TimestamptzArray) EncodeText(w io.Writer) (bool, error) { if err != nil { return false, err } - } else if elemBuf.Len() == 0 { - _, err = io.WriteString(w, `""`) - if err != nil { - return false, err - } } else { - _, err = elemBuf.WriteTo(w) + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) if err != nil { return false, err } diff --git a/typed_array.go.erb b/typed_array.go.erb index 2e9b77ea..f9dba308 100644 --- a/typed_array.go.erb +++ b/typed_array.go.erb @@ -207,13 +207,8 @@ func (src *<%= pgtype_array_type %>) EncodeText(w io.Writer) (bool, error) { if err != nil { return false, err } - } else if elemBuf.Len() == 0 { - _, err = io.WriteString(w, `""`) - if err != nil { - return false, err - } } else { - _, err = elemBuf.WriteTo(w) + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) if err != nil { return false, err } diff --git a/typed_array_gen.sh b/typed_array_gen.sh index 43109700..c63414c8 100644 --- a/typed_array_gen.sh +++ b/typed_array_gen.sh @@ -9,3 +9,4 @@ erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]fl erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_oid=Float8OID text_null=NULL typed_array.go.erb > float8array.go erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_oid=InetOID text_null=NULL typed_array.go.erb > inetarray.go erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_oid=TextOID text_null='"NULL"' typed_array.go.erb > textarray.go +erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_oid=ByteaOID text_null=NULL typed_array.go.erb > byteaarray.go