diff --git a/enum_type.go b/enum_type.go new file mode 100644 index 00000000..44095cc7 --- /dev/null +++ b/enum_type.go @@ -0,0 +1,163 @@ +package pgtype + +import errors "golang.org/x/xerrors" + +// EnumType represents an enum type. In the normal pgtype model a Go type maps to a PostgreSQL type and an instance +// of a Go type maps to a PostgreSQL value of that type. EnumType is different in that an instance of EnumType +// represents a PostgreSQL type. The zero value is not usable -- NewEnumType must be used as a constructor. In general, +// an EnumType should not be used to represent a value. It should only be used as an encoder and decoder internal to +// ConnInfo. +type EnumType struct { + String string + Status Status + + pgTypeName string // PostgreSQL type name + members []string // enum members + membersMap map[string]string // map to quickly lookup member and reuse string instead of allocating +} + +// NewEnumType initializes a new EnumType. It retains a read-only reference to members. members must not be changed. +func NewEnumType(pgTypeName string, members []string) *EnumType { + et := &EnumType{pgTypeName: pgTypeName, members: members} + et.membersMap = make(map[string]string, len(members)) + for _, m := range members { + et.membersMap[m] = m + } + return et +} + +func (et *EnumType) CloneTypeValue() Value { + return &EnumType{ + String: et.String, + Status: et.Status, + + pgTypeName: et.pgTypeName, + members: et.members, + membersMap: et.membersMap, + } +} + +func (et *EnumType) PgTypeName() string { + return et.pgTypeName +} + +func (et *EnumType) Members() []string { + return et.members +} + +// Set assigns src to dst. Set purposely does not check that src is a member. This allows continued error free +// operation in the event the PostgreSQL enum type is modified during a connection. +func (dst *EnumType) Set(src interface{}) error { + if src == nil { + dst.Status = Null + return nil + } + + if value, ok := src.(interface{ Get() interface{} }); ok { + value2 := value.Get() + if value2 != value { + return dst.Set(value2) + } + } + + switch value := src.(type) { + case string: + dst.String = value + dst.Status = Present + case *string: + if value == nil { + dst.Status = Null + } else { + dst.String = *value + dst.Status = Present + } + case []byte: + if value == nil { + dst.Status = Null + } else { + dst.String = string(value) + dst.Status = Present + } + default: + if originalSrc, ok := underlyingStringType(src); ok { + return dst.Set(originalSrc) + } + return errors.Errorf("cannot convert %v to enum %s", value, dst.pgTypeName) + } + + return nil +} + +func (dst EnumType) Get() interface{} { + switch dst.Status { + case Present: + return dst.String + case Null: + return nil + default: + return dst.Status + } +} + +func (src *EnumType) AssignTo(dst interface{}) error { + switch src.Status { + case Present: + switch v := dst.(type) { + case *string: + *v = src.String + return nil + case *[]byte: + *v = make([]byte, len(src.String)) + copy(*v, src.String) + return nil + default: + if nextDst, retry := GetAssignToDstType(dst); retry { + return src.AssignTo(nextDst) + } + return errors.Errorf("unable to assign to %T", dst) + } + case Null: + return NullAssignTo(dst) + } + + return errors.Errorf("cannot decode %#v into %T", src, dst) +} + +func (dst *EnumType) DecodeText(ci *ConnInfo, src []byte) error { + if src == nil { + dst.Status = Null + return nil + } + + // Lookup the string in membersMap to avoid an allocation. + if s, found := dst.membersMap[string(src)]; found { + dst.String = s + } else { + // If an enum type is modified after the initial connection it is possible to receive an unexpected value. + // Gracefully handle this situation. Purposely NOT modifying members and membersMap to allow for sharing members + // and membersMap between connections. + dst.String = string(src) + } + dst.Status = Present + + return nil +} + +func (dst *EnumType) DecodeBinary(ci *ConnInfo, src []byte) error { + return dst.DecodeText(ci, src) +} + +func (src EnumType) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { + switch src.Status { + case Null: + return nil, nil + case Undefined: + return nil, errUndefined + } + + return append(buf, src.String...), nil +} + +func (src EnumType) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { + return src.EncodeText(ci, buf) +} diff --git a/enum_type_test.go b/enum_type_test.go new file mode 100644 index 00000000..4dd88f2a --- /dev/null +++ b/enum_type_test.go @@ -0,0 +1,148 @@ +package pgtype_test + +import ( + "bytes" + "context" + "testing" + + "github.com/jackc/pgtype" + "github.com/jackc/pgtype/testutil" + "github.com/jackc/pgx/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func setupEnum(t *testing.T, conn *pgx.Conn) *pgtype.EnumType { + _, err := conn.Exec(context.Background(), "drop type if exists pgtype_enum_color;") + require.NoError(t, err) + + _, err = conn.Exec(context.Background(), "create type pgtype_enum_color as enum ('blue', 'green', 'purple');") + require.NoError(t, err) + + var oid uint32 + err = conn.QueryRow(context.Background(), "select oid from pg_type where typname=$1;", "pgtype_enum_color").Scan(&oid) + require.NoError(t, err) + + et := pgtype.NewEnumType("pgtype_enum_color", []string{"blue", "green", "purple"}) + conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: et, Name: "pgtype_enum_color", OID: oid}) + + return et +} + +func cleanupEnum(t *testing.T, conn *pgx.Conn) { + _, err := conn.Exec(context.Background(), "drop type if exists pgtype_enum_color;") + require.NoError(t, err) +} + +func TestEnumTypeTranscode(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + setupEnum(t, conn) + defer cleanupEnum(t, conn) + + var dst string + err := conn.QueryRow(context.Background(), "select $1::pgtype_enum_color", "blue").Scan(&dst) + require.NoError(t, err) + require.EqualValues(t, "blue", dst) +} + +func TestEnumTypeSet(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + enumType := setupEnum(t, conn) + defer cleanupEnum(t, conn) + + successfulTests := []struct { + source interface{} + result interface{} + }{ + {source: "blue", result: "blue"}, + {source: _string("green"), result: "green"}, + {source: (*string)(nil), result: nil}, + } + + for i, tt := range successfulTests { + err := enumType.Set(tt.source) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.result, enumType.Get(), "%d", i) + } +} + +func TestEnumTypeAssignTo(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + enumType := setupEnum(t, conn) + defer cleanupEnum(t, conn) + + { + var s string + + err := enumType.Set("blue") + require.NoError(t, err) + + err = enumType.AssignTo(&s) + require.NoError(t, err) + + assert.EqualValues(t, "blue", s) + } + + { + var ps *string + + err := enumType.Set("blue") + require.NoError(t, err) + + err = enumType.AssignTo(&ps) + require.NoError(t, err) + + assert.EqualValues(t, "blue", *ps) + } + + { + var ps *string + + err := enumType.Set(nil) + require.NoError(t, err) + + err = enumType.AssignTo(&ps) + require.NoError(t, err) + + assert.EqualValues(t, (*string)(nil), ps) + } + + var buf []byte + bytesTests := []struct { + src interface{} + dst *[]byte + expected []byte + }{ + {src: "blue", dst: &buf, expected: []byte("blue")}, + {src: nil, dst: &buf, expected: nil}, + } + + for i, tt := range bytesTests { + err := enumType.Set(tt.src) + require.NoError(t, err, "%d", i) + + err = enumType.AssignTo(tt.dst) + require.NoError(t, err, "%d", i) + + if bytes.Compare(*tt.dst, tt.expected) != 0 { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, tt.dst) + } + } + + { + var s string + + err := enumType.Set(nil) + require.NoError(t, err) + + err = enumType.AssignTo(&s) + require.Error(t, err) + } + +} diff --git a/pgtype.go b/pgtype.go index bb0a99af..997899d8 100644 --- a/pgtype.go +++ b/pgtype.go @@ -125,6 +125,22 @@ type Value interface { AssignTo(dst interface{}) error } +// TypeValue represents values where instances represent a type. In the normal pgtype model a Go type maps to a +// PostgreSQL type and an instance of a Go type maps to a PostgreSQL value of that type. Implementors of TypeValue +// are different in that an instance represents a PostgreSQL type. This can be useful for representing types such +// as enums, composites, and arrays. +// +// In general, instances of TypeValue should not be used to directly represent a value. It should only be used as an +// encoder and decoder internal to ConnInfo. +type TypeValue interface { + // CloneTypeValue duplicates a TypeValue including references to internal type information. e.g. the list of members + // in an EnumType. + CloneTypeValue() Value + + // PgTypeName returns the PostgreSQL name of this type. + PgTypeName() string +} + type BinaryDecoder interface { // DecodeBinary decodes src into BinaryDecoder. If src is nil then the // original SQL value is NULL. BinaryDecoder takes ownership of src. The @@ -270,9 +286,16 @@ func (ci *ConnInfo) InitializeDataTypes(nameOIDs map[string]uint32) { } func (ci *ConnInfo) RegisterDataType(t DataType) { + tv, _ := t.Value.(TypeValue) + if tv != nil { + t.Value = tv.CloneTypeValue() + } + ci.oidToDataType[t.OID] = &t ci.nameToDataType[t.Name] = &t - ci.reflectTypeToDataType[reflect.ValueOf(t.Value).Type()] = &t + if tv == nil { + ci.reflectTypeToDataType[reflect.ValueOf(t.Value).Type()] = &t + } { var formatCode int16 @@ -310,6 +333,11 @@ func (ci *ConnInfo) DataTypeForName(name string) (*DataType, bool) { } func (ci *ConnInfo) DataTypeForValue(v Value) (*DataType, bool) { + if tv, ok := v.(TypeValue); ok { + dt, ok := ci.nameToDataType[tv.PgTypeName()] + return dt, ok + } + dt, ok := ci.reflectTypeToDataType[reflect.ValueOf(v).Type()] return dt, ok } @@ -336,11 +364,20 @@ func (ci *ConnInfo) DeepCopy() *ConnInfo { oidToDataType: make(map[uint32]*DataType, len(ci.oidToDataType)), nameToDataType: make(map[string]*DataType, len(ci.nameToDataType)), reflectTypeToDataType: make(map[reflect.Type]*DataType, len(ci.reflectTypeToDataType)), + oidToParamFormatCode: make(map[uint32]int16, len(ci.oidToParamFormatCode)), + oidToResultFormatCode: make(map[uint32]int16, len(ci.oidToResultFormatCode)), } for _, dt := range ci.oidToDataType { + var value Value + if tv, ok := dt.Value.(TypeValue); ok { + value = tv.CloneTypeValue() + } else { + value = reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(Value) + } + ci2.RegisterDataType(DataType{ - Value: reflect.New(reflect.ValueOf(dt.Value).Elem().Type()).Interface().(Value), + Value: value, Name: dt.Name, OID: dt.OID, })