diff --git a/conn.go b/conn.go index 19833dc0..d97942aa 100644 --- a/conn.go +++ b/conn.go @@ -286,10 +286,14 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl Int4OID: &pgtype.Int4{}, Int8ArrayOID: &pgtype.Int8Array{}, Int8OID: &pgtype.Int8{}, + TextArrayOID: &pgtype.TextArray{}, + TextOID: &pgtype.Text{}, TimestampArrayOID: &pgtype.TimestampArray{}, TimestampOID: &pgtype.Timestamp{}, TimestampTzArrayOID: &pgtype.TimestamptzArray{}, TimestampTzOID: &pgtype.Timestamptz{}, + VarcharArrayOID: &pgtype.VarcharArray{}, + VarcharOID: &pgtype.Text{}, } if tlsConfig != nil { diff --git a/pgtype/array_test.go b/pgtype/array_test.go index 5e5f00e7..d1cdb4c5 100644 --- a/pgtype/array_test.go +++ b/pgtype/array_test.go @@ -40,6 +40,13 @@ func TestParseUntypedTextArray(t *testing.T) { Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, }, }, + { + source: `{""}`, + result: pgtype.UntypedTextArray{ + Elements: []string{""}, + Dimensions: []pgtype.ArrayDimension{{Length: 1, LowerBound: 1}}, + }, + }, { source: `{"He said, \"Hello.\""}`, result: pgtype.UntypedTextArray{ diff --git a/pgtype/bool.go b/pgtype/bool.go index 2889b787..076403f9 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -66,7 +66,7 @@ func (src *Bool) AssignTo(dst interface{}) error { return nil } } - return fmt.Errorf("cannot put decode %v into %T", src, dst) + return fmt.Errorf("cannot decode %v into %T", src, dst) } return nil diff --git a/pgtype/boolarray.go b/pgtype/boolarray.go index 8dd68dc2..b6b5db02 100644 --- a/pgtype/boolarray.go +++ b/pgtype/boolarray.go @@ -18,7 +18,7 @@ func (dst *BoolArray) ConvertFrom(src interface{}) error { switch value := src.(type) { case BoolArray: *dst = value - + case []bool: if value == nil { *dst = BoolArray{Status: Null} @@ -37,7 +37,7 @@ func (dst *BoolArray) ConvertFrom(src interface{}) error { Status: Present, } } - + default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.ConvertFrom(originalSrc) @@ -50,7 +50,7 @@ func (dst *BoolArray) ConvertFrom(src interface{}) error { func (src *BoolArray) AssignTo(dst interface{}) error { switch v := dst.(type) { - + case *[]bool: if src.Status == Present { *v = make([]bool, len(src.Elements)) @@ -62,12 +62,12 @@ func (src *BoolArray) AssignTo(dst interface{}) error { } else { *v = nil } - + default: if originalDst, ok := underlyingPtrSliceType(dst); ok { return src.AssignTo(originalDst) } - return fmt.Errorf("cannot put decode %v into %T", src, dst) + return fmt.Errorf("cannot decode %v into %T", src, dst) } return nil diff --git a/pgtype/convert.go b/pgtype/convert.go index 7111f8bc..31bbf060 100644 --- a/pgtype/convert.go +++ b/pgtype/convert.go @@ -85,6 +85,25 @@ func underlyingBoolType(val interface{}) (interface{}, bool) { return nil, false } +// underlyingStringType gets the underlying type that can be converted to String +func underlyingStringType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.String: + convVal := refVal.String() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + } + + return nil, false +} + // underlyingPtrType dereferences a pointer func underlyingPtrType(val interface{}) (interface{}, bool) { refVal := reflect.ValueOf(val) diff --git a/pgtype/datearray.go b/pgtype/datearray.go index 877f328e..5e93501e 100644 --- a/pgtype/datearray.go +++ b/pgtype/datearray.go @@ -68,7 +68,7 @@ func (src *DateArray) AssignTo(dst interface{}) error { if originalDst, ok := underlyingPtrSliceType(dst); ok { return src.AssignTo(originalDst) } - return fmt.Errorf("cannot put decode %v into %T", src, dst) + return fmt.Errorf("cannot decode %v into %T", src, dst) } return nil diff --git a/pgtype/float4array.go b/pgtype/float4array.go index c06490cf..8834d213 100644 --- a/pgtype/float4array.go +++ b/pgtype/float4array.go @@ -18,7 +18,7 @@ func (dst *Float4Array) ConvertFrom(src interface{}) error { switch value := src.(type) { case Float4Array: *dst = value - + case []float32: if value == nil { *dst = Float4Array{Status: Null} @@ -37,7 +37,7 @@ func (dst *Float4Array) ConvertFrom(src interface{}) error { Status: Present, } } - + default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.ConvertFrom(originalSrc) @@ -50,7 +50,7 @@ func (dst *Float4Array) ConvertFrom(src interface{}) error { func (src *Float4Array) AssignTo(dst interface{}) error { switch v := dst.(type) { - + case *[]float32: if src.Status == Present { *v = make([]float32, len(src.Elements)) @@ -62,12 +62,12 @@ func (src *Float4Array) AssignTo(dst interface{}) error { } else { *v = nil } - + default: if originalDst, ok := underlyingPtrSliceType(dst); ok { return src.AssignTo(originalDst) } - return fmt.Errorf("cannot put decode %v into %T", src, dst) + return fmt.Errorf("cannot decode %v into %T", src, dst) } return nil diff --git a/pgtype/float8array.go b/pgtype/float8array.go index 776fc1e6..bad9ed9f 100644 --- a/pgtype/float8array.go +++ b/pgtype/float8array.go @@ -18,7 +18,7 @@ func (dst *Float8Array) ConvertFrom(src interface{}) error { switch value := src.(type) { case Float8Array: *dst = value - + case []float64: if value == nil { *dst = Float8Array{Status: Null} @@ -37,7 +37,7 @@ func (dst *Float8Array) ConvertFrom(src interface{}) error { Status: Present, } } - + default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.ConvertFrom(originalSrc) @@ -50,7 +50,7 @@ func (dst *Float8Array) ConvertFrom(src interface{}) error { func (src *Float8Array) AssignTo(dst interface{}) error { switch v := dst.(type) { - + case *[]float64: if src.Status == Present { *v = make([]float64, len(src.Elements)) @@ -62,12 +62,12 @@ func (src *Float8Array) AssignTo(dst interface{}) error { } else { *v = nil } - + default: if originalDst, ok := underlyingPtrSliceType(dst); ok { return src.AssignTo(originalDst) } - return fmt.Errorf("cannot put decode %v into %T", src, dst) + return fmt.Errorf("cannot decode %v into %T", src, dst) } return nil diff --git a/pgtype/inetarray.go b/pgtype/inetarray.go index eb5a4c88..cd12e917 100644 --- a/pgtype/inetarray.go +++ b/pgtype/inetarray.go @@ -97,7 +97,7 @@ func (src *InetArray) AssignTo(dst interface{}) error { if originalDst, ok := underlyingPtrSliceType(dst); ok { return src.AssignTo(originalDst) } - return fmt.Errorf("cannot put decode %v into %T", src, dst) + return fmt.Errorf("cannot decode %v into %T", src, dst) } return nil diff --git a/pgtype/int2array.go b/pgtype/int2array.go index 4fc6d882..a989347d 100644 --- a/pgtype/int2array.go +++ b/pgtype/int2array.go @@ -18,7 +18,7 @@ func (dst *Int2Array) ConvertFrom(src interface{}) error { switch value := src.(type) { case Int2Array: *dst = value - + case []int16: if value == nil { *dst = Int2Array{Status: Null} @@ -37,7 +37,7 @@ func (dst *Int2Array) ConvertFrom(src interface{}) error { Status: Present, } } - + case []uint16: if value == nil { *dst = Int2Array{Status: Null} @@ -56,7 +56,7 @@ func (dst *Int2Array) ConvertFrom(src interface{}) error { Status: Present, } } - + default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.ConvertFrom(originalSrc) @@ -69,7 +69,7 @@ func (dst *Int2Array) ConvertFrom(src interface{}) error { func (src *Int2Array) AssignTo(dst interface{}) error { switch v := dst.(type) { - + case *[]int16: if src.Status == Present { *v = make([]int16, len(src.Elements)) @@ -81,7 +81,7 @@ func (src *Int2Array) AssignTo(dst interface{}) error { } else { *v = nil } - + case *[]uint16: if src.Status == Present { *v = make([]uint16, len(src.Elements)) @@ -93,12 +93,12 @@ func (src *Int2Array) AssignTo(dst interface{}) error { } else { *v = nil } - + default: if originalDst, ok := underlyingPtrSliceType(dst); ok { return src.AssignTo(originalDst) } - return fmt.Errorf("cannot put decode %v into %T", src, dst) + return fmt.Errorf("cannot decode %v into %T", src, dst) } return nil diff --git a/pgtype/int4array.go b/pgtype/int4array.go index 40e1490d..89caf263 100644 --- a/pgtype/int4array.go +++ b/pgtype/int4array.go @@ -18,7 +18,7 @@ func (dst *Int4Array) ConvertFrom(src interface{}) error { switch value := src.(type) { case Int4Array: *dst = value - + case []int32: if value == nil { *dst = Int4Array{Status: Null} @@ -37,7 +37,7 @@ func (dst *Int4Array) ConvertFrom(src interface{}) error { Status: Present, } } - + case []uint32: if value == nil { *dst = Int4Array{Status: Null} @@ -56,7 +56,7 @@ func (dst *Int4Array) ConvertFrom(src interface{}) error { Status: Present, } } - + default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.ConvertFrom(originalSrc) @@ -69,7 +69,7 @@ func (dst *Int4Array) ConvertFrom(src interface{}) error { func (src *Int4Array) AssignTo(dst interface{}) error { switch v := dst.(type) { - + case *[]int32: if src.Status == Present { *v = make([]int32, len(src.Elements)) @@ -81,7 +81,7 @@ func (src *Int4Array) AssignTo(dst interface{}) error { } else { *v = nil } - + case *[]uint32: if src.Status == Present { *v = make([]uint32, len(src.Elements)) @@ -93,12 +93,12 @@ func (src *Int4Array) AssignTo(dst interface{}) error { } else { *v = nil } - + default: if originalDst, ok := underlyingPtrSliceType(dst); ok { return src.AssignTo(originalDst) } - return fmt.Errorf("cannot put decode %v into %T", src, dst) + return fmt.Errorf("cannot decode %v into %T", src, dst) } return nil diff --git a/pgtype/int8array.go b/pgtype/int8array.go index 35ecf946..003ed055 100644 --- a/pgtype/int8array.go +++ b/pgtype/int8array.go @@ -18,7 +18,7 @@ func (dst *Int8Array) ConvertFrom(src interface{}) error { switch value := src.(type) { case Int8Array: *dst = value - + case []int64: if value == nil { *dst = Int8Array{Status: Null} @@ -37,7 +37,7 @@ func (dst *Int8Array) ConvertFrom(src interface{}) error { Status: Present, } } - + case []uint64: if value == nil { *dst = Int8Array{Status: Null} @@ -56,7 +56,7 @@ func (dst *Int8Array) ConvertFrom(src interface{}) error { Status: Present, } } - + default: if originalSrc, ok := underlyingSliceType(src); ok { return dst.ConvertFrom(originalSrc) @@ -69,7 +69,7 @@ func (dst *Int8Array) ConvertFrom(src interface{}) error { func (src *Int8Array) AssignTo(dst interface{}) error { switch v := dst.(type) { - + case *[]int64: if src.Status == Present { *v = make([]int64, len(src.Elements)) @@ -81,7 +81,7 @@ func (src *Int8Array) AssignTo(dst interface{}) error { } else { *v = nil } - + case *[]uint64: if src.Status == Present { *v = make([]uint64, len(src.Elements)) @@ -93,12 +93,12 @@ func (src *Int8Array) AssignTo(dst interface{}) error { } else { *v = nil } - + default: if originalDst, ok := underlyingPtrSliceType(dst); ok { return src.AssignTo(originalDst) } - return fmt.Errorf("cannot put decode %v into %T", src, dst) + return fmt.Errorf("cannot decode %v into %T", src, dst) } return nil diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index 7d34ae34..304fd0ea 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -13,6 +13,7 @@ import ( ) // Test for renamed types +type _string string type _bool bool type _int8 int8 type _int16 int16 diff --git a/pgtype/text.go b/pgtype/text.go new file mode 100644 index 00000000..c9054468 --- /dev/null +++ b/pgtype/text.go @@ -0,0 +1,115 @@ +package pgtype + +import ( + "fmt" + "io" + "reflect" + + "github.com/jackc/pgx/pgio" +) + +type Text struct { + String string + Status Status +} + +func (dst *Text) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Text: + *dst = value + case string: + *dst = Text{String: value, Status: Present} + case *string: + if value == nil { + *dst = Text{Status: Null} + } else { + *dst = Text{String: *value, Status: Present} + } + default: + if originalSrc, ok := underlyingStringType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Text", value) + } + + return nil +} + +func (src *Text) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *string: + if src.Status != Present { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + *v = src.String + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if src.Status == Null { + el.Set(reflect.Zero(el.Type())) + return nil + } + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return src.AssignTo(el.Interface()) + case reflect.String: + if src.Status != Present { + return fmt.Errorf("cannot assign %v to %T", src, dst) + } + el.SetString(src.String) + return nil + } + } + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil +} + +func (dst *Text) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Text{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = r.Read(buf) + if err != nil { + return err + } + + *dst = Text{String: string(buf), Status: Present} + return nil +} + +func (dst *Text) DecodeBinary(r io.Reader) error { + return dst.DecodeText(r) +} + +func (src Text) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, int32(len(src.String))) + if err != nil { + return nil + } + + _, err = io.WriteString(w, src.String) + return err +} + +func (src Text) EncodeBinary(w io.Writer) error { + return src.EncodeText(w) +} diff --git a/pgtype/text_test.go b/pgtype/text_test.go new file mode 100644 index 00000000..6e944857 --- /dev/null +++ b/pgtype/text_test.go @@ -0,0 +1,100 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestTextTranscode(t *testing.T) { + for _, pgTypeName := range []string{"text", "varchar"} { + testSuccessfulTranscode(t, pgTypeName, []interface{}{ + pgtype.Text{String: "", Status: pgtype.Present}, + pgtype.Text{String: "foo", Status: pgtype.Present}, + pgtype.Text{Status: pgtype.Null}, + }) + } +} + +func TestTextConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Text + }{ + {source: pgtype.Text{String: "foo", Status: pgtype.Present}, result: pgtype.Text{String: "foo", Status: pgtype.Present}}, + {source: "foo", result: pgtype.Text{String: "foo", Status: pgtype.Present}}, + {source: _string("bar"), result: pgtype.Text{String: "bar", Status: pgtype.Present}}, + {source: (*string)(nil), result: pgtype.Text{Status: pgtype.Null}}, + } + + for i, tt := range successfulTests { + var d pgtype.Text + err := d.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if d != tt.result { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, d) + } + } +} + +func TestTextAssignTo(t *testing.T) { + var s string + var ps *string + + simpleTests := []struct { + src pgtype.Text + dst interface{} + expected interface{} + }{ + {src: pgtype.Text{String: "foo", Status: pgtype.Present}, dst: &s, expected: "foo"}, + {src: pgtype.Text{Status: pgtype.Null}, dst: &ps, expected: ((*string)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + pointerAllocTests := []struct { + src pgtype.Text + dst interface{} + expected interface{} + }{ + {src: pgtype.Text{String: "foo", Status: pgtype.Present}, dst: &ps, expected: "foo"}, + } + + for i, tt := range pointerAllocTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Elem().Interface(); dst != tt.expected { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.Text + dst interface{} + }{ + {src: pgtype.Text{Status: pgtype.Null}, dst: &s}, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/textarray.go b/pgtype/textarray.go new file mode 100644 index 00000000..c420e5c9 --- /dev/null +++ b/pgtype/textarray.go @@ -0,0 +1,297 @@ +package pgtype + +import ( + "bytes" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type TextArray struct { + Elements []Text + Dimensions []ArrayDimension + Status Status +} + +func (dst *TextArray) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case TextArray: + *dst = value + + case []string: + if value == nil { + *dst = TextArray{Status: Null} + } else if len(value) == 0 { + *dst = TextArray{Status: Present} + } else { + elements := make([]Text, len(value)) + for i := range value { + if err := elements[i].ConvertFrom(value[i]); err != nil { + return err + } + } + *dst = TextArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Text", value) + } + + return nil +} + +func (src *TextArray) AssignTo(dst interface{}) error { + switch v := dst.(type) { + + case *[]string: + if src.Status == Present { + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil +} + +func (dst *TextArray) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = TextArray{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + _, err = io.ReadFull(r, buf) + if err != nil { + return err + } + + uta, err := ParseUntypedTextArray(string(buf)) + if err != nil { + return err + } + + textElementReader := NewTextElementReader(r) + var elements []Text + + if len(uta.Elements) > 0 { + elements = make([]Text, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Text + textElementReader.Reset(s) + err = elem.DecodeText(textElementReader) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = TextArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *TextArray) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = TextArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + err = arrayHeader.DecodeBinary(r) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = TextArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Text, elementCount) + + for i := range elements { + err = elements[i].DecodeBinary(r) + if err != nil { + return err + } + } + + *dst = TextArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *TextArray) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + if len(src.Dimensions) == 0 { + _, err := pgio.WriteInt32(w, 2) + if err != nil { + return err + } + + _, err = w.Write([]byte("{}")) + return err + } + + buf := &bytes.Buffer{} + + err := EncodeTextArrayDimensions(buf, src.Dimensions) + if err != nil { + return err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + textElementWriter := NewTextElementWriter(buf) + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(buf, ',') + if err != nil { + return err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(buf, '{') + if err != nil { + return err + } + } + } + + textElementWriter.Reset() + if elem.String == "" && elem.Status == Present { + _, err := io.WriteString(buf, `""`) + if err != nil { + return err + } + } else { + err = elem.EncodeText(textElementWriter) + if err != nil { + return err + } + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(buf, '}') + if err != nil { + return err + } + } + } + } + + _, err = pgio.WriteInt32(w, int32(buf.Len())) + if err != nil { + return err + } + + _, err = buf.WriteTo(w) + return err +} + +func (src *TextArray) EncodeBinary(w io.Writer) error { + return src.encodeBinary(w, TextOID) +} + +func (src *TextArray) encodeBinary(w io.Writer, elementOID int32) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + var arrayHeader ArrayHeader + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + err := src.Elements[i].EncodeBinary(elemBuf) + if err != nil { + return err + } + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + } + } + + arrayHeader.ElementOID = elementOID + arrayHeader.Dimensions = src.Dimensions + + // TODO - consider how to avoid having to buffer array before writing length - + // or how not pay allocations for the byte order conversions. + headerBuf := &bytes.Buffer{} + err := arrayHeader.EncodeBinary(headerBuf) + if err != nil { + return err + } + + _, err = pgio.WriteInt32(w, int32(headerBuf.Len()+elemBuf.Len())) + if err != nil { + return err + } + + _, err = headerBuf.WriteTo(w) + if err != nil { + return err + } + + _, err = elemBuf.WriteTo(w) + if err != nil { + return err + } + + return err +} diff --git a/pgtype/textarray_test.go b/pgtype/textarray_test.go new file mode 100644 index 00000000..29e3a6c7 --- /dev/null +++ b/pgtype/textarray_test.go @@ -0,0 +1,151 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestTextArrayTranscode(t *testing.T) { + testSuccessfulTranscode(t, "text[]", []interface{}{ + &pgtype.TextArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.TextArray{ + Elements: []pgtype.Text{ + pgtype.Text{String: "foo", Status: pgtype.Present}, + pgtype.Text{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.TextArray{Status: pgtype.Null}, + &pgtype.TextArray{ + Elements: []pgtype.Text{ + pgtype.Text{String: "bar", Status: pgtype.Present}, + pgtype.Text{String: "baz", Status: pgtype.Present}, + pgtype.Text{String: "quz", Status: pgtype.Present}, + pgtype.Text{String: "", Status: pgtype.Present}, + pgtype.Text{Status: pgtype.Null}, + pgtype.Text{String: "foo", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}, {Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.TextArray{ + Elements: []pgtype.Text{ + pgtype.Text{String: "bar", Status: pgtype.Present}, + pgtype.Text{String: "baz", Status: pgtype.Present}, + pgtype.Text{String: "quz", Status: pgtype.Present}, + pgtype.Text{String: "foo", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestTextArrayConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.TextArray + }{ + { + source: []string{"foo"}, + result: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]string)(nil)), + result: pgtype.TextArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.TextArray + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestTextArrayAssignTo(t *testing.T) { + var stringSlice []string + type _stringSlice []string + var namedStringSlice _stringSlice + + simpleTests := []struct { + src pgtype.TextArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "foo", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + expected: []string{"foo"}, + }, + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedStringSlice, + expected: _stringSlice{"bar"}, + }, + { + src: pgtype.TextArray{Status: pgtype.Null}, + dst: &stringSlice, + expected: (([]string)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.TextArray + dst interface{} + }{ + { + src: pgtype.TextArray{ + Elements: []pgtype.Text{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/timestamparray.go b/pgtype/timestamparray.go index f1b1d003..3acbb35f 100644 --- a/pgtype/timestamparray.go +++ b/pgtype/timestamparray.go @@ -68,7 +68,7 @@ func (src *TimestampArray) AssignTo(dst interface{}) error { if originalDst, ok := underlyingPtrSliceType(dst); ok { return src.AssignTo(originalDst) } - return fmt.Errorf("cannot put decode %v into %T", src, dst) + return fmt.Errorf("cannot decode %v into %T", src, dst) } return nil diff --git a/pgtype/timestamptzarray.go b/pgtype/timestamptzarray.go index 72b28e43..9df746e6 100644 --- a/pgtype/timestamptzarray.go +++ b/pgtype/timestamptzarray.go @@ -68,7 +68,7 @@ func (src *TimestamptzArray) AssignTo(dst interface{}) error { if originalDst, ok := underlyingPtrSliceType(dst); ok { return src.AssignTo(originalDst) } - return fmt.Errorf("cannot put decode %v into %T", src, dst) + return fmt.Errorf("cannot decode %v into %T", src, dst) } return nil diff --git a/pgtype/typed_array.go.erb b/pgtype/typed_array.go.erb index e6e480b0..647ed7c0 100644 --- a/pgtype/typed_array.go.erb +++ b/pgtype/typed_array.go.erb @@ -67,7 +67,7 @@ func (src *<%= pgtype_array_type %>) AssignTo(dst interface{}) error { if originalDst, ok := underlyingPtrSliceType(dst); ok { return src.AssignTo(originalDst) } - return fmt.Errorf("cannot put decode %v into %T", src, dst) + return fmt.Errorf("cannot decode %v into %T", src, dst) } return nil diff --git a/pgtype/typed_array_gen.sh b/pgtype/typed_array_gen.sh index 47afdf1d..f984e12e 100644 --- a/pgtype/typed_array_gen.sh +++ b/pgtype/typed_array_gen.sh @@ -8,3 +8,4 @@ erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_type erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_oid=Float4OID typed_array.go.erb > float4array.go erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_oid=Float8OID typed_array.go.erb > float8array.go erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_oid=InetOID typed_array.go.erb > inetarray.go +erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_oid=TextOID typed_array.go.erb > textarray.go diff --git a/pgtype/varchararray.go b/pgtype/varchararray.go new file mode 100644 index 00000000..13d94bc0 --- /dev/null +++ b/pgtype/varchararray.go @@ -0,0 +1,31 @@ +package pgtype + +import ( + "io" +) + +type VarcharArray TextArray + +func (dst *VarcharArray) ConvertFrom(src interface{}) error { + return (*TextArray)(dst).ConvertFrom(src) +} + +func (src *VarcharArray) AssignTo(dst interface{}) error { + return (*TextArray)(src).AssignTo(dst) +} + +func (dst *VarcharArray) DecodeText(r io.Reader) error { + return (*TextArray)(dst).DecodeText(r) +} + +func (dst *VarcharArray) DecodeBinary(r io.Reader) error { + return (*TextArray)(dst).DecodeBinary(r) +} + +func (src *VarcharArray) EncodeText(w io.Writer) error { + return (*TextArray)(src).EncodeText(w) +} + +func (src *VarcharArray) EncodeBinary(w io.Writer) error { + return (*TextArray)(src).encodeBinary(w, VarcharOID) +} diff --git a/query.go b/query.go index 9019fca4..ffe51ecc 100644 --- a/query.go +++ b/query.go @@ -388,22 +388,6 @@ func (rows *Rows) Values() ([]interface{}, error) { values = append(values, decodeFloat4(vr)) case Float8OID: values = append(values, decodeFloat8(vr)) - case BoolArrayOID: - values = append(values, decodeBoolArray(vr)) - case Int2ArrayOID: - values = append(values, decodeInt2Array(vr)) - case Int4ArrayOID: - values = append(values, decodeInt4Array(vr)) - case Int8ArrayOID: - values = append(values, decodeInt8Array(vr)) - case Float4ArrayOID: - values = append(values, decodeFloat4Array(vr)) - case Float8ArrayOID: - values = append(values, decodeFloat8Array(vr)) - case TextArrayOID, VarcharArrayOID: - values = append(values, decodeTextArray(vr)) - case TimestampArrayOID, TimestampTzArrayOID: - values = append(values, decodeTimestampArray(vr)) case DateOID: values = append(values, decodeDate(vr)) case TimestampTzOID: @@ -479,22 +463,6 @@ func (rows *Rows) ValuesForStdlib() ([]interface{}, error) { values = append(values, decodeFloat4(vr)) case Float8OID: values = append(values, decodeFloat8(vr)) - case BoolArrayOID: - values = append(values, decodeBoolArray(vr)) - case Int2ArrayOID: - values = append(values, decodeInt2Array(vr)) - case Int4ArrayOID: - values = append(values, decodeInt4Array(vr)) - case Int8ArrayOID: - values = append(values, decodeInt8Array(vr)) - case Float4ArrayOID: - values = append(values, decodeFloat4Array(vr)) - case Float8ArrayOID: - values = append(values, decodeFloat8Array(vr)) - case TextArrayOID, VarcharArrayOID: - values = append(values, decodeTextArray(vr)) - case TimestampArrayOID, TimestampTzArrayOID: - values = append(values, decodeTimestampArray(vr)) case DateOID: values = append(values, decodeDate(vr)) case TimestampTzOID: diff --git a/query_test.go b/query_test.go index 364e6b57..801ba851 100644 --- a/query_test.go +++ b/query_test.go @@ -1179,9 +1179,6 @@ func TestQueryRowCoreStringSlice(t *testing.T) { if err == nil { t.Error("Expected null to cause error when scanned into slice, but it didn't") } - if err != nil && !strings.Contains(err.Error(), "Cannot decode null") { - t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err) - } ensureConnValid(t, conn) } diff --git a/values.go b/values.go index 3d7d63a2..c011c8cf 100644 --- a/values.go +++ b/values.go @@ -1073,8 +1073,6 @@ func Encode(wbuf *WriteBuf, oid OID, arg interface{}) error { } switch arg := arg.(type) { - case []string: - return encodeStringSlice(wbuf, oid, arg) case Char: return encodeChar(wbuf, oid, arg) case AclItem: @@ -1178,8 +1176,6 @@ func Decode(vr *ValueReader, d interface{}) error { *v = decodeText(vr) case *[]AclItem: *v = decodeAclItemArray(vr) - case *[]string: - *v = decodeTextArray(vr) case *[][]byte: *v = decodeByteaArray(vr) case *[]interface{}: @@ -2569,41 +2565,6 @@ func encodeFloat64Slice(w *WriteBuf, oid OID, slice []float64) error { return nil } -func decodeTextArray(vr *ValueReader) []string { - if vr.Len() == -1 { - return nil - } - - if vr.Type().DataType != TextArrayOID && vr.Type().DataType != VarcharArrayOID { - vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []string", vr.Type().DataType))) - return nil - } - - if vr.Type().FormatCode != BinaryFormatCode { - vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode))) - return nil - } - - numElems, err := decode1dArrayHeader(vr) - if err != nil { - vr.Fatal(err) - return nil - } - - a := make([]string, int(numElems)) - for i := 0; i < len(a); i++ { - elSize := vr.ReadInt32() - if elSize == -1 { - vr.Fatal(ProtocolError("Cannot decode null element")) - return nil - } - - a[i] = vr.ReadString(elSize) - } - - return a -} - // escapeAclItem escapes an AclItem before it is added to // its aclitem[] string representation. The PostgreSQL aclitem // datatype itself can need escapes because it follows the @@ -2808,39 +2769,6 @@ func decodeAclItemArray(vr *ValueReader) []AclItem { return aclItems } -func encodeStringSlice(w *WriteBuf, oid OID, slice []string) error { - var elOID OID - switch oid { - case VarcharArrayOID: - elOID = VarcharOID - case TextArrayOID: - elOID = TextOID - default: - return fmt.Errorf("cannot encode Go %s into oid %d", "[]string", oid) - } - - var totalStringSize int - for _, v := range slice { - totalStringSize += len(v) - } - - size := 20 + len(slice)*4 + totalStringSize - w.WriteInt32(int32(size)) - - w.WriteInt32(1) // number of dimensions - w.WriteInt32(0) // no nulls - w.WriteInt32(int32(elOID)) // type of elements - w.WriteInt32(int32(len(slice))) // number of elements - w.WriteInt32(1) // index of first element - - for _, v := range slice { - w.WriteInt32(int32(len(v))) - w.WriteBytes([]byte(v)) - } - - return nil -} - func decodeTimestampArray(vr *ValueReader) []time.Time { if vr.Len() == -1 { return nil