CompositeType fields contain name and oid

This commit is contained in:
Jack Christensen 2020-05-13 07:09:52 -05:00
parent 0e2bc3467a
commit ee0e207ee4
3 changed files with 102 additions and 43 deletions

View File

@ -5,6 +5,7 @@ import (
"github.com/jackc/pgio" "github.com/jackc/pgio"
"github.com/jackc/pgtype" "github.com/jackc/pgtype"
"github.com/stretchr/testify/require"
) )
type MyCompositeRaw struct { type MyCompositeRaw struct {
@ -83,7 +84,11 @@ func BenchmarkBinaryEncodingComposite(b *testing.B) {
ci := pgtype.NewConnInfo() ci := pgtype.NewConnInfo()
f1 := 2 f1 := 2
f2 := ptrS("bar") f2 := ptrS("bar")
c := pgtype.NewCompositeType("test", &pgtype.Int4{}, &pgtype.Text{}) c, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{
{"a", pgtype.Int4OID},
{"b", pgtype.TextOID},
}, ci)
require.NoError(b, err)
b.ResetTimer() b.ResetTimer()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
@ -146,7 +151,11 @@ func BenchmarkBinaryDecodingCompositeScan(b *testing.B) {
var f1 int var f1 int
var f2 *string var f2 *string
c := pgtype.NewCompositeType("test", &pgtype.Int4{}, &pgtype.Text{}) c, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{
{"a", pgtype.Int4OID},
{"b", pgtype.TextOID},
}, ci)
require.NoError(b, err)
b.ResetTimer() b.ResetTimer()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {

View File

@ -9,30 +9,59 @@ import (
errors "golang.org/x/xerrors" errors "golang.org/x/xerrors"
) )
type CompositeTypeField struct {
Name string
OID uint32
}
type CompositeType struct { type CompositeType struct {
status Status status Status
typeName string typeName string
fields []ValueTranscoder
fields []CompositeTypeField
valueTranscoders []ValueTranscoder
} }
// NewCompositeType creates a Composite object, which acts as a "schema" for // NewCompositeType creates a CompositeType from fields and ci. ci is used to find the ValueTranscoders used
// SQL composite values. // for fields. All field OIDs must be previously registered in ci.
// To pass Composite as SQL parameter first set it's fields, either by func NewCompositeType(typeName string, fields []CompositeTypeField, ci *ConnInfo) (*CompositeType, error) {
// passing initialized Value{} instances to NewCompositeType or by calling valueTranscoders := make([]ValueTranscoder, len(fields))
// SetFields method
// To read composite fields back pass result of Scan() method for i := range fields {
// to query Scan function. dt, ok := ci.DataTypeForOID(fields[i].OID)
func NewCompositeType(typeName string, fields ...ValueTranscoder) *CompositeType { if !ok {
return &CompositeType{typeName: typeName, fields: fields} return nil, errors.Errorf("no data type registered for oid: %d", fields[i].OID)
}
value := NewValue(dt.Value)
valueTranscoder, ok := value.(ValueTranscoder)
if !ok {
return nil, errors.Errorf("data type for oid does not implement ValueTranscoder: %d", fields[i].OID)
}
valueTranscoders[i] = valueTranscoder
}
return &CompositeType{typeName: typeName, fields: fields, valueTranscoders: valueTranscoders}, nil
}
// NewCompositeTypeValues creates a CompositeType from fields and values. fields and values must have the same length.
// Prefer NewCompositeType unless overriding the transcoding of fields is required.
func NewCompositeTypeValues(typeName string, fields []CompositeTypeField, values []ValueTranscoder) (*CompositeType, error) {
if len(fields) != len(values) {
return nil, errors.New("fields and valueTranscoders must have same length")
}
return &CompositeType{typeName: typeName, fields: fields, valueTranscoders: values}, nil
} }
func (src CompositeType) Get() interface{} { func (src CompositeType) Get() interface{} {
switch src.status { switch src.status {
case Present: case Present:
results := make([]interface{}, len(src.fields)) results := make([]interface{}, len(src.valueTranscoders))
for i := range results { for i := range results {
results[i] = src.fields[i].Get() results[i] = src.valueTranscoders[i].Get()
} }
return results return results
case Null: case Null:
@ -44,12 +73,13 @@ func (src CompositeType) Get() interface{} {
func (ct *CompositeType) NewTypeValue() Value { func (ct *CompositeType) NewTypeValue() Value {
a := &CompositeType{ a := &CompositeType{
typeName: ct.typeName, typeName: ct.typeName,
fields: make([]ValueTranscoder, len(ct.fields)), fields: ct.fields,
valueTranscoders: make([]ValueTranscoder, len(ct.valueTranscoders)),
} }
for i := range ct.fields { for i := range ct.valueTranscoders {
a.fields[i] = NewValue(ct.fields[i]).(ValueTranscoder) a.valueTranscoders[i] = NewValue(ct.valueTranscoders[i]).(ValueTranscoder)
} }
return a return a
@ -59,6 +89,10 @@ func (ct *CompositeType) TypeName() string {
return ct.typeName return ct.typeName
} }
func (ct *CompositeType) Fields() []CompositeTypeField {
return ct.fields
}
func (dst *CompositeType) Set(src interface{}) error { func (dst *CompositeType) Set(src interface{}) error {
if src == nil { if src == nil {
dst.status = Null dst.status = Null
@ -67,11 +101,11 @@ func (dst *CompositeType) Set(src interface{}) error {
switch value := src.(type) { switch value := src.(type) {
case []interface{}: case []interface{}:
if len(value) != len(dst.fields) { if len(value) != len(dst.valueTranscoders) {
return errors.Errorf("Number of fields don't match. CompositeType has %d fields", len(dst.fields)) return errors.Errorf("Number of fields don't match. CompositeType has %d fields", len(dst.valueTranscoders))
} }
for i, v := range value { for i, v := range value {
if err := dst.fields[i].Set(v); err != nil { if err := dst.valueTranscoders[i].Set(v); err != nil {
return err return err
} }
} }
@ -95,15 +129,15 @@ func (src CompositeType) AssignTo(dst interface{}) error {
case Present: case Present:
switch v := dst.(type) { switch v := dst.(type) {
case []interface{}: case []interface{}:
if len(v) != len(src.fields) { if len(v) != len(src.valueTranscoders) {
return errors.Errorf("Number of fields don't match. CompositeType has %d fields", len(src.fields)) return errors.Errorf("Number of fields don't match. CompositeType has %d fields", len(src.valueTranscoders))
} }
for i := range src.fields { for i := range src.valueTranscoders {
if v[i] == nil { if v[i] == nil {
continue continue
} }
err := assignToOrSet(src.fields[i], v[i]) err := assignToOrSet(src.valueTranscoders[i], v[i])
if err != nil { if err != nil {
return errors.Errorf("unable to assign to dst[%d]: %v", i, err) return errors.Errorf("unable to assign to dst[%d]: %v", i, err)
} }
@ -169,12 +203,12 @@ func (src CompositeType) assignToPtrStruct(dst interface{}) (bool, error) {
} }
} }
if len(exportedFields) != len(src.fields) { if len(exportedFields) != len(src.valueTranscoders) {
return false, nil return false, nil
} }
for i := range exportedFields { for i := range exportedFields {
err := assignToOrSet(src.fields[i], dstElemValue.Field(exportedFields[i]).Addr().Interface()) err := assignToOrSet(src.valueTranscoders[i], dstElemValue.Field(exportedFields[i]).Addr().Interface())
if err != nil { if err != nil {
return true, errors.Errorf("unable to assign to field %s: %v", dstElemType.Field(exportedFields[i]).Name, err) return true, errors.Errorf("unable to assign to field %s: %v", dstElemType.Field(exportedFields[i]).Name, err)
} }
@ -192,13 +226,8 @@ func (src CompositeType) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte,
} }
b := NewCompositeBinaryBuilder(ci, buf) b := NewCompositeBinaryBuilder(ci, buf)
for _, f := range src.fields { for i := range src.valueTranscoders {
dt, ok := ci.DataTypeForValue(f) b.AppendEncoder(src.fields[i].OID, src.valueTranscoders[i])
if !ok {
return nil, errors.Errorf("unknown oid")
}
b.AppendEncoder(dt.OID, f)
} }
return b.Finish() return b.Finish()
@ -216,7 +245,7 @@ func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) error {
scanner := NewCompositeBinaryScanner(ci, buf) scanner := NewCompositeBinaryScanner(ci, buf)
for _, f := range dst.fields { for _, f := range dst.valueTranscoders {
scanner.ScanDecoder(f) scanner.ScanDecoder(f)
} }
@ -237,7 +266,7 @@ func (dst *CompositeType) DecodeText(ci *ConnInfo, buf []byte) error {
scanner := NewCompositeTextScanner(ci, buf) scanner := NewCompositeTextScanner(ci, buf)
for _, f := range dst.fields { for _, f := range dst.valueTranscoders {
scanner.ScanDecoder(f) scanner.ScanDecoder(f)
} }
@ -259,7 +288,7 @@ func (src CompositeType) EncodeText(ci *ConnInfo, buf []byte) (newBuf []byte, er
} }
b := NewCompositeTextBuilder(ci, buf) b := NewCompositeTextBuilder(ci, buf)
for _, f := range src.fields { for _, f := range src.valueTranscoders {
b.AppendEncoder(f) b.AppendEncoder(f)
} }

View File

@ -14,7 +14,12 @@ import (
) )
func TestCompositeTypeSetAndGet(t *testing.T) { func TestCompositeTypeSetAndGet(t *testing.T) {
ct := pgtype.NewCompositeType("test", &pgtype.Text{}, &pgtype.Int4{}) ci := pgtype.NewConnInfo()
ct, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{
{"a", pgtype.TextOID},
{"b", pgtype.Int4OID},
}, ci)
require.NoError(t, err)
assert.Equal(t, pgtype.Undefined, ct.Get()) assert.Equal(t, pgtype.Undefined, ct.Get())
nilTests := []struct { nilTests := []struct {
@ -56,7 +61,12 @@ func TestCompositeTypeSetAndGet(t *testing.T) {
} }
func TestCompositeTypeAssignTo(t *testing.T) { func TestCompositeTypeAssignTo(t *testing.T) {
ct := pgtype.NewCompositeType("test", &pgtype.Text{}, &pgtype.Int4{}) ci := pgtype.NewConnInfo()
ct, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{
{"a", pgtype.TextOID},
{"b", pgtype.Int4OID},
}, ci)
require.NoError(t, err)
{ {
err := ct.Set([]interface{}{"foo", int32(42)}) err := ct.Set([]interface{}{"foo", int32(42)})
@ -168,8 +178,12 @@ create type ct_test as (
defer conn.Exec(context.Background(), "drop type ct_test") defer conn.Exec(context.Background(), "drop type ct_test")
ct := pgtype.NewCompositeType("ct_test", &pgtype.Text{}, &pgtype.Int4{}) ct, err := pgtype.NewCompositeType("ct_test", []pgtype.CompositeTypeField{
conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: "ct_test", OID: oid}) {"a", pgtype.TextOID},
{"b", pgtype.Int4OID},
}, conn.ConnInfo())
require.NoError(t, err)
conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: ct.TypeName(), OID: oid})
// Use simple protocol to force text or binary encoding // Use simple protocol to force text or binary encoding
simpleProtocols := []bool{true, false} simpleProtocols := []bool{true, false}
@ -221,8 +235,15 @@ func Example_composite() {
return return
} }
c := pgtype.NewCompositeType("mytype", &pgtype.Int4{}, &pgtype.Text{}) ct, err := pgtype.NewCompositeType("mytype", []pgtype.CompositeTypeField{
conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: c, Name: "mytype", OID: oid}) {"a", pgtype.Int4OID},
{"b", pgtype.TextOID},
}, conn.ConnInfo())
if err != nil {
fmt.Println(err)
return
}
conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: ct.TypeName(), OID: oid})
var a int var a int
var b *string var b *string