CompositeType fields contain name and oid

This commit is contained in:
Jack Christensen 2020-05-13 07:09:52 -05:00
parent 0e2bc3467a
commit ee0e207ee4
3 changed files with 102 additions and 43 deletions

View File

@ -5,6 +5,7 @@ import (
"github.com/jackc/pgio"
"github.com/jackc/pgtype"
"github.com/stretchr/testify/require"
)
type MyCompositeRaw struct {
@ -83,7 +84,11 @@ func BenchmarkBinaryEncodingComposite(b *testing.B) {
ci := pgtype.NewConnInfo()
f1 := 2
f2 := ptrS("bar")
c := pgtype.NewCompositeType("test", &pgtype.Int4{}, &pgtype.Text{})
c, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{
{"a", pgtype.Int4OID},
{"b", pgtype.TextOID},
}, ci)
require.NoError(b, err)
b.ResetTimer()
for n := 0; n < b.N; n++ {
@ -146,7 +151,11 @@ func BenchmarkBinaryDecodingCompositeScan(b *testing.B) {
var f1 int
var f2 *string
c := pgtype.NewCompositeType("test", &pgtype.Int4{}, &pgtype.Text{})
c, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{
{"a", pgtype.Int4OID},
{"b", pgtype.TextOID},
}, ci)
require.NoError(b, err)
b.ResetTimer()
for n := 0; n < b.N; n++ {

View File

@ -9,30 +9,59 @@ import (
errors "golang.org/x/xerrors"
)
type CompositeTypeField struct {
Name string
OID uint32
}
type CompositeType struct {
status Status
typeName string
fields []ValueTranscoder
fields []CompositeTypeField
valueTranscoders []ValueTranscoder
}
// NewCompositeType creates a Composite object, which acts as a "schema" for
// SQL composite values.
// To pass Composite as SQL parameter first set it's fields, either by
// passing initialized Value{} instances to NewCompositeType or by calling
// SetFields method
// To read composite fields back pass result of Scan() method
// to query Scan function.
func NewCompositeType(typeName string, fields ...ValueTranscoder) *CompositeType {
return &CompositeType{typeName: typeName, fields: fields}
// NewCompositeType creates a CompositeType from fields and ci. ci is used to find the ValueTranscoders used
// for fields. All field OIDs must be previously registered in ci.
func NewCompositeType(typeName string, fields []CompositeTypeField, ci *ConnInfo) (*CompositeType, error) {
valueTranscoders := make([]ValueTranscoder, len(fields))
for i := range fields {
dt, ok := ci.DataTypeForOID(fields[i].OID)
if !ok {
return nil, errors.Errorf("no data type registered for oid: %d", fields[i].OID)
}
value := NewValue(dt.Value)
valueTranscoder, ok := value.(ValueTranscoder)
if !ok {
return nil, errors.Errorf("data type for oid does not implement ValueTranscoder: %d", fields[i].OID)
}
valueTranscoders[i] = valueTranscoder
}
return &CompositeType{typeName: typeName, fields: fields, valueTranscoders: valueTranscoders}, nil
}
// NewCompositeTypeValues creates a CompositeType from fields and values. fields and values must have the same length.
// Prefer NewCompositeType unless overriding the transcoding of fields is required.
func NewCompositeTypeValues(typeName string, fields []CompositeTypeField, values []ValueTranscoder) (*CompositeType, error) {
if len(fields) != len(values) {
return nil, errors.New("fields and valueTranscoders must have same length")
}
return &CompositeType{typeName: typeName, fields: fields, valueTranscoders: values}, nil
}
func (src CompositeType) Get() interface{} {
switch src.status {
case Present:
results := make([]interface{}, len(src.fields))
results := make([]interface{}, len(src.valueTranscoders))
for i := range results {
results[i] = src.fields[i].Get()
results[i] = src.valueTranscoders[i].Get()
}
return results
case Null:
@ -44,12 +73,13 @@ func (src CompositeType) Get() interface{} {
func (ct *CompositeType) NewTypeValue() Value {
a := &CompositeType{
typeName: ct.typeName,
fields: make([]ValueTranscoder, len(ct.fields)),
typeName: ct.typeName,
fields: ct.fields,
valueTranscoders: make([]ValueTranscoder, len(ct.valueTranscoders)),
}
for i := range ct.fields {
a.fields[i] = NewValue(ct.fields[i]).(ValueTranscoder)
for i := range ct.valueTranscoders {
a.valueTranscoders[i] = NewValue(ct.valueTranscoders[i]).(ValueTranscoder)
}
return a
@ -59,6 +89,10 @@ func (ct *CompositeType) TypeName() string {
return ct.typeName
}
func (ct *CompositeType) Fields() []CompositeTypeField {
return ct.fields
}
func (dst *CompositeType) Set(src interface{}) error {
if src == nil {
dst.status = Null
@ -67,11 +101,11 @@ func (dst *CompositeType) Set(src interface{}) error {
switch value := src.(type) {
case []interface{}:
if len(value) != len(dst.fields) {
return errors.Errorf("Number of fields don't match. CompositeType has %d fields", len(dst.fields))
if len(value) != len(dst.valueTranscoders) {
return errors.Errorf("Number of fields don't match. CompositeType has %d fields", len(dst.valueTranscoders))
}
for i, v := range value {
if err := dst.fields[i].Set(v); err != nil {
if err := dst.valueTranscoders[i].Set(v); err != nil {
return err
}
}
@ -95,15 +129,15 @@ func (src CompositeType) AssignTo(dst interface{}) error {
case Present:
switch v := dst.(type) {
case []interface{}:
if len(v) != len(src.fields) {
return errors.Errorf("Number of fields don't match. CompositeType has %d fields", len(src.fields))
if len(v) != len(src.valueTranscoders) {
return errors.Errorf("Number of fields don't match. CompositeType has %d fields", len(src.valueTranscoders))
}
for i := range src.fields {
for i := range src.valueTranscoders {
if v[i] == nil {
continue
}
err := assignToOrSet(src.fields[i], v[i])
err := assignToOrSet(src.valueTranscoders[i], v[i])
if err != nil {
return errors.Errorf("unable to assign to dst[%d]: %v", i, err)
}
@ -169,12 +203,12 @@ func (src CompositeType) assignToPtrStruct(dst interface{}) (bool, error) {
}
}
if len(exportedFields) != len(src.fields) {
if len(exportedFields) != len(src.valueTranscoders) {
return false, nil
}
for i := range exportedFields {
err := assignToOrSet(src.fields[i], dstElemValue.Field(exportedFields[i]).Addr().Interface())
err := assignToOrSet(src.valueTranscoders[i], dstElemValue.Field(exportedFields[i]).Addr().Interface())
if err != nil {
return true, errors.Errorf("unable to assign to field %s: %v", dstElemType.Field(exportedFields[i]).Name, err)
}
@ -192,13 +226,8 @@ func (src CompositeType) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte,
}
b := NewCompositeBinaryBuilder(ci, buf)
for _, f := range src.fields {
dt, ok := ci.DataTypeForValue(f)
if !ok {
return nil, errors.Errorf("unknown oid")
}
b.AppendEncoder(dt.OID, f)
for i := range src.valueTranscoders {
b.AppendEncoder(src.fields[i].OID, src.valueTranscoders[i])
}
return b.Finish()
@ -216,7 +245,7 @@ func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) error {
scanner := NewCompositeBinaryScanner(ci, buf)
for _, f := range dst.fields {
for _, f := range dst.valueTranscoders {
scanner.ScanDecoder(f)
}
@ -237,7 +266,7 @@ func (dst *CompositeType) DecodeText(ci *ConnInfo, buf []byte) error {
scanner := NewCompositeTextScanner(ci, buf)
for _, f := range dst.fields {
for _, f := range dst.valueTranscoders {
scanner.ScanDecoder(f)
}
@ -259,7 +288,7 @@ func (src CompositeType) EncodeText(ci *ConnInfo, buf []byte) (newBuf []byte, er
}
b := NewCompositeTextBuilder(ci, buf)
for _, f := range src.fields {
for _, f := range src.valueTranscoders {
b.AppendEncoder(f)
}

View File

@ -14,7 +14,12 @@ import (
)
func TestCompositeTypeSetAndGet(t *testing.T) {
ct := pgtype.NewCompositeType("test", &pgtype.Text{}, &pgtype.Int4{})
ci := pgtype.NewConnInfo()
ct, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{
{"a", pgtype.TextOID},
{"b", pgtype.Int4OID},
}, ci)
require.NoError(t, err)
assert.Equal(t, pgtype.Undefined, ct.Get())
nilTests := []struct {
@ -56,7 +61,12 @@ func TestCompositeTypeSetAndGet(t *testing.T) {
}
func TestCompositeTypeAssignTo(t *testing.T) {
ct := pgtype.NewCompositeType("test", &pgtype.Text{}, &pgtype.Int4{})
ci := pgtype.NewConnInfo()
ct, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{
{"a", pgtype.TextOID},
{"b", pgtype.Int4OID},
}, ci)
require.NoError(t, err)
{
err := ct.Set([]interface{}{"foo", int32(42)})
@ -168,8 +178,12 @@ create type ct_test as (
defer conn.Exec(context.Background(), "drop type ct_test")
ct := pgtype.NewCompositeType("ct_test", &pgtype.Text{}, &pgtype.Int4{})
conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: "ct_test", OID: oid})
ct, err := pgtype.NewCompositeType("ct_test", []pgtype.CompositeTypeField{
{"a", pgtype.TextOID},
{"b", pgtype.Int4OID},
}, conn.ConnInfo())
require.NoError(t, err)
conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: ct.TypeName(), OID: oid})
// Use simple protocol to force text or binary encoding
simpleProtocols := []bool{true, false}
@ -221,8 +235,15 @@ func Example_composite() {
return
}
c := pgtype.NewCompositeType("mytype", &pgtype.Int4{}, &pgtype.Text{})
conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: c, Name: "mytype", OID: oid})
ct, err := pgtype.NewCompositeType("mytype", []pgtype.CompositeTypeField{
{"a", pgtype.Int4OID},
{"b", pgtype.TextOID},
}, conn.ConnInfo())
if err != nil {
fmt.Println(err)
return
}
conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: ct.TypeName(), OID: oid})
var a int
var b *string