diff --git a/composite_bench_test.go b/composite_bench_test.go index cff9d518..7aef8c4f 100644 --- a/composite_bench_test.go +++ b/composite_bench_test.go @@ -5,6 +5,7 @@ import ( "github.com/jackc/pgio" "github.com/jackc/pgtype" + "github.com/stretchr/testify/require" ) type MyCompositeRaw struct { @@ -83,7 +84,11 @@ func BenchmarkBinaryEncodingComposite(b *testing.B) { ci := pgtype.NewConnInfo() f1 := 2 f2 := ptrS("bar") - c := pgtype.NewCompositeType("test", &pgtype.Int4{}, &pgtype.Text{}) + c, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{ + {"a", pgtype.Int4OID}, + {"b", pgtype.TextOID}, + }, ci) + require.NoError(b, err) b.ResetTimer() for n := 0; n < b.N; n++ { @@ -146,7 +151,11 @@ func BenchmarkBinaryDecodingCompositeScan(b *testing.B) { var f1 int var f2 *string - c := pgtype.NewCompositeType("test", &pgtype.Int4{}, &pgtype.Text{}) + c, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{ + {"a", pgtype.Int4OID}, + {"b", pgtype.TextOID}, + }, ci) + require.NoError(b, err) b.ResetTimer() for n := 0; n < b.N; n++ { diff --git a/composite_type.go b/composite_type.go index 7f5ae694..389bf178 100644 --- a/composite_type.go +++ b/composite_type.go @@ -9,30 +9,59 @@ import ( errors "golang.org/x/xerrors" ) +type CompositeTypeField struct { + Name string + OID uint32 +} + type CompositeType struct { status Status typeName string - fields []ValueTranscoder + + fields []CompositeTypeField + valueTranscoders []ValueTranscoder } -// NewCompositeType creates a Composite object, which acts as a "schema" for -// SQL composite values. -// To pass Composite as SQL parameter first set it's fields, either by -// passing initialized Value{} instances to NewCompositeType or by calling -// SetFields method -// To read composite fields back pass result of Scan() method -// to query Scan function. -func NewCompositeType(typeName string, fields ...ValueTranscoder) *CompositeType { - return &CompositeType{typeName: typeName, fields: fields} +// NewCompositeType creates a CompositeType from fields and ci. ci is used to find the ValueTranscoders used +// for fields. All field OIDs must be previously registered in ci. +func NewCompositeType(typeName string, fields []CompositeTypeField, ci *ConnInfo) (*CompositeType, error) { + valueTranscoders := make([]ValueTranscoder, len(fields)) + + for i := range fields { + dt, ok := ci.DataTypeForOID(fields[i].OID) + if !ok { + return nil, errors.Errorf("no data type registered for oid: %d", fields[i].OID) + } + + value := NewValue(dt.Value) + valueTranscoder, ok := value.(ValueTranscoder) + if !ok { + return nil, errors.Errorf("data type for oid does not implement ValueTranscoder: %d", fields[i].OID) + } + + valueTranscoders[i] = valueTranscoder + } + + return &CompositeType{typeName: typeName, fields: fields, valueTranscoders: valueTranscoders}, nil +} + +// NewCompositeTypeValues creates a CompositeType from fields and values. fields and values must have the same length. +// Prefer NewCompositeType unless overriding the transcoding of fields is required. +func NewCompositeTypeValues(typeName string, fields []CompositeTypeField, values []ValueTranscoder) (*CompositeType, error) { + if len(fields) != len(values) { + return nil, errors.New("fields and valueTranscoders must have same length") + } + + return &CompositeType{typeName: typeName, fields: fields, valueTranscoders: values}, nil } func (src CompositeType) Get() interface{} { switch src.status { case Present: - results := make([]interface{}, len(src.fields)) + results := make([]interface{}, len(src.valueTranscoders)) for i := range results { - results[i] = src.fields[i].Get() + results[i] = src.valueTranscoders[i].Get() } return results case Null: @@ -44,12 +73,13 @@ func (src CompositeType) Get() interface{} { func (ct *CompositeType) NewTypeValue() Value { a := &CompositeType{ - typeName: ct.typeName, - fields: make([]ValueTranscoder, len(ct.fields)), + typeName: ct.typeName, + fields: ct.fields, + valueTranscoders: make([]ValueTranscoder, len(ct.valueTranscoders)), } - for i := range ct.fields { - a.fields[i] = NewValue(ct.fields[i]).(ValueTranscoder) + for i := range ct.valueTranscoders { + a.valueTranscoders[i] = NewValue(ct.valueTranscoders[i]).(ValueTranscoder) } return a @@ -59,6 +89,10 @@ func (ct *CompositeType) TypeName() string { return ct.typeName } +func (ct *CompositeType) Fields() []CompositeTypeField { + return ct.fields +} + func (dst *CompositeType) Set(src interface{}) error { if src == nil { dst.status = Null @@ -67,11 +101,11 @@ func (dst *CompositeType) Set(src interface{}) error { switch value := src.(type) { case []interface{}: - if len(value) != len(dst.fields) { - return errors.Errorf("Number of fields don't match. CompositeType has %d fields", len(dst.fields)) + if len(value) != len(dst.valueTranscoders) { + return errors.Errorf("Number of fields don't match. CompositeType has %d fields", len(dst.valueTranscoders)) } for i, v := range value { - if err := dst.fields[i].Set(v); err != nil { + if err := dst.valueTranscoders[i].Set(v); err != nil { return err } } @@ -95,15 +129,15 @@ func (src CompositeType) AssignTo(dst interface{}) error { case Present: switch v := dst.(type) { case []interface{}: - if len(v) != len(src.fields) { - return errors.Errorf("Number of fields don't match. CompositeType has %d fields", len(src.fields)) + if len(v) != len(src.valueTranscoders) { + return errors.Errorf("Number of fields don't match. CompositeType has %d fields", len(src.valueTranscoders)) } - for i := range src.fields { + for i := range src.valueTranscoders { if v[i] == nil { continue } - err := assignToOrSet(src.fields[i], v[i]) + err := assignToOrSet(src.valueTranscoders[i], v[i]) if err != nil { return errors.Errorf("unable to assign to dst[%d]: %v", i, err) } @@ -169,12 +203,12 @@ func (src CompositeType) assignToPtrStruct(dst interface{}) (bool, error) { } } - if len(exportedFields) != len(src.fields) { + if len(exportedFields) != len(src.valueTranscoders) { return false, nil } for i := range exportedFields { - err := assignToOrSet(src.fields[i], dstElemValue.Field(exportedFields[i]).Addr().Interface()) + err := assignToOrSet(src.valueTranscoders[i], dstElemValue.Field(exportedFields[i]).Addr().Interface()) if err != nil { return true, errors.Errorf("unable to assign to field %s: %v", dstElemType.Field(exportedFields[i]).Name, err) } @@ -192,13 +226,8 @@ func (src CompositeType) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte, } b := NewCompositeBinaryBuilder(ci, buf) - for _, f := range src.fields { - dt, ok := ci.DataTypeForValue(f) - if !ok { - return nil, errors.Errorf("unknown oid") - } - - b.AppendEncoder(dt.OID, f) + for i := range src.valueTranscoders { + b.AppendEncoder(src.fields[i].OID, src.valueTranscoders[i]) } return b.Finish() @@ -216,7 +245,7 @@ func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) error { scanner := NewCompositeBinaryScanner(ci, buf) - for _, f := range dst.fields { + for _, f := range dst.valueTranscoders { scanner.ScanDecoder(f) } @@ -237,7 +266,7 @@ func (dst *CompositeType) DecodeText(ci *ConnInfo, buf []byte) error { scanner := NewCompositeTextScanner(ci, buf) - for _, f := range dst.fields { + for _, f := range dst.valueTranscoders { scanner.ScanDecoder(f) } @@ -259,7 +288,7 @@ func (src CompositeType) EncodeText(ci *ConnInfo, buf []byte) (newBuf []byte, er } b := NewCompositeTextBuilder(ci, buf) - for _, f := range src.fields { + for _, f := range src.valueTranscoders { b.AppendEncoder(f) } diff --git a/composite_type_test.go b/composite_type_test.go index 0225e443..b32810ff 100644 --- a/composite_type_test.go +++ b/composite_type_test.go @@ -14,7 +14,12 @@ import ( ) func TestCompositeTypeSetAndGet(t *testing.T) { - ct := pgtype.NewCompositeType("test", &pgtype.Text{}, &pgtype.Int4{}) + ci := pgtype.NewConnInfo() + ct, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{ + {"a", pgtype.TextOID}, + {"b", pgtype.Int4OID}, + }, ci) + require.NoError(t, err) assert.Equal(t, pgtype.Undefined, ct.Get()) nilTests := []struct { @@ -56,7 +61,12 @@ func TestCompositeTypeSetAndGet(t *testing.T) { } func TestCompositeTypeAssignTo(t *testing.T) { - ct := pgtype.NewCompositeType("test", &pgtype.Text{}, &pgtype.Int4{}) + ci := pgtype.NewConnInfo() + ct, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{ + {"a", pgtype.TextOID}, + {"b", pgtype.Int4OID}, + }, ci) + require.NoError(t, err) { err := ct.Set([]interface{}{"foo", int32(42)}) @@ -168,8 +178,12 @@ create type ct_test as ( defer conn.Exec(context.Background(), "drop type ct_test") - ct := pgtype.NewCompositeType("ct_test", &pgtype.Text{}, &pgtype.Int4{}) - conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: "ct_test", OID: oid}) + ct, err := pgtype.NewCompositeType("ct_test", []pgtype.CompositeTypeField{ + {"a", pgtype.TextOID}, + {"b", pgtype.Int4OID}, + }, conn.ConnInfo()) + require.NoError(t, err) + conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: ct.TypeName(), OID: oid}) // Use simple protocol to force text or binary encoding simpleProtocols := []bool{true, false} @@ -221,8 +235,15 @@ func Example_composite() { return } - c := pgtype.NewCompositeType("mytype", &pgtype.Int4{}, &pgtype.Text{}) - conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: c, Name: "mytype", OID: oid}) + ct, err := pgtype.NewCompositeType("mytype", []pgtype.CompositeTypeField{ + {"a", pgtype.Int4OID}, + {"b", pgtype.TextOID}, + }, conn.ConnInfo()) + if err != nil { + fmt.Println(err) + return + } + conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: ct.TypeName(), OID: oid}) var a int var b *string