diff --git a/conn.go b/conn.go index 9041cafc..073694fa 100644 --- a/conn.go +++ b/conn.go @@ -419,6 +419,45 @@ where ( c.ConnInfo = pgtype.NewConnInfo() c.ConnInfo.InitializeDataTypes(nameOIDs) + + return c.initConnInfoEnumArray() +} + +// initConnInfoEnumArray introspects for arrays of enums and registers a data type for them. +func (c *Conn) initConnInfoEnumArray() error { + nameOIDs := make(map[string]pgtype.OID, 16) + + rows, err := c.Query(`select t.oid, t.typname +from pg_type t + join pg_type base_type on t.typelem=base_type.oid +where t.typtype = 'b' + and base_type.typtype = 'e'`) + if err != nil { + return err + } + + for rows.Next() { + var oid pgtype.OID + var name pgtype.Text + if err := rows.Scan(&oid, &name); err != nil { + return err + } + + nameOIDs[name.String] = oid + } + + if rows.Err() != nil { + return rows.Err() + } + + for name, oid := range nameOIDs { + c.ConnInfo.RegisterDataType(pgtype.DataType{ + &pgtype.EnumArray{}, + name, + oid, + }) + } + return nil } diff --git a/pgmock/pgmock.go b/pgmock/pgmock.go index e65a25fb..71f18852 100644 --- a/pgmock/pgmock.go +++ b/pgmock/pgmock.go @@ -438,6 +438,47 @@ where ( steps = append(steps, SendMessage(&pgproto3.CommandComplete{CommandTag: "SELECT 163"})) steps = append(steps, SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'})) + steps = append(steps, []Step{ + ExpectMessage(&pgproto3.Parse{ + Query: "select t.oid, t.typname\nfrom pg_type t\n join pg_type base_type on t.typelem=base_type.oid\nwhere t.typtype = 'b'\n and base_type.typtype = 'e'", + }), + ExpectMessage(&pgproto3.Describe{ + ObjectType: 'S', + }), + ExpectMessage(&pgproto3.Sync{}), + SendMessage(&pgproto3.ParseComplete{}), + SendMessage(&pgproto3.ParameterDescription{}), + SendMessage(&pgproto3.RowDescription{ + Fields: []pgproto3.FieldDescription{ + {Name: "oid", + TableOID: 1247, + TableAttributeNumber: 65534, + DataTypeOID: 26, + DataTypeSize: 4, + TypeModifier: 4294967295, + Format: 0, + }, + {Name: "typname", + TableOID: 1247, + TableAttributeNumber: 1, + DataTypeOID: 19, + DataTypeSize: 64, + TypeModifier: 4294967295, + Format: 0, + }, + }, + }), + SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), + ExpectMessage(&pgproto3.Bind{ + ResultFormatCodes: []int16{1, 1}, + }), + ExpectMessage(&pgproto3.Execute{}), + ExpectMessage(&pgproto3.Sync{}), + SendMessage(&pgproto3.BindComplete{}), + SendMessage(&pgproto3.CommandComplete{CommandTag: "SELECT 0"}), + SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), + }...) + return steps } diff --git a/pgtype/enum_array.go b/pgtype/enum_array.go new file mode 100644 index 00000000..3a948015 --- /dev/null +++ b/pgtype/enum_array.go @@ -0,0 +1,212 @@ +package pgtype + +import ( + "database/sql/driver" + + "github.com/pkg/errors" +) + +type EnumArray struct { + Elements []GenericText + Dimensions []ArrayDimension + Status Status +} + +func (dst *EnumArray) Set(src interface{}) error { + // untyped nil and typed nil interfaces are different + if src == nil { + *dst = EnumArray{Status: Null} + return nil + } + + switch value := src.(type) { + + case []string: + if value == nil { + *dst = EnumArray{Status: Null} + } else if len(value) == 0 { + *dst = EnumArray{Status: Present} + } else { + elements := make([]GenericText, len(value)) + for i := range value { + if err := elements[i].Set(value[i]); err != nil { + return err + } + } + *dst = EnumArray{ + Elements: elements, + Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, + Status: Present, + } + } + + default: + if originalSrc, ok := underlyingSliceType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to EnumArray", value) + } + + return nil +} + +func (dst *EnumArray) Get() interface{} { + switch dst.Status { + case Present: + return dst + case Null: + return nil + default: + return dst.Status + } +} + +func (src *EnumArray) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + + case *[]string: + *v = make([]string, len(src.Elements)) + for i := range src.Elements { + if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { + return err + } + } + return nil + + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %v into %T", src, dst) +} + +func (dst *EnumArray) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + *dst = EnumArray{Status: Null} + return nil + } + + uta, err := ParseUntypedTextArray(string(src)) + if err != nil { + return err + } + + var elements []GenericText + + if len(uta.Elements) > 0 { + elements = make([]GenericText, len(uta.Elements)) + + for i, s := range uta.Elements { + var elem GenericText + var elemSrc []byte + if s != "NULL" { + elemSrc = []byte(s) + } + err = elem.DecodeText(ci, elemSrc) + if err != nil { + return err + } + + elements[i] = elem + } + } + + *dst = EnumArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} + + return nil +} + +func (src *EnumArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + if len(src.Dimensions) == 0 { + return append(buf, '{', '}'), nil + } + + buf = EncodeTextArrayDimensions(buf, src.Dimensions) + + // dimElemCounts is the multiples of elements that each array lies on. For + // example, a single dimension array of length 4 would have a dimElemCounts of + // [4]. A multi-dimensional array of lengths [3,5,2] would have a + // dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' + // or '}'. + dimElemCounts := make([]int, len(src.Dimensions)) + dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) + for i := len(src.Dimensions) - 2; i > -1; i-- { + dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] + } + + inElemBuf := make([]byte, 0, 32) + for i, elem := range src.Elements { + if i > 0 { + buf = append(buf, ',') + } + + for _, dec := range dimElemCounts { + if i%dec == 0 { + buf = append(buf, '{') + } + } + + elemBuf, err := elem.EncodeText(ci, inElemBuf) + if err != nil { + return nil, err + } + if elemBuf == nil { + buf = append(buf, `NULL`...) + } else { + buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) + } + + for _, dec := range dimElemCounts { + if (i+1)%dec == 0 { + buf = append(buf, '}') + } + } + } + + return buf, nil +} + +// Scan implements the database/sql Scanner interface. +func (dst *EnumArray) Scan(src interface{}) error { + if src == nil { + return dst.DecodeText(nil, nil) + } + + switch src := src.(type) { + case string: + return dst.DecodeText(nil, []byte(src)) + case []byte: + srcCopy := make([]byte, len(src)) + copy(srcCopy, src) + return dst.DecodeText(nil, srcCopy) + } + + return errors.Errorf("cannot scan %T", src) +} + +// Value implements the database/sql/driver Valuer interface. +func (src *EnumArray) Value() (driver.Value, error) { + buf, err := src.EncodeText(nil, nil) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + + return string(buf), nil +} diff --git a/pgtype/enum_array_test.go b/pgtype/enum_array_test.go new file mode 100644 index 00000000..94774e1e --- /dev/null +++ b/pgtype/enum_array_test.go @@ -0,0 +1,150 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgx/pgtype/testutil" +) + +func TestEnumArrayTranscode(t *testing.T) { + setupConn := testutil.MustConnectPgx(t) + defer testutil.MustClose(t, setupConn) + + if _, err := setupConn.Exec("drop type if exists color"); err != nil { + t.Fatal(err) + } + if _, err := setupConn.Exec("create type color as enum ('red', 'green', 'blue')"); err != nil { + t.Fatal(err) + } + + testutil.TestSuccessfulTranscode(t, "color[]", []interface{}{ + &pgtype.EnumArray{ + Elements: nil, + Dimensions: nil, + Status: pgtype.Present, + }, + &pgtype.EnumArray{ + Elements: []pgtype.GenericText{ + pgtype.GenericText{String: "red", Status: pgtype.Present}, + pgtype.GenericText{Status: pgtype.Null}, + }, + Dimensions: []pgtype.ArrayDimension{{Length: 2, LowerBound: 1}}, + Status: pgtype.Present, + }, + &pgtype.EnumArray{Status: pgtype.Null}, + &pgtype.EnumArray{ + Elements: []pgtype.GenericText{ + pgtype.GenericText{String: "red", Status: pgtype.Present}, + pgtype.GenericText{String: "green", Status: pgtype.Present}, + pgtype.GenericText{String: "blue", Status: pgtype.Present}, + pgtype.GenericText{String: "red", Status: pgtype.Present}, + }, + Dimensions: []pgtype.ArrayDimension{ + {Length: 2, LowerBound: 4}, + {Length: 2, LowerBound: 2}, + }, + Status: pgtype.Present, + }, + }) +} + +func TestEnumArrayArraySet(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.EnumArray + }{ + { + source: []string{"foo"}, + result: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present}, + }, + { + source: (([]string)(nil)), + result: pgtype.EnumArray{Status: pgtype.Null}, + }, + } + + for i, tt := range successfulTests { + var r pgtype.EnumArray + err := r.Set(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestEnumArrayArrayAssignTo(t *testing.T) { + var stringSlice []string + type _stringSlice []string + var namedStringSlice _stringSlice + + simpleTests := []struct { + src pgtype.EnumArray + dst interface{} + expected interface{} + }{ + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "foo", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + expected: []string{"foo"}, + }, + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{String: "bar", Status: pgtype.Present}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &namedStringSlice, + expected: _stringSlice{"bar"}, + }, + { + src: pgtype.EnumArray{Status: pgtype.Null}, + dst: &stringSlice, + expected: (([]string)(nil)), + }, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } + + errorTests := []struct { + src pgtype.EnumArray + dst interface{} + }{ + { + src: pgtype.EnumArray{ + Elements: []pgtype.GenericText{{Status: pgtype.Null}}, + Dimensions: []pgtype.ArrayDimension{{LowerBound: 1, Length: 1}}, + Status: pgtype.Present, + }, + dst: &stringSlice, + }, + } + + for i, tt := range errorTests { + err := tt.src.AssignTo(tt.dst) + if err == nil { + t.Errorf("%d: expected error but none was returned (%v -> %v)", i, tt.src, tt.dst) + } + } +} diff --git a/pgtype/typed_array_gen.sh b/pgtype/typed_array_gen.sh index 80ece93c..2a1eab99 100644 --- a/pgtype/typed_array_gen.sh +++ b/pgtype/typed_array_gen.sh @@ -16,4 +16,8 @@ erb pgtype_array_type=ACLItemArray pgtype_element_type=ACLItem go_array_types=[] erb pgtype_array_type=HstoreArray pgtype_element_type=Hstore go_array_types=[]map[string]string element_type_name=hstore text_null=NULL binary_format=true typed_array.go.erb > hstore_array.go erb pgtype_array_type=NumericArray pgtype_element_type=Numeric go_array_types=[]float32,[]float64 element_type_name=numeric text_null=NULL binary_format=true typed_array.go.erb > numeric_array.go erb pgtype_array_type=UUIDArray pgtype_element_type=UUID go_array_types=[][16]byte,[][]byte,[]string element_type_name=uuid text_null=NULL binary_format=true typed_array.go.erb > uuid_array.go + +# While the binary format is theoretically possible it is only practical to use the text format. In addition, the text format for NULL enums is unquoted so TextArray or a possible GenericTextArray cannot be used. +erb pgtype_array_type=EnumArray pgtype_element_type=GenericText go_array_types=[]string text_null='NULL' binary_format=false typed_array.go.erb > enum_array.go + goimports -w *_array.go