pgx/composite_type_test.go

321 lines
7.6 KiB
Go

package pgtype_test
import (
"context"
"fmt"
"os"
"testing"
"github.com/jackc/pgtype"
"github.com/jackc/pgtype/testutil"
pgx "github.com/jackc/pgx/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCompositeTypeSetAndGet(t *testing.T) {
ci := pgtype.NewConnInfo()
ct, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{
{"a", pgtype.TextOID},
{"b", pgtype.Int4OID},
}, ci)
require.NoError(t, err)
assert.Equal(t, pgtype.Undefined, ct.Get())
nilTests := []struct {
src interface{}
}{
{nil}, // nil interface
{(*[]interface{})(nil)}, // typed nil
}
for i, tt := range nilTests {
err := ct.Set(tt.src)
assert.NoErrorf(t, err, "%d", i)
assert.Equal(t, nil, ct.Get())
}
compatibleValuesTests := []struct {
src []interface{}
expected map[string]interface{}
}{
{
src: []interface{}{"foo", int32(42)},
expected: map[string]interface{}{"a": "foo", "b": int32(42)},
},
{
src: []interface{}{nil, nil},
expected: map[string]interface{}{"a": nil, "b": nil},
},
{
src: []interface{}{&pgtype.Text{String: "hi", Status: pgtype.Present}, &pgtype.Int4{Int: 7, Status: pgtype.Present}},
expected: map[string]interface{}{"a": "hi", "b": int32(7)},
},
}
for i, tt := range compatibleValuesTests {
err := ct.Set(tt.src)
assert.NoErrorf(t, err, "%d", i)
assert.EqualValues(t, tt.expected, ct.Get())
}
}
func TestCompositeTypeAssignTo(t *testing.T) {
ci := pgtype.NewConnInfo()
ct, err := pgtype.NewCompositeType("test", []pgtype.CompositeTypeField{
{"a", pgtype.TextOID},
{"b", pgtype.Int4OID},
}, ci)
require.NoError(t, err)
{
err := ct.Set([]interface{}{"foo", int32(42)})
assert.NoError(t, err)
var a string
var b int32
err = ct.AssignTo([]interface{}{&a, &b})
assert.NoError(t, err)
assert.Equal(t, "foo", a)
assert.Equal(t, int32(42), b)
}
{
err := ct.Set([]interface{}{"foo", int32(42)})
assert.NoError(t, err)
var a pgtype.Text
var b pgtype.Int4
err = ct.AssignTo([]interface{}{&a, &b})
assert.NoError(t, err)
assert.Equal(t, pgtype.Text{String: "foo", Status: pgtype.Present}, a)
assert.Equal(t, pgtype.Int4{Int: 42, Status: pgtype.Present}, b)
}
// Allow nil destination component as no-op
{
err := ct.Set([]interface{}{"foo", int32(42)})
assert.NoError(t, err)
var b int32
err = ct.AssignTo([]interface{}{nil, &b})
assert.NoError(t, err)
assert.Equal(t, int32(42), b)
}
// *[]interface{} dest when null
{
err := ct.Set(nil)
assert.NoError(t, err)
var a pgtype.Text
var b pgtype.Int4
dst := []interface{}{&a, &b}
err = ct.AssignTo(&dst)
assert.NoError(t, err)
assert.Nil(t, dst)
}
// *[]interface{} dest when not null
{
err := ct.Set([]interface{}{"foo", int32(42)})
assert.NoError(t, err)
var a pgtype.Text
var b pgtype.Int4
dst := []interface{}{&a, &b}
err = ct.AssignTo(&dst)
assert.NoError(t, err)
assert.NotNil(t, dst)
assert.Equal(t, pgtype.Text{String: "foo", Status: pgtype.Present}, a)
assert.Equal(t, pgtype.Int4{Int: 42, Status: pgtype.Present}, b)
}
// Struct fields positionally via reflection
{
err := ct.Set([]interface{}{"foo", int32(42)})
assert.NoError(t, err)
s := struct {
A string
B int32
}{}
err = ct.AssignTo(&s)
if assert.NoError(t, err) {
assert.Equal(t, "foo", s.A)
assert.Equal(t, int32(42), s.B)
}
}
}
func TestCompositeTypeTranscode(t *testing.T) {
conn := testutil.MustConnectPgx(t)
defer testutil.MustCloseContext(t, conn)
_, err := conn.Exec(context.Background(), `drop type if exists ct_test;
create type ct_test as (
a text,
b int4
);`)
require.NoError(t, err)
defer conn.Exec(context.Background(), "drop type ct_test")
var oid uint32
err = conn.QueryRow(context.Background(), `select 'ct_test'::regtype::oid`).Scan(&oid)
require.NoError(t, err)
defer conn.Exec(context.Background(), "drop type ct_test")
ct, err := pgtype.NewCompositeType("ct_test", []pgtype.CompositeTypeField{
{"a", pgtype.TextOID},
{"b", pgtype.Int4OID},
}, conn.ConnInfo())
require.NoError(t, err)
conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: ct.TypeName(), OID: oid})
// Use simple protocol to force text or binary encoding
simpleProtocols := []bool{true, false}
var a string
var b int32
for _, simpleProtocol := range simpleProtocols {
err := conn.QueryRow(context.Background(), "select $1::ct_test", pgx.QuerySimpleProtocol(simpleProtocol),
pgtype.CompositeFields{"hi", int32(42)},
).Scan(
[]interface{}{&a, &b},
)
if assert.NoErrorf(t, err, "Simple Protocol: %v", simpleProtocol) {
assert.EqualValuesf(t, "hi", a, "Simple Protocol: %v", simpleProtocol)
assert.EqualValuesf(t, 42, b, "Simple Protocol: %v", simpleProtocol)
}
}
}
// https://github.com/jackc/pgx/issues/874
func TestCompositeTypeTextDecodeNested(t *testing.T) {
newCompositeType := func(name string, fieldNames []string, vals ...pgtype.ValueTranscoder) *pgtype.CompositeType {
fields := make([]pgtype.CompositeTypeField, len(fieldNames))
for i, name := range fieldNames {
fields[i] = pgtype.CompositeTypeField{Name: name}
}
rowType, err := pgtype.NewCompositeTypeValues(name, fields, vals)
require.NoError(t, err)
return rowType
}
dimensionsType := func() pgtype.ValueTranscoder {
return newCompositeType(
"dimensions",
[]string{"width", "height"},
&pgtype.Int4{},
&pgtype.Int4{},
)
}
productImageType := func() pgtype.ValueTranscoder {
return newCompositeType(
"product_image_type",
[]string{"source", "dimensions"},
&pgtype.Text{},
dimensionsType(),
)
}
productImageSetType := newCompositeType(
"product_image_set_type",
[]string{"name", "orig_image", "images"},
&pgtype.Text{},
productImageType(),
pgtype.NewArrayType("product_image", 0, func() pgtype.ValueTranscoder {
return productImageType()
}),
)
err := productImageSetType.DecodeText(nil, []byte(`(name,"(img1,""(11,11)"")","{""(img2,\\""(22,22)\\"")"",""(img3,\\""(33,33)\\"")""}")`))
require.NoError(t, err)
}
func Example_composite() {
conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
if err != nil {
fmt.Println(err)
return
}
defer conn.Close(context.Background())
_, err = conn.Exec(context.Background(), `drop type if exists mytype;`)
if err != nil {
fmt.Println(err)
return
}
_, err = conn.Exec(context.Background(), `create type mytype as (
a int4,
b text
);`)
if err != nil {
fmt.Println(err)
return
}
defer conn.Exec(context.Background(), "drop type mytype")
var oid uint32
err = conn.QueryRow(context.Background(), `select 'mytype'::regtype::oid`).Scan(&oid)
if err != nil {
fmt.Println(err)
return
}
ct, err := pgtype.NewCompositeType("mytype", []pgtype.CompositeTypeField{
{"a", pgtype.Int4OID},
{"b", pgtype.TextOID},
}, conn.ConnInfo())
if err != nil {
fmt.Println(err)
return
}
conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: ct, Name: ct.TypeName(), OID: oid})
var a int
var b *string
err = conn.QueryRow(context.Background(), "select $1::mytype", []interface{}{2, "bar"}).Scan([]interface{}{&a, &b})
if err != nil {
fmt.Println(err)
return
}
fmt.Printf("First: a=%d b=%s\n", a, *b)
err = conn.QueryRow(context.Background(), "select (1, NULL)::mytype").Scan([]interface{}{&a, &b})
if err != nil {
fmt.Println(err)
return
}
fmt.Printf("Second: a=%d b=%v\n", a, b)
scanTarget := []interface{}{&a, &b}
err = conn.QueryRow(context.Background(), "select NULL::mytype").Scan(&scanTarget)
E(err)
fmt.Printf("Third: isNull=%v\n", scanTarget == nil)
// Output:
// First: a=2 b=bar
// Second: a=1 b=<nil>
// Third: isNull=true
}