diff --git a/pgtype/array.go b/pgtype/array.go index ebe537e8..75d2e440 100644 --- a/pgtype/array.go +++ b/pgtype/array.go @@ -102,9 +102,7 @@ type UntypedTextArray struct { } func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { - uta := &UntypedTextArray{ - Elements: []string{}, - } + uta := &UntypedTextArray{} buf := bytes.NewBufferString(src) @@ -235,13 +233,12 @@ func ParseUntypedTextArray(src string) (*UntypedTextArray, error) { return nil, fmt.Errorf("unexpected trailing data: %v", buf.String()) } - if len(explicitDimensions) > 0 { + if len(uta.Elements) == 0 { + uta.Dimensions = nil + } else if len(explicitDimensions) > 0 { uta.Dimensions = explicitDimensions } else { uta.Dimensions = implicitDimensions - if len(uta.Dimensions) == 1 && uta.Dimensions[0].Length == 0 { - uta.Dimensions = []ArrayDimension{} - } } return uta, nil @@ -334,3 +331,45 @@ func arrayParseInteger(buf *bytes.Buffer) (int32, error) { } } } + +func EncodeTextArrayDimensions(w io.Writer, dimensions []ArrayDimension) error { + var customDimensions bool + for _, dim := range dimensions { + if dim.LowerBound != 1 { + customDimensions = true + } + } + + if !customDimensions { + return nil + } + + for _, dim := range dimensions { + err := pgio.WriteByte(w, '[') + if err != nil { + return err + } + + _, err = io.WriteString(w, strconv.FormatInt(int64(dim.LowerBound), 10)) + if err != nil { + return err + } + + err = pgio.WriteByte(w, ':') + if err != nil { + return err + } + + _, err = io.WriteString(w, strconv.FormatInt(int64(dim.LowerBound+dim.Length-1), 10)) + if err != nil { + return err + } + + err = pgio.WriteByte(w, ']') + if err != nil { + return err + } + } + + return pgio.WriteByte(w, '=') +} diff --git a/pgtype/array_test.go b/pgtype/array_test.go index 3f527653..5e5f00e7 100644 --- a/pgtype/array_test.go +++ b/pgtype/array_test.go @@ -15,8 +15,8 @@ func TestParseUntypedTextArray(t *testing.T) { { source: "{}", result: pgtype.UntypedTextArray{ - Elements: []string{}, - Dimensions: []pgtype.ArrayDimension{}, + Elements: nil, + Dimensions: nil, }, }, { diff --git a/pgtype/int2array.go b/pgtype/int2array.go index eb9cd32a..7345305f 100644 --- a/pgtype/int2array.go +++ b/pgtype/int2array.go @@ -22,6 +22,47 @@ func (a *Int2Array) AssignTo(dst interface{}) error { } func (a *Int2Array) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *a = Int2Array{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = io.ReadFull(r, buf) + if err != nil { + return err + } + + uta, err := ParseUntypedTextArray(string(buf)) + if err != nil { + return err + } + + textElementReader := NewTextElementReader(r) + var elements []Int2 + + if len(uta.Elements) > 0 { + elements = make([]Int2, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Int2 + textElementReader.Reset(s) + err = elem.DecodeText(textElementReader) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *a = Int2Array{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + return nil } @@ -70,7 +111,76 @@ func (a *Int2Array) EncodeText(w io.Writer) error { return err } - return nil + if len(a.Dimensions) == 0 { + _, err := pgio.WriteInt32(w, 2) + if err != nil { + return err + } + + _, err = w.Write([]byte("{}")) + return err + } + + buf := &bytes.Buffer{} + + err := EncodeTextArrayDimensions(buf, a.Dimensions) + if err != nil { + return err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(a.Dimensions)) + dimElemCounts[len(a.Dimensions)-1] = int(a.Dimensions[len(a.Dimensions)-1].Length) + for i := len(a.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(a.Dimensions[i].Length) * dimElemCounts[i+1] + } + + textElementWriter := NewTextElementWriter(buf) + + for i, elem := range a.Elements { + if i > 0 { + err = pgio.WriteByte(buf, ',') + if err != nil { + return err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(buf, '{') + if err != nil { + return err + } + } + } + + textElementWriter.Reset() + err = elem.EncodeText(textElementWriter) + if err != nil { + return err + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(buf, '}') + if err != nil { + return err + } + } + } + } + + _, err = pgio.WriteInt32(w, int32(buf.Len())) + if err != nil { + return err + } + + _, err = buf.WriteTo(w) + return err } func (a *Int2Array) EncodeBinary(w io.Writer) error { @@ -94,8 +204,7 @@ func (a *Int2Array) EncodeBinary(w io.Writer) error { } } - // TODO - don't use magic number. Types with fixed OIDs should be constants. - arrayHeader.ElementOID = 21 + arrayHeader.ElementOID = Int2OID arrayHeader.Dimensions = a.Dimensions // TODO - consider how to avoid having to buffer array before writing length - diff --git a/pgtype/int2array_test.go b/pgtype/int2array_test.go index 0f5bfeaf..5ea81990 100644 --- a/pgtype/int2array_test.go +++ b/pgtype/int2array_test.go @@ -22,6 +22,31 @@ func TestInt2ArrayTranscode(t *testing.T) { Status: pgtype.Present, }, &pgtype.Int2Array{Status: pgtype.Null}, + &pgtype.Int2Array{ + Elements: []pgtype.Int2{ + pgtype.Int2{Int: 1, Status: pgtype.Present}, + pgtype.Int2{Int: 2, Status: pgtype.Present}, + pgtype.Int2{Int: 3, Status: pgtype.Present}, + pgtype.Int2{Int: 4, Status: pgtype.Present}, + pgtype.Int2{Status: pgtype.Null}, + pgtype.Int2{Int: 6, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.Int2Array{ + Elements: []pgtype.Int2{ + pgtype.Int2{Int: 1, Status: pgtype.Present}, + pgtype.Int2{Int: 2, Status: pgtype.Present}, + pgtype.Int2{Int: 3, Status: pgtype.Present}, + pgtype.Int2{Int: 4, Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, }) } diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 06c20db9..f9833363 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -7,6 +7,50 @@ import ( "github.com/jackc/pgx/pgio" ) +// PostgreSQL oids for common types +const ( + BoolOID = 16 + ByteaOID = 17 + CharOID = 18 + NameOID = 19 + Int8OID = 20 + Int2OID = 21 + Int4OID = 23 + TextOID = 25 + OIDOID = 26 + TidOID = 27 + XidOID = 28 + CidOID = 29 + JSONOID = 114 + CidrOID = 650 + CidrArrayOID = 651 + Float4OID = 700 + Float8OID = 701 + UnknownOID = 705 + InetOID = 869 + BoolArrayOID = 1000 + Int2ArrayOID = 1005 + Int4ArrayOID = 1007 + TextArrayOID = 1009 + ByteaArrayOID = 1001 + VarcharArrayOID = 1015 + Int8ArrayOID = 1016 + Float4ArrayOID = 1021 + Float8ArrayOID = 1022 + AclItemOID = 1033 + AclItemArrayOID = 1034 + InetArrayOID = 1041 + VarcharOID = 1043 + DateOID = 1082 + TimestampOID = 1114 + TimestampArrayOID = 1115 + TimestampTzOID = 1184 + TimestampTzArrayOID = 1185 + RecordOID = 2249 + UUIDOID = 2950 + JSONBOID = 3802 +) + type Status byte const ( diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index 783062f7..a1a575f7 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -80,7 +80,7 @@ func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []int name string formatCode int16 }{ - // {name: "TextFormat", formatCode: pgx.TextFormatCode}, + {name: "TextFormat", formatCode: pgx.TextFormatCode}, {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, } diff --git a/pgtype/text_element.go b/pgtype/text_element.go new file mode 100644 index 00000000..1a585d08 --- /dev/null +++ b/pgtype/text_element.go @@ -0,0 +1,112 @@ +package pgtype + +import ( + "bytes" + "errors" + "io" + + "github.com/jackc/pgx/pgio" +) + +// TextElementWriter is a wrapper that makes TextEncoders composable into other +// TextEncoders. TextEncoder first writes the length of the subsequent value. +// This is not necessary when the value is part of another value such as an +// array. TextElementWriter requires one int32 to be written first which it +// ignores. No other integer writes are valid. +type TextElementWriter struct { + w io.Writer + lengthHeaderIgnored bool +} + +func NewTextElementWriter(w io.Writer) *TextElementWriter { + return &TextElementWriter{w: w} +} + +func (w *TextElementWriter) WriteUint16(n uint16) (int, error) { + return 0, errors.New("WriteUint16 should never be called on TextElementWriter") +} + +func (w *TextElementWriter) WriteUint32(n uint32) (int, error) { + if !w.lengthHeaderIgnored { + w.lengthHeaderIgnored = true + + if int32(n) == -1 { + return io.WriteString(w.w, "NULL") + } + + return 4, nil + } + + return 0, errors.New("WriteUint32 should only be called once on TextElementWriter") +} + +func (w *TextElementWriter) WriteUint64(n uint64) (int, error) { + if w.lengthHeaderIgnored { + return pgio.WriteUint64(w.w, n) + } + + return 0, errors.New("WriteUint64 should never be called on TextElementWriter") +} + +func (w *TextElementWriter) Write(buf []byte) (int, error) { + if w.lengthHeaderIgnored { + return w.w.Write(buf) + } + + return 0, errors.New("int32 must be written first") +} + +func (w *TextElementWriter) Reset() { + w.lengthHeaderIgnored = false +} + +// TextElementReader is a wrapper that makes TextDecoders composable into other +// TextDecoders. TextEncoders first read the length of the subsequent value. +// This length value is not present when the value is part of another value such +// as an array. TextElementReader provides a substitute length value from the +// length of the string. No other integer reads are valid. Each time DecodeText +// is called with a TextElementReader as the source the TextElementReader must +// first have Reset called with the new element string data. +type TextElementReader struct { + buf *bytes.Buffer + lengthHeaderIgnored bool +} + +func NewTextElementReader(r io.Reader) *TextElementReader { + return &TextElementReader{buf: &bytes.Buffer{}} +} + +func (r *TextElementReader) ReadUint16() (uint16, error) { + return 0, errors.New("ReadUint16 should never be called on TextElementReader") +} + +func (r *TextElementReader) ReadUint32() (uint32, error) { + if !r.lengthHeaderIgnored { + r.lengthHeaderIgnored = true + if r.buf.String() == "NULL" { + n32 := int32(-1) + return uint32(n32), nil + } + return uint32(r.buf.Len()), nil + } + + return 0, errors.New("ReadUint32 should only be called once on TextElementReader") +} + +func (r *TextElementReader) WriteUint64(n uint64) (int, error) { + return 0, errors.New("ReadUint64 should never be called on TextElementReader") +} + +func (r *TextElementReader) Read(buf []byte) (int, error) { + if r.lengthHeaderIgnored { + return r.buf.Read(buf) + } + + return 0, errors.New("int32 must be read first") +} + +func (r *TextElementReader) Reset(s string) { + r.lengthHeaderIgnored = false + r.buf.Reset() + r.buf.WriteString(s) +}