diff --git a/pgtype/bool_array.go b/pgtype/bool_array.go index 1cb46cf6..6adfbb00 100644 --- a/pgtype/bool_array.go +++ b/pgtype/bool_array.go @@ -238,10 +238,6 @@ func (src *BoolArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *BoolArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, BoolOid) -} - -func (src *BoolArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -250,10 +246,15 @@ func (src *BoolArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("bool"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "bool") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/pgtype/bytea_array.go b/pgtype/bytea_array.go index 30405509..d318fa3b 100644 --- a/pgtype/bytea_array.go +++ b/pgtype/bytea_array.go @@ -238,10 +238,6 @@ func (src *ByteaArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *ByteaArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, ByteaOid) -} - -func (src *ByteaArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -250,10 +246,15 @@ func (src *ByteaArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("bytea"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "bytea") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/pgtype/cidr_array.go b/pgtype/cidr_array.go index 32d2e7bf..3ab83ecd 100644 --- a/pgtype/cidr_array.go +++ b/pgtype/cidr_array.go @@ -270,10 +270,6 @@ func (src *CidrArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *CidrArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, CidrOid) -} - -func (src *CidrArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -282,10 +278,15 @@ func (src *CidrArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("cidr"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "cidr") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/pgtype/date_array.go b/pgtype/date_array.go index ba68d561..8bc8ff72 100644 --- a/pgtype/date_array.go +++ b/pgtype/date_array.go @@ -239,10 +239,6 @@ func (src *DateArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *DateArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, DateOid) -} - -func (src *DateArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -251,10 +247,15 @@ func (src *DateArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("date"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "date") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/pgtype/float4_array.go b/pgtype/float4_array.go index 40152bcf..6abc1a31 100644 --- a/pgtype/float4_array.go +++ b/pgtype/float4_array.go @@ -238,10 +238,6 @@ func (src *Float4Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *Float4Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, Float4Oid) -} - -func (src *Float4Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -250,10 +246,15 @@ func (src *Float4Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32 } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("float4"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "float4") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/pgtype/float8_array.go b/pgtype/float8_array.go index d0ee0d70..050efa3f 100644 --- a/pgtype/float8_array.go +++ b/pgtype/float8_array.go @@ -238,10 +238,6 @@ func (src *Float8Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *Float8Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, Float8Oid) -} - -func (src *Float8Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -250,10 +246,15 @@ func (src *Float8Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32 } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("float8"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "float8") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/pgtype/hstore_array.go b/pgtype/hstore_array.go new file mode 100644 index 00000000..ba192462 --- /dev/null +++ b/pgtype/hstore_array.go @@ -0,0 +1,297 @@ +package pgtype + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + + "github.com/jackc/pgx/pgio" +) + +type HstoreArray struct { + Elements []Hstore + Dimensions []ArrayDimension + Status Status +} + +func (dst *HstoreArray) Set(src interface{}) error { + switch value := src.(type) { + + case []map[string]string: + if value == nil { + *dst = HstoreArray{Status: Null} + } else if len(value) == 0 { + *dst = HstoreArray{Status: Present} + } else { + elements := make([]Hstore, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = HstoreArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return fmt.Errorf("cannot convert %v to Hstore", value) + } + + return nil +} + +func (dst *HstoreArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *HstoreArray) AssignTo(dst interface{}) error { + switch v := dst.(type) { + + case *[]map[string]string: + if src.Status == Present { + *v = make([]map[string]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + } else { + *v = nil + } + + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil +} + +func (dst *HstoreArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = HstoreArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []Hstore + + if len(uta.Elements) > 0 { + elements = make([]Hstore, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem Hstore + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = HstoreArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (dst *HstoreArray) DecodeBinary(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = HstoreArray{Status: Null} + return nil + } + + var arrayHeader ArrayHeader + rp, err := arrayHeader.DecodeBinary(ci, src) + if err != nil { + return err + } + + if len(arrayHeader.Dimensions) == 0 { + *dst = HstoreArray{Dimensions: arrayHeader.Dimensions, Status: Present} + return nil + } + + elementCount := arrayHeader.Dimensions[0].Length + for _, d := range arrayHeader.Dimensions[1:] { + elementCount *= d.Length + } + + elements := make([]Hstore, elementCount) + + for i := range elements { + elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) + rp += 4 + var elemSrc []byte + if elemLen >= 0 { + elemSrc = src[rp : rp+elemLen] + rp += elemLen + } + err = elements[i].DecodeBinary(ci, elemSrc) + if err != nil { + return err + } + } + + *dst = HstoreArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present} + return nil +} + +func (src *HstoreArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + if len(src.Dimensions) == 0 { + _, err := io.WriteString(w, "{}") + return false, err + } + + err := EncodeTextArrayDimensions(w, src.Dimensions) + if err != nil { + return false, err + } + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + for i, elem := range src.Elements { + if i > 0 { + err = pgio.WriteByte(w, ',') + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + err = pgio.WriteByte(w, '{') + if err != nil { + return false, err + } + } + } + + elemBuf := &bytes.Buffer{} + null, err := elem.EncodeText(ci, elemBuf) + if err != nil { + return false, err + } + if null { + _, err = io.WriteString(w, `NULL`) + if err != nil { + return false, err + } + } else { + _, err = io.WriteString(w, QuoteArrayElementIfNeeded(elemBuf.String())) + if err != nil { + return false, err + } + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + err = pgio.WriteByte(w, '}') + if err != nil { + return false, err + } + } + } + } + + return false, nil +} + +func (src *HstoreArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { + switch src.Status { + case Null: + return true, nil + case Undefined: + return false, errUndefined + } + + arrayHeader := ArrayHeader{ + Dimensions: src.Dimensions, + } + + if dt, ok := ci.DataTypeForName("hstore"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "hstore") + } + + for i := range src.Elements { + if src.Elements[i].Status == Null { + arrayHeader.ContainsNull = true + break + } + } + + err := arrayHeader.EncodeBinary(ci, w) + if err != nil { + return false, err + } + + elemBuf := &bytes.Buffer{} + + for i := range src.Elements { + elemBuf.Reset() + + null, err := src.Elements[i].EncodeBinary(ci, elemBuf) + if err != nil { + return false, err + } + if null { + _, err = pgio.WriteInt32(w, -1) + if err != nil { + return false, err + } + } else { + _, err = pgio.WriteInt32(w, int32(elemBuf.Len())) + if err != nil { + return false, err + } + _, err = elemBuf.WriteTo(w) + if err != nil { + return false, err + } + } + } + + return false, err +} diff --git a/pgtype/hstore_array_test.go b/pgtype/hstore_array_test.go new file mode 100644 index 00000000..e23c7b3b --- /dev/null +++ b/pgtype/hstore_array_test.go @@ -0,0 +1,183 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx" + "github.com/jackc/pgx/pgtype" +) + +func TestHstoreArrayTranscode(t *testing.T) { + conn := mustConnectPgx(t) + defer mustClose(t, conn) + + text := func(s string) pgtype.Text { + return pgtype.Text{String: s, Status: pgtype.Present} + } + + values := []pgtype.Hstore{ + pgtype.Hstore{Map: map[string]pgtype.Text{}, Status: pgtype.Present}, + pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar")}, Status: pgtype.Present}, + pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("bar"), "baz": text("quz")}, Status: pgtype.Present}, + pgtype.Hstore{Map: map[string]pgtype.Text{"NULL": text("bar")}, Status: pgtype.Present}, + pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("NULL")}, Status: pgtype.Present}, + pgtype.Hstore{Status: pgtype.Null}, + } + + specialStrings := []string{ + `"`, + `'`, + `\`, + `\\`, + `=>`, + ` `, + `\ / / \\ => " ' " '`, + } + for _, s := range specialStrings { + // Special key values + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{s + "foo": text("bar")}, Status: pgtype.Present}) // at beginning + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s + "bar": text("bar")}, Status: pgtype.Present}) // in middle + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo" + s: text("bar")}, Status: pgtype.Present}) // at end + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{s: text("bar")}, Status: pgtype.Present}) // is key + + // Special value values + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s + "bar")}, Status: pgtype.Present}) // at beginning + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s + "bar")}, Status: pgtype.Present}) // in middle + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text("foo" + s)}, Status: pgtype.Present}) // at end + values = append(values, pgtype.Hstore{Map: map[string]pgtype.Text{"foo": text(s)}, Status: pgtype.Present}) // is key + } + + src := pgtype.HstoreArray{ + Elements: values, + Dimensions: []pgtype.ArrayDimension{{Length: int32(len(values)), LowerBound: 1}}, + Status: pgtype.Present, + } + + ps, err := conn.Prepare("test", "select $1::hstore[]") + if err != nil { + t.Fatal(err) + } + + formats := []struct { + name string + formatCode int16 + }{ + {name: "TextFormat", formatCode: pgx.TextFormatCode}, + {name: "BinaryFormat", formatCode: pgx.BinaryFormatCode}, + } + + for _, fc := range formats { + ps.FieldDescriptions[0].FormatCode = fc.formatCode + vEncoder := forceEncoder(src, fc.formatCode) + if vEncoder == nil { + t.Logf("%#v does not implement %v", src, fc.name) + continue + } + + var result pgtype.HstoreArray + err := conn.QueryRow("test", vEncoder).Scan(&result) + if err != nil { + t.Errorf("%v: %v", fc.name, err) + continue + } + + if result.Status != src.Status { + t.Errorf("%v: expected Status %v, got %v", fc.formatCode, src.Status, result.Status) + continue + } + + if len(result.Elements) != len(src.Elements) { + t.Errorf("%v: expected %v elements, got %v", fc.formatCode, len(src.Elements), len(result.Elements)) + continue + } + + for i := range result.Elements { + a := src.Elements[i] + b := result.Elements[i] + + if a.Status != b.Status { + t.Errorf("%v element idx %d: expected status %v, got %v", fc.formatCode, i, a.Status, b.Status) + } + + if len(a.Map) != len(b.Map) { + t.Errorf("%v element idx %d: expected %v pairs, got %v", fc.formatCode, i, len(a.Map), len(b.Map)) + } + + for k := range a.Map { + if a.Map[k] != b.Map[k] { + t.Errorf("%v element idx %d: expected key %v to be %v, got %v", fc.formatCode, i, k, a.Map[k], b.Map[k]) + } + } + } + } +} + +func TestHstoreArraySet(t *testing.T) { + successfulTests := []struct { + src []map[string]string + result pgtype.HstoreArray + }{ + { + src: []map[string]string{map[string]string{"foo": "bar"}}, + result: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": pgtype.Text{String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + }, + } + + for i, tt := range successfulTests { + var dst pgtype.HstoreArray + err := dst.Set(tt.src) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(dst, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.src, tt.result, dst) + } + } +} + +func TestHstoreArrayAssignTo(t *testing.T) { + var m []map[string]string + + simpleTests := []struct { + src pgtype.HstoreArray + dst *[]map[string]string + expected []map[string]string + }{ + { + src: pgtype.HstoreArray{ + Elements: []pgtype.Hstore{ + { + Map: map[string]pgtype.Text{"foo": pgtype.Text{String: "bar", Status: pgtype.Present}}, + Status: pgtype.Present, + }, + }, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &m, + expected: []map[string]string{{"foo": "bar"}}}, + {src: pgtype.HstoreArray{Status: pgtype.Null}, dst: &m, expected: (([]map[string]string)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(*tt.dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, *tt.dst) + } + } +} diff --git a/pgtype/inet_array.go b/pgtype/inet_array.go index 6cad82e7..d893a724 100644 --- a/pgtype/inet_array.go +++ b/pgtype/inet_array.go @@ -270,10 +270,6 @@ func (src *InetArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *InetArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, InetOid) -} - -func (src *InetArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -282,10 +278,15 @@ func (src *InetArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("inet"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "inet") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/pgtype/int2_array.go b/pgtype/int2_array.go index 2bf1c237..b93a4fa3 100644 --- a/pgtype/int2_array.go +++ b/pgtype/int2_array.go @@ -269,10 +269,6 @@ func (src *Int2Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *Int2Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, Int2Oid) -} - -func (src *Int2Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -281,10 +277,15 @@ func (src *Int2Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("int2"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "int2") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/pgtype/int4_array.go b/pgtype/int4_array.go index dda88eaf..0b96b7a4 100644 --- a/pgtype/int4_array.go +++ b/pgtype/int4_array.go @@ -269,10 +269,6 @@ func (src *Int4Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *Int4Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, Int4Oid) -} - -func (src *Int4Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -281,10 +277,15 @@ func (src *Int4Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("int4"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "int4") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/pgtype/int8_array.go b/pgtype/int8_array.go index 468c126b..02a240f4 100644 --- a/pgtype/int8_array.go +++ b/pgtype/int8_array.go @@ -269,10 +269,6 @@ func (src *Int8Array) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *Int8Array) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, Int8Oid) -} - -func (src *Int8Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -281,10 +277,15 @@ func (src *Int8Array) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("int8"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "int8") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/pgtype/text_array.go b/pgtype/text_array.go index 6e89708f..9f25727e 100644 --- a/pgtype/text_array.go +++ b/pgtype/text_array.go @@ -238,10 +238,6 @@ func (src *TextArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *TextArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, TextOid) -} - -func (src *TextArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -250,10 +246,15 @@ func (src *TextArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("text"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "text") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/pgtype/timestamp_array.go b/pgtype/timestamp_array.go index 064ad483..bb19e502 100644 --- a/pgtype/timestamp_array.go +++ b/pgtype/timestamp_array.go @@ -239,10 +239,6 @@ func (src *TimestampArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *TimestampArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, TimestampOid) -} - -func (src *TimestampArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -251,10 +247,15 @@ func (src *TimestampArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid in } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("timestamp"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "timestamp") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/pgtype/timestamptz_array.go b/pgtype/timestamptz_array.go index 4af1460b..6a85cefa 100644 --- a/pgtype/timestamptz_array.go +++ b/pgtype/timestamptz_array.go @@ -239,10 +239,6 @@ func (src *TimestamptzArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) } func (src *TimestamptzArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, TimestamptzOid) -} - -func (src *TimestamptzArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -251,10 +247,15 @@ func (src *TimestamptzArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("timestamptz"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "timestamptz") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/pgtype/typed_array.go.erb b/pgtype/typed_array.go.erb index 2a46a658..2b81666e 100644 --- a/pgtype/typed_array.go.erb +++ b/pgtype/typed_array.go.erb @@ -237,10 +237,6 @@ func (src *<%= pgtype_array_type %>) EncodeText(ci *ConnInfo, w io.Writer) (bool } func (src *<%= pgtype_array_type %>) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, <%= element_oid %>) -} - -func (src *<%= pgtype_array_type %>) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -249,10 +245,15 @@ func (src *<%= pgtype_array_type %>) encodeBinary(ci *ConnInfo, w io.Writer, ele } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("<%= element_type_name %>"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "<%= element_type_name %>") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true diff --git a/pgtype/typed_array_gen.sh b/pgtype/typed_array_gen.sh index 5fde32aa..166f8802 100644 --- a/pgtype/typed_array_gen.sh +++ b/pgtype/typed_array_gen.sh @@ -1,15 +1,16 @@ -erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]uint16 element_oid=Int2Oid text_null=NULL typed_array.go.erb > int2_array.go -erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int32,[]uint32 element_oid=Int4Oid text_null=NULL typed_array.go.erb > int4_array.go -erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int64,[]uint64 element_oid=Int8Oid text_null=NULL typed_array.go.erb > int8_array.go -erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool element_oid=BoolOid text_null=NULL typed_array.go.erb > bool_array.go -erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time element_oid=DateOid text_null=NULL typed_array.go.erb > date_array.go -erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time element_oid=TimestamptzOid text_null=NULL typed_array.go.erb > timestamptz_array.go -erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time element_oid=TimestampOid text_null=NULL typed_array.go.erb > timestamp_array.go -erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_oid=Float4Oid text_null=NULL typed_array.go.erb > float4_array.go -erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_oid=Float8Oid text_null=NULL typed_array.go.erb > float8_array.go -erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_oid=InetOid text_null=NULL typed_array.go.erb > inet_array.go -erb pgtype_array_type=CidrArray pgtype_element_type=Cidr go_array_types=[]*net.IPNet,[]net.IP element_oid=CidrOid text_null=NULL typed_array.go.erb > cidr_array.go -erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_oid=TextOid text_null='"NULL"' typed_array.go.erb > text_array.go -erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string element_oid=VarcharOid text_null='"NULL"' typed_array.go.erb > varchar_array.go -erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_oid=ByteaOid text_null=NULL typed_array.go.erb > bytea_array.go -erb pgtype_array_type=AclitemArray pgtype_element_type=Aclitem go_array_types=[]string element_oid=AclitemOid text_null=NULL typed_array.go.erb > aclitem_array.go +erb pgtype_array_type=Int2Array pgtype_element_type=Int2 go_array_types=[]int16,[]uint16 element_type_name=int2 text_null=NULL typed_array.go.erb > int2_array.go +erb pgtype_array_type=Int4Array pgtype_element_type=Int4 go_array_types=[]int32,[]uint32 element_type_name=int4 text_null=NULL typed_array.go.erb > int4_array.go +erb pgtype_array_type=Int8Array pgtype_element_type=Int8 go_array_types=[]int64,[]uint64 element_type_name=int8 text_null=NULL typed_array.go.erb > int8_array.go +erb pgtype_array_type=BoolArray pgtype_element_type=Bool go_array_types=[]bool element_type_name=bool text_null=NULL typed_array.go.erb > bool_array.go +erb pgtype_array_type=DateArray pgtype_element_type=Date go_array_types=[]time.Time element_type_name=date text_null=NULL typed_array.go.erb > date_array.go +erb pgtype_array_type=TimestamptzArray pgtype_element_type=Timestamptz go_array_types=[]time.Time element_type_name=timestamptz text_null=NULL typed_array.go.erb > timestamptz_array.go +erb pgtype_array_type=TimestampArray pgtype_element_type=Timestamp go_array_types=[]time.Time element_type_name=timestamp text_null=NULL typed_array.go.erb > timestamp_array.go +erb pgtype_array_type=Float4Array pgtype_element_type=Float4 go_array_types=[]float32 element_type_name=float4 text_null=NULL typed_array.go.erb > float4_array.go +erb pgtype_array_type=Float8Array pgtype_element_type=Float8 go_array_types=[]float64 element_type_name=float8 text_null=NULL typed_array.go.erb > float8_array.go +erb pgtype_array_type=InetArray pgtype_element_type=Inet go_array_types=[]*net.IPNet,[]net.IP element_type_name=inet text_null=NULL typed_array.go.erb > inet_array.go +erb pgtype_array_type=CidrArray pgtype_element_type=Cidr go_array_types=[]*net.IPNet,[]net.IP element_type_name=cidr text_null=NULL typed_array.go.erb > cidr_array.go +erb pgtype_array_type=TextArray pgtype_element_type=Text go_array_types=[]string element_type_name=text text_null='"NULL"' typed_array.go.erb > text_array.go +erb pgtype_array_type=VarcharArray pgtype_element_type=Varchar go_array_types=[]string element_type_name=varchar text_null='"NULL"' typed_array.go.erb > varchar_array.go +erb pgtype_array_type=ByteaArray pgtype_element_type=Bytea go_array_types=[][]byte element_type_name=bytea text_null=NULL typed_array.go.erb > bytea_array.go +erb pgtype_array_type=AclitemArray pgtype_element_type=Aclitem go_array_types=[]string element_type_name=aclitem text_null=NULL typed_array.go.erb > aclitem_array.go +erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore go_array_types=[]map[string]string element_type_name=hstore text_null=NULL typed_array.go.erb > hstore_array.go diff --git a/pgtype/varchar_array.go b/pgtype/varchar_array.go index 21e9ccff..158ece94 100644 --- a/pgtype/varchar_array.go +++ b/pgtype/varchar_array.go @@ -238,10 +238,6 @@ func (src *VarcharArray) EncodeText(ci *ConnInfo, w io.Writer) (bool, error) { } func (src *VarcharArray) EncodeBinary(ci *ConnInfo, w io.Writer) (bool, error) { - return src.encodeBinary(ci, w, VarcharOid) -} - -func (src *VarcharArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int32) (bool, error) { switch src.Status { case Null: return true, nil @@ -250,10 +246,15 @@ func (src *VarcharArray) encodeBinary(ci *ConnInfo, w io.Writer, elementOid int3 } arrayHeader := ArrayHeader{ - ElementOid: elementOid, Dimensions: src.Dimensions, } + if dt, ok := ci.DataTypeForName("varchar"); ok { + arrayHeader.ElementOid = int32(dt.Oid) + } else { + return false, fmt.Errorf("unable to find oid for type name %v", "varchar") + } + for i := range src.Elements { if src.Elements[i].Status == Null { arrayHeader.ContainsNull = true