From 36da5cc2178d1a31a56dc6e6f128843bd80dea0b Mon Sep 17 00:00:00 2001
From: Jack Christensen <jack@jackchristensen.com>
Date: Tue, 28 Feb 2017 21:45:33 -0600
Subject: [PATCH] Add text array transcoding

---
 pgtype/array.go          |  53 +++++++++++++++---
 pgtype/array_test.go     |   4 +-
 pgtype/int2array.go      | 115 ++++++++++++++++++++++++++++++++++++++-
 pgtype/int2array_test.go |  25 +++++++++
 pgtype/pgtype.go         |  44 +++++++++++++++
 pgtype/pgtype_test.go    |   2 +-
 pgtype/text_element.go   | 112 ++++++++++++++++++++++++++++++++++++++
 7 files changed, 342 insertions(+), 13 deletions(-)
 create mode 100644 pgtype/text_element.go

diff --git a/pgtype/array.go b/pgtype/array.go
index ebe537e8..75d2e440 100644
--- a/pgtype/array.go
+++ b/pgtype/array.go
@@ -102,9 +102,7 @@ type UntypedTextArray struct {
 }
 
 func ParseUntypedTextArray(src string) (*UntypedTextArray, error) {
-	uta := &UntypedTextArray{
-		Elements: []string{},
-	}
+	uta := &UntypedTextArray{}
 
 	buf := bytes.NewBufferString(src)
 
@@ -235,13 +233,12 @@ func ParseUntypedTextArray(src string) (*UntypedTextArray, error) {
 		return nil, fmt.Errorf("unexpected trailing data: %v", buf.String())
 	}
 
-	if len(explicitDimensions) > 0 {
+	if len(uta.Elements) == 0 {
+		uta.Dimensions = nil
+	} else if len(explicitDimensions) > 0 {
 		uta.Dimensions = explicitDimensions
 	} else {
 		uta.Dimensions = implicitDimensions
-		if len(uta.Dimensions) == 1 && uta.Dimensions[0].Length == 0 {
-			uta.Dimensions = []ArrayDimension{}
-		}
 	}
 
 	return uta, nil
@@ -334,3 +331,45 @@ func arrayParseInteger(buf *bytes.Buffer) (int32, error) {
 		}
 	}
 }
+
+func EncodeTextArrayDimensions(w io.Writer, dimensions []ArrayDimension) error {
+	var customDimensions bool
+	for _, dim := range dimensions {
+		if dim.LowerBound != 1 {
+			customDimensions = true
+		}
+	}
+
+	if !customDimensions {
+		return nil
+	}
+
+	for _, dim := range dimensions {
+		err := pgio.WriteByte(w, '[')
+		if err != nil {
+			return err
+		}
+
+		_, err = io.WriteString(w, strconv.FormatInt(int64(dim.LowerBound), 10))
+		if err != nil {
+			return err
+		}
+
+		err = pgio.WriteByte(w, ':')
+		if err != nil {
+			return err
+		}
+
+		_, err = io.WriteString(w, strconv.FormatInt(int64(dim.LowerBound+dim.Length-1), 10))
+		if err != nil {
+			return err
+		}
+
+		err = pgio.WriteByte(w, ']')
+		if err != nil {
+			return err
+		}
+	}
+
+	return pgio.WriteByte(w, '=')
+}
diff --git a/pgtype/array_test.go b/pgtype/array_test.go
index 3f527653..5e5f00e7 100644
--- a/pgtype/array_test.go
+++ b/pgtype/array_test.go
@@ -15,8 +15,8 @@ func TestParseUntypedTextArray(t *testing.T) {
 		{
 			source: "{}",
 			result: pgtype.UntypedTextArray{
-				Elements:   []string{},
-				Dimensions: []pgtype.ArrayDimension{},
+				Elements:   nil,
+				Dimensions: nil,
 			},
 		},
 		{
diff --git a/pgtype/int2array.go b/pgtype/int2array.go
index eb9cd32a..7345305f 100644
--- a/pgtype/int2array.go
+++ b/pgtype/int2array.go
@@ -22,6 +22,47 @@ func (a *Int2Array) AssignTo(dst interface{}) error {
 }
 
 func (a *Int2Array) DecodeText(r io.Reader) error {
+	size, err := pgio.ReadInt32(r)
+	if err != nil {
+		return err
+	}
+
+	if size == -1 {
+		*a = Int2Array{Status: Null}
+		return nil
+	}
+
+	buf := make([]byte, int(size))
+	_, err = io.ReadFull(r, buf)
+	if err != nil {
+		return err
+	}
+
+	uta, err := ParseUntypedTextArray(string(buf))
+	if err != nil {
+		return err
+	}
+
+	textElementReader := NewTextElementReader(r)
+	var elements []Int2
+
+	if len(uta.Elements) > 0 {
+		elements = make([]Int2, len(uta.Elements))
+
+		for i, s := range uta.Elements {
+			var elem Int2
+			textElementReader.Reset(s)
+			err = elem.DecodeText(textElementReader)
+			if err != nil {
+				return err
+			}
+
+			elements[i] = elem
+		}
+	}
+
+	*a = Int2Array{Elements: elements, Dimensions: uta.Dimensions, Status: Present}
+
 	return nil
 }
 
@@ -70,7 +111,76 @@ func (a *Int2Array) EncodeText(w io.Writer) error {
 		return err
 	}
 
-	return nil
+	if len(a.Dimensions) == 0 {
+		_, err := pgio.WriteInt32(w, 2)
+		if err != nil {
+			return err
+		}
+
+		_, err = w.Write([]byte("{}"))
+		return err
+	}
+
+	buf := &bytes.Buffer{}
+
+	err := EncodeTextArrayDimensions(buf, a.Dimensions)
+	if err != nil {
+		return err
+	}
+
+	// dimElemCounts is the multiples of elements that each array lies on. For
+	// example, a single dimension array of length 4 would have a dimElemCounts of
+	// [4]. A multi-dimensional array of lengths [3,5,2] would have a
+	// dimElemCounts of [30,10,2]. This is used to simplify when to render a '{'
+	// or '}'.
+	dimElemCounts := make([]int, len(a.Dimensions))
+	dimElemCounts[len(a.Dimensions)-1] = int(a.Dimensions[len(a.Dimensions)-1].Length)
+	for i := len(a.Dimensions) - 2; i > -1; i-- {
+		dimElemCounts[i] = int(a.Dimensions[i].Length) * dimElemCounts[i+1]
+	}
+
+	textElementWriter := NewTextElementWriter(buf)
+
+	for i, elem := range a.Elements {
+		if i > 0 {
+			err = pgio.WriteByte(buf, ',')
+			if err != nil {
+				return err
+			}
+		}
+
+		for _, dec := range dimElemCounts {
+			if i%dec == 0 {
+				err = pgio.WriteByte(buf, '{')
+				if err != nil {
+					return err
+				}
+			}
+		}
+
+		textElementWriter.Reset()
+		err = elem.EncodeText(textElementWriter)
+		if err != nil {
+			return err
+		}
+
+		for _, dec := range dimElemCounts {
+			if (i+1)%dec == 0 {
+				err = pgio.WriteByte(buf, '}')
+				if err != nil {
+					return err
+				}
+			}
+		}
+	}
+
+	_, err = pgio.WriteInt32(w, int32(buf.Len()))
+	if err != nil {
+		return err
+	}
+
+	_, err = buf.WriteTo(w)
+	return err
 }
 
 func (a *Int2Array) EncodeBinary(w io.Writer) error {
@@ -94,8 +204,7 @@ func (a *Int2Array) EncodeBinary(w io.Writer) error {
 		}
 	}
 
-	// TODO - don't use magic number. Types with fixed OIDs should be constants.
-	arrayHeader.ElementOID = 21
+	arrayHeader.ElementOID = Int2OID
 	arrayHeader.Dimensions = a.Dimensions
 
 	// TODO - consider how to avoid having to buffer array before writing length -
diff --git a/pgtype/int2array_test.go b/pgtype/int2array_test.go
index 0f5bfeaf..5ea81990 100644
--- a/pgtype/int2array_test.go
+++ b/pgtype/int2array_test.go
@@ -22,6 +22,31 @@ func TestInt2ArrayTranscode(t *testing.T) {
 			Status:     pgtype.Present,
 		},
 		&pgtype.Int2Array{Status: pgtype.Null},
+		&pgtype.Int2Array{
+			Elements: []pgtype.Int2{
+				pgtype.Int2{Int: 1, Status: pgtype.Present},
+				pgtype.Int2{Int: 2, Status: pgtype.Present},
+				pgtype.Int2{Int: 3, Status: pgtype.Present},
+				pgtype.Int2{Int: 4, Status: pgtype.Present},
+				pgtype.Int2{Status: pgtype.Null},
+				pgtype.Int2{Int: 6, Status: pgtype.Present},
+			},
+			Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}},
+			Status:     pgtype.Present,
+		},
+		&pgtype.Int2Array{
+			Elements: []pgtype.Int2{
+				pgtype.Int2{Int: 1, Status: pgtype.Present},
+				pgtype.Int2{Int: 2, Status: pgtype.Present},
+				pgtype.Int2{Int: 3, Status: pgtype.Present},
+				pgtype.Int2{Int: 4, Status: pgtype.Present},
+			},
+			Dimensions: []pgtype.ArrayDimension{
+				{Length: 2, LowerBound: 4},
+				{Length: 2, LowerBound: 2},
+			},
+			Status: pgtype.Present,
+		},
 	})
 }
 
diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go
index 06c20db9..f9833363 100644
--- a/pgtype/pgtype.go
+++ b/pgtype/pgtype.go
@@ -7,6 +7,50 @@ import (
 	"github.com/jackc/pgx/pgio"
 )
 
+// PostgreSQL oids for common types
+const (
+	BoolOID             = 16
+	ByteaOID            = 17
+	CharOID             = 18
+	NameOID             = 19
+	Int8OID             = 20
+	Int2OID             = 21
+	Int4OID             = 23
+	TextOID             = 25
+	OIDOID              = 26
+	TidOID              = 27
+	XidOID              = 28
+	CidOID              = 29
+	JSONOID             = 114
+	CidrOID             = 650
+	CidrArrayOID        = 651
+	Float4OID           = 700
+	Float8OID           = 701
+	UnknownOID          = 705
+	InetOID             = 869
+	BoolArrayOID        = 1000
+	Int2ArrayOID        = 1005
+	Int4ArrayOID        = 1007
+	TextArrayOID        = 1009
+	ByteaArrayOID       = 1001
+	VarcharArrayOID     = 1015
+	Int8ArrayOID        = 1016
+	Float4ArrayOID      = 1021
+	Float8ArrayOID      = 1022
+	AclItemOID          = 1033
+	AclItemArrayOID     = 1034
+	InetArrayOID        = 1041
+	VarcharOID          = 1043
+	DateOID             = 1082
+	TimestampOID        = 1114
+	TimestampArrayOID   = 1115
+	TimestampTzOID      = 1184
+	TimestampTzArrayOID = 1185
+	RecordOID           = 2249
+	UUIDOID             = 2950
+	JSONBOID            = 3802
+)
+
 type Status byte
 
 const (
diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go
index 783062f7..a1a575f7 100644
--- a/pgtype/pgtype_test.go
+++ b/pgtype/pgtype_test.go
@@ -80,7 +80,7 @@ func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []int
 		name       string
 		formatCode int16
 	}{
-		// {name: "TextFormat", formatCode: pgx.TextFormatCode},
+		{name: "TextFormat", formatCode: pgx.TextFormatCode},
 		{name: "BinaryFormat", formatCode: pgx.BinaryFormatCode},
 	}
 
diff --git a/pgtype/text_element.go b/pgtype/text_element.go
new file mode 100644
index 00000000..1a585d08
--- /dev/null
+++ b/pgtype/text_element.go
@@ -0,0 +1,112 @@
+package pgtype
+
+import (
+	"bytes"
+	"errors"
+	"io"
+
+	"github.com/jackc/pgx/pgio"
+)
+
+// TextElementWriter is a wrapper that makes TextEncoders composable into other
+// TextEncoders. TextEncoder first writes the length of the subsequent value.
+// This is not necessary when the value is part of another value such as an
+// array. TextElementWriter requires one int32 to be written first which it
+// ignores. No other integer writes are valid.
+type TextElementWriter struct {
+	w                   io.Writer
+	lengthHeaderIgnored bool
+}
+
+func NewTextElementWriter(w io.Writer) *TextElementWriter {
+	return &TextElementWriter{w: w}
+}
+
+func (w *TextElementWriter) WriteUint16(n uint16) (int, error) {
+	return 0, errors.New("WriteUint16 should never be called on TextElementWriter")
+}
+
+func (w *TextElementWriter) WriteUint32(n uint32) (int, error) {
+	if !w.lengthHeaderIgnored {
+		w.lengthHeaderIgnored = true
+
+		if int32(n) == -1 {
+			return io.WriteString(w.w, "NULL")
+		}
+
+		return 4, nil
+	}
+
+	return 0, errors.New("WriteUint32 should only be called once on TextElementWriter")
+}
+
+func (w *TextElementWriter) WriteUint64(n uint64) (int, error) {
+	if w.lengthHeaderIgnored {
+		return pgio.WriteUint64(w.w, n)
+	}
+
+	return 0, errors.New("WriteUint64 should never be called on TextElementWriter")
+}
+
+func (w *TextElementWriter) Write(buf []byte) (int, error) {
+	if w.lengthHeaderIgnored {
+		return w.w.Write(buf)
+	}
+
+	return 0, errors.New("int32 must be written first")
+}
+
+func (w *TextElementWriter) Reset() {
+	w.lengthHeaderIgnored = false
+}
+
+// TextElementReader is a wrapper that makes TextDecoders composable into other
+// TextDecoders. TextEncoders first read the length of the subsequent value.
+// This length value is not present when the value is part of another value such
+// as an array. TextElementReader provides a substitute length value from the
+// length of the string. No other integer reads are valid. Each time DecodeText
+// is called with a TextElementReader as the source the TextElementReader must
+// first have Reset called with the new element string data.
+type TextElementReader struct {
+	buf                 *bytes.Buffer
+	lengthHeaderIgnored bool
+}
+
+func NewTextElementReader(r io.Reader) *TextElementReader {
+	return &TextElementReader{buf: &bytes.Buffer{}}
+}
+
+func (r *TextElementReader) ReadUint16() (uint16, error) {
+	return 0, errors.New("ReadUint16 should never be called on TextElementReader")
+}
+
+func (r *TextElementReader) ReadUint32() (uint32, error) {
+	if !r.lengthHeaderIgnored {
+		r.lengthHeaderIgnored = true
+		if r.buf.String() == "NULL" {
+			n32 := int32(-1)
+			return uint32(n32), nil
+		}
+		return uint32(r.buf.Len()), nil
+	}
+
+	return 0, errors.New("ReadUint32 should only be called once on TextElementReader")
+}
+
+func (r *TextElementReader) WriteUint64(n uint64) (int, error) {
+	return 0, errors.New("ReadUint64 should never be called on TextElementReader")
+}
+
+func (r *TextElementReader) Read(buf []byte) (int, error) {
+	if r.lengthHeaderIgnored {
+		return r.buf.Read(buf)
+	}
+
+	return 0, errors.New("int32 must be read first")
+}
+
+func (r *TextElementReader) Reset(s string) {
+	r.lengthHeaderIgnored = false
+	r.buf.Reset()
+	r.buf.WriteString(s)
+}