mirror of
https://github.com/jackc/pgx.git
synced 2025-05-31 11:42:24 +00:00
CompositeType fields contain name and oid
This commit is contained in:
parent
0e2bc3467a
commit
ee0e207ee4
@ -5,6 +5,7 @@ import (
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
"github.com/jackc/pgtype"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type MyCompositeRaw struct {
|
||||
@ -83,7 +84,11 @@ func BenchmarkBinaryEncodingComposite(b *testing.B) {
|
||||
ci := pgtype.NewConnInfo()
|
||||
f1 := 2
|
||||
f2 := ptrS("bar")
|
||||
c := pgtype.NewCompositeType("test", &pgtype.Int4{}, &pgtype.Text{})
|
||||
c, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{
|
||||
{"a", pgtype.Int4OID},
|
||||
{"b", pgtype.TextOID},
|
||||
}, ci)
|
||||
require.NoError(b, err)
|
||||
|
||||
b.ResetTimer()
|
||||
for n := 0; n < b.N; n++ {
|
||||
@ -146,7 +151,11 @@ func BenchmarkBinaryDecodingCompositeScan(b *testing.B) {
|
||||
var f1 int
|
||||
var f2 *string
|
||||
|
||||
c := pgtype.NewCompositeType("test", &pgtype.Int4{}, &pgtype.Text{})
|
||||
c, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{
|
||||
{"a", pgtype.Int4OID},
|
||||
{"b", pgtype.TextOID},
|
||||
}, ci)
|
||||
require.NoError(b, err)
|
||||
|
||||
b.ResetTimer()
|
||||
for n := 0; n < b.N; n++ {
|
||||
|
@ -9,30 +9,59 @@ import (
|
||||
errors "golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
type CompositeTypeField struct {
|
||||
Name string
|
||||
OID uint32
|
||||
}
|
||||
|
||||
type CompositeType struct {
|
||||
status Status
|
||||
|
||||
typeName string
|
||||
fields []ValueTranscoder
|
||||
|
||||
fields []CompositeTypeField
|
||||
valueTranscoders []ValueTranscoder
|
||||
}
|
||||
|
||||
// NewCompositeType creates a Composite object, which acts as a "schema" for
|
||||
// SQL composite values.
|
||||
// To pass Composite as SQL parameter first set it's fields, either by
|
||||
// passing initialized Value{} instances to NewCompositeType or by calling
|
||||
// SetFields method
|
||||
// To read composite fields back pass result of Scan() method
|
||||
// to query Scan function.
|
||||
func NewCompositeType(typeName string, fields ...ValueTranscoder) *CompositeType {
|
||||
return &CompositeType{typeName: typeName, fields: fields}
|
||||
// NewCompositeType creates a CompositeType from fields and ci. ci is used to find the ValueTranscoders used
|
||||
// for fields. All field OIDs must be previously registered in ci.
|
||||
func NewCompositeType(typeName string, fields []CompositeTypeField, ci *ConnInfo) (*CompositeType, error) {
|
||||
valueTranscoders := make([]ValueTranscoder, len(fields))
|
||||
|
||||
for i := range fields {
|
||||
dt, ok := ci.DataTypeForOID(fields[i].OID)
|
||||
if !ok {
|
||||
return nil, errors.Errorf("no data type registered for oid: %d", fields[i].OID)
|
||||
}
|
||||
|
||||
value := NewValue(dt.Value)
|
||||
valueTranscoder, ok := value.(ValueTranscoder)
|
||||
if !ok {
|
||||
return nil, errors.Errorf("data type for oid does not implement ValueTranscoder: %d", fields[i].OID)
|
||||
}
|
||||
|
||||
valueTranscoders[i] = valueTranscoder
|
||||
}
|
||||
|
||||
return &CompositeType{typeName: typeName, fields: fields, valueTranscoders: valueTranscoders}, nil
|
||||
}
|
||||
|
||||
// NewCompositeTypeValues creates a CompositeType from fields and values. fields and values must have the same length.
|
||||
// Prefer NewCompositeType unless overriding the transcoding of fields is required.
|
||||
func NewCompositeTypeValues(typeName string, fields []CompositeTypeField, values []ValueTranscoder) (*CompositeType, error) {
|
||||
if len(fields) != len(values) {
|
||||
return nil, errors.New("fields and valueTranscoders must have same length")
|
||||
}
|
||||
|
||||
return &CompositeType{typeName: typeName, fields: fields, valueTranscoders: values}, nil
|
||||
}
|
||||
|
||||
func (src CompositeType) Get() interface{} {
|
||||
switch src.status {
|
||||
case Present:
|
||||
results := make([]interface{}, len(src.fields))
|
||||
results := make([]interface{}, len(src.valueTranscoders))
|
||||
for i := range results {
|
||||
results[i] = src.fields[i].Get()
|
||||
results[i] = src.valueTranscoders[i].Get()
|
||||
}
|
||||
return results
|
||||
case Null:
|
||||
@ -44,12 +73,13 @@ func (src CompositeType) Get() interface{} {
|
||||
|
||||
func (ct *CompositeType) NewTypeValue() Value {
|
||||
a := &CompositeType{
|
||||
typeName: ct.typeName,
|
||||
fields: make([]ValueTranscoder, len(ct.fields)),
|
||||
typeName: ct.typeName,
|
||||
fields: ct.fields,
|
||||
valueTranscoders: make([]ValueTranscoder, len(ct.valueTranscoders)),
|
||||
}
|
||||
|
||||
for i := range ct.fields {
|
||||
a.fields[i] = NewValue(ct.fields[i]).(ValueTranscoder)
|
||||
for i := range ct.valueTranscoders {
|
||||
a.valueTranscoders[i] = NewValue(ct.valueTranscoders[i]).(ValueTranscoder)
|
||||
}
|
||||
|
||||
return a
|
||||
@ -59,6 +89,10 @@ func (ct *CompositeType) TypeName() string {
|
||||
return ct.typeName
|
||||
}
|
||||
|
||||
func (ct *CompositeType) Fields() []CompositeTypeField {
|
||||
return ct.fields
|
||||
}
|
||||
|
||||
func (dst *CompositeType) Set(src interface{}) error {
|
||||
if src == nil {
|
||||
dst.status = Null
|
||||
@ -67,11 +101,11 @@ func (dst *CompositeType) Set(src interface{}) error {
|
||||
|
||||
switch value := src.(type) {
|
||||
case []interface{}:
|
||||
if len(value) != len(dst.fields) {
|
||||
return errors.Errorf("Number of fields don't match. CompositeType has %d fields", len(dst.fields))
|
||||
if len(value) != len(dst.valueTranscoders) {
|
||||
return errors.Errorf("Number of fields don't match. CompositeType has %d fields", len(dst.valueTranscoders))
|
||||
}
|
||||
for i, v := range value {
|
||||
if err := dst.fields[i].Set(v); err != nil {
|
||||
if err := dst.valueTranscoders[i].Set(v); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -95,15 +129,15 @@ func (src CompositeType) AssignTo(dst interface{}) error {
|
||||
case Present:
|
||||
switch v := dst.(type) {
|
||||
case []interface{}:
|
||||
if len(v) != len(src.fields) {
|
||||
return errors.Errorf("Number of fields don't match. CompositeType has %d fields", len(src.fields))
|
||||
if len(v) != len(src.valueTranscoders) {
|
||||
return errors.Errorf("Number of fields don't match. CompositeType has %d fields", len(src.valueTranscoders))
|
||||
}
|
||||
for i := range src.fields {
|
||||
for i := range src.valueTranscoders {
|
||||
if v[i] == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
err := assignToOrSet(src.fields[i], v[i])
|
||||
err := assignToOrSet(src.valueTranscoders[i], v[i])
|
||||
if err != nil {
|
||||
return errors.Errorf("unable to assign to dst[%d]: %v", i, err)
|
||||
}
|
||||
@ -169,12 +203,12 @@ func (src CompositeType) assignToPtrStruct(dst interface{}) (bool, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if len(exportedFields) != len(src.fields) {
|
||||
if len(exportedFields) != len(src.valueTranscoders) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for i := range exportedFields {
|
||||
err := assignToOrSet(src.fields[i], dstElemValue.Field(exportedFields[i]).Addr().Interface())
|
||||
err := assignToOrSet(src.valueTranscoders[i], dstElemValue.Field(exportedFields[i]).Addr().Interface())
|
||||
if err != nil {
|
||||
return true, errors.Errorf("unable to assign to field %s: %v", dstElemType.Field(exportedFields[i]).Name, err)
|
||||
}
|
||||
@ -192,13 +226,8 @@ func (src CompositeType) EncodeBinary(ci *ConnInfo, buf []byte) (newBuf []byte,
|
||||
}
|
||||
|
||||
b := NewCompositeBinaryBuilder(ci, buf)
|
||||
for _, f := range src.fields {
|
||||
dt, ok := ci.DataTypeForValue(f)
|
||||
if !ok {
|
||||
return nil, errors.Errorf("unknown oid")
|
||||
}
|
||||
|
||||
b.AppendEncoder(dt.OID, f)
|
||||
for i := range src.valueTranscoders {
|
||||
b.AppendEncoder(src.fields[i].OID, src.valueTranscoders[i])
|
||||
}
|
||||
|
||||
return b.Finish()
|
||||
@ -216,7 +245,7 @@ func (dst *CompositeType) DecodeBinary(ci *ConnInfo, buf []byte) error {
|
||||
|
||||
scanner := NewCompositeBinaryScanner(ci, buf)
|
||||
|
||||
for _, f := range dst.fields {
|
||||
for _, f := range dst.valueTranscoders {
|
||||
scanner.ScanDecoder(f)
|
||||
}
|
||||
|
||||
@ -237,7 +266,7 @@ func (dst *CompositeType) DecodeText(ci *ConnInfo, buf []byte) error {
|
||||
|
||||
scanner := NewCompositeTextScanner(ci, buf)
|
||||
|
||||
for _, f := range dst.fields {
|
||||
for _, f := range dst.valueTranscoders {
|
||||
scanner.ScanDecoder(f)
|
||||
}
|
||||
|
||||
@ -259,7 +288,7 @@ func (src CompositeType) EncodeText(ci *ConnInfo, buf []byte) (newBuf []byte, er
|
||||
}
|
||||
|
||||
b := NewCompositeTextBuilder(ci, buf)
|
||||
for _, f := range src.fields {
|
||||
for _, f := range src.valueTranscoders {
|
||||
b.AppendEncoder(f)
|
||||
}
|
||||
|
||||
|
@ -14,7 +14,12 @@ import (
|
||||
)
|
||||
|
||||
func TestCompositeTypeSetAndGet(t *testing.T) {
|
||||
ct := pgtype.NewCompositeType("test", &pgtype.Text{}, &pgtype.Int4{})
|
||||
ci := pgtype.NewConnInfo()
|
||||
ct, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{
|
||||
{"a", pgtype.TextOID},
|
||||
{"b", pgtype.Int4OID},
|
||||
}, ci)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, pgtype.Undefined, ct.Get())
|
||||
|
||||
nilTests := []struct {
|
||||
@ -56,7 +61,12 @@ func TestCompositeTypeSetAndGet(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCompositeTypeAssignTo(t *testing.T) {
|
||||
ct := pgtype.NewCompositeType("test", &pgtype.Text{}, &pgtype.Int4{})
|
||||
ci := pgtype.NewConnInfo()
|
||||
ct, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{
|
||||
{"a", pgtype.TextOID},
|
||||
{"b", pgtype.Int4OID},
|
||||
}, ci)
|
||||
require.NoError(t, err)
|
||||
|
||||
{
|
||||
err := ct.Set([]interface{}{"foo", int32(42)})
|
||||
@ -168,8 +178,12 @@ create type ct_test as (
|
||||
|
||||
defer conn.Exec(context.Background(), "drop type ct_test")
|
||||
|
||||
ct := pgtype.NewCompositeType("ct_test", &pgtype.Text{}, &pgtype.Int4{})
|
||||
conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: "ct_test", OID: oid})
|
||||
ct, err := pgtype.NewCompositeType("ct_test", []pgtype.CompositeTypeField{
|
||||
{"a", pgtype.TextOID},
|
||||
{"b", pgtype.Int4OID},
|
||||
}, conn.ConnInfo())
|
||||
require.NoError(t, err)
|
||||
conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: ct.TypeName(), OID: oid})
|
||||
|
||||
// Use simple protocol to force text or binary encoding
|
||||
simpleProtocols := []bool{true, false}
|
||||
@ -221,8 +235,15 @@ func Example_composite() {
|
||||
return
|
||||
}
|
||||
|
||||
c := pgtype.NewCompositeType("mytype", &pgtype.Int4{}, &pgtype.Text{})
|
||||
conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: c, Name: "mytype", OID: oid})
|
||||
ct, err := pgtype.NewCompositeType("mytype", []pgtype.CompositeTypeField{
|
||||
{"a", pgtype.Int4OID},
|
||||
{"b", pgtype.TextOID},
|
||||
}, conn.ConnInfo())
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: ct.TypeName(), OID: oid})
|
||||
|
||||
var a int
|
||||
var b *string
|
||||
|
Loading…
x
Reference in New Issue
Block a user