Move LoadDataType to pgx.Conn

query-exec-mode
Jack Christensen 2022-01-29 16:32:05 -06:00
parent f5c3eeb813
commit b5bf9d7bb9
4 changed files with 166 additions and 144 deletions

90
conn.go
View File

@ -862,3 +862,93 @@ func (c *Conn) sanitizeForSimpleQuery(sql string, args ...interface{}) (string,
return sanitize.SanitizeSQL(sql, valueArgs...)
}
// LoadDataType inspects the database for typeName and produces a pgtype.DataType suitable for
// registration.
func (c *Conn) LoadDataType(ctx context.Context, typeName string) (*pgtype.DataType, error) {
var oid uint32
err := c.QueryRow(ctx, "select $1::text::regtype::oid;", typeName).Scan(&oid)
if err != nil {
return nil, err
}
var typtype string
err = c.QueryRow(ctx, "select typtype::text from pg_type where oid=$1", oid).Scan(&typtype)
if err != nil {
return nil, err
}
switch typtype {
case "b": // array
elementOID, err := c.getArrayElementOID(ctx, oid)
if err != nil {
return nil, err
}
var elementCodec pgtype.Codec
if dt, ok := c.ConnInfo().DataTypeForOID(elementOID); ok {
if dt.Codec == nil {
return nil, errors.New("array element OID not registered with Codec")
}
elementCodec = dt.Codec
}
return &pgtype.DataType{Name: typeName, OID: oid, Codec: &pgtype.ArrayCodec{ElementOID: elementOID, ElementCodec: elementCodec}}, nil
case "c": // composite
fields, err := c.getCompositeFields(ctx, oid)
if err != nil {
return nil, err
}
return &pgtype.DataType{Name: typeName, OID: oid, Codec: &pgtype.CompositeCodec{Fields: fields}}, nil
case "e": // enum
return &pgtype.DataType{Name: typeName, OID: oid, Codec: &pgtype.EnumCodec{}}, nil
default:
return &pgtype.DataType{}, errors.New("unknown typtype")
}
}
func (c *Conn) getArrayElementOID(ctx context.Context, oid uint32) (uint32, error) {
var typelem uint32
err := c.QueryRow(ctx, "select typelem from pg_type where oid=$1", oid).Scan(&typelem)
if err != nil {
return 0, err
}
return typelem, nil
}
func (c *Conn) getCompositeFields(ctx context.Context, oid uint32) ([]pgtype.CompositeCodecField, error) {
var typrelid uint32
err := c.QueryRow(ctx, "select typrelid from pg_type where oid=$1", oid).Scan(&typrelid)
if err != nil {
return nil, err
}
var fields []pgtype.CompositeCodecField
var fieldName string
var fieldOID uint32
_, err = c.QueryFunc(ctx, `select attname, atttypid
from pg_attribute
where attrelid=$1
order by attnum`,
[]interface{}{typrelid},
[]interface{}{&fieldName, &fieldOID},
func(qfr QueryFuncRow) error {
dt, ok := c.ConnInfo().DataTypeForOID(fieldOID)
if !ok {
return fmt.Errorf("unknown composite type field OID: %v", fieldOID)
}
fields = append(fields, pgtype.CompositeCodecField{Name: fieldName, DataType: dt})
return nil
})
if err != nil {
return nil, err
}
return fields, nil
}

View File

@ -2,6 +2,7 @@ package pgtype_test
import (
"context"
"fmt"
"testing"
pgx "github.com/jackc/pgx/v5"
@ -23,34 +24,9 @@ create type ct_test as (
require.NoError(t, err)
defer conn.Exec(context.Background(), "drop type ct_test")
var oid uint32
err = conn.QueryRow(context.Background(), `select 'ct_test'::regtype::oid`).Scan(&oid)
dt, err := conn.LoadDataType(context.Background(), "ct_test")
require.NoError(t, err)
defer conn.Exec(context.Background(), "drop type ct_test")
textDataType, ok := conn.ConnInfo().DataTypeForOID(pgtype.TextOID)
require.True(t, ok)
int4DataType, ok := conn.ConnInfo().DataTypeForOID(pgtype.Int4OID)
require.True(t, ok)
conn.ConnInfo().RegisterDataType(pgtype.DataType{
Name: "ct_test",
OID: oid,
Codec: &pgtype.CompositeCodec{
Fields: []pgtype.CompositeCodecField{
{
Name: "a",
DataType: textDataType,
},
{
Name: "b",
DataType: int4DataType,
},
},
},
})
conn.ConnInfo().RegisterDataType(*dt)
formats := []struct {
name string
@ -74,3 +50,76 @@ create type ct_test as (
require.EqualValuesf(t, 42, b, "%v", format.name)
}
}
type point3d struct {
X, Y, Z float64
}
func (p point3d) IsNull() bool {
return false
}
func (p point3d) Index(i int) interface{} {
switch i {
case 0:
return p.X
case 1:
return p.Y
case 2:
return p.Z
default:
panic("invalid index")
}
}
func (p *point3d) ScanNull() error {
return fmt.Errorf("cannot scan NULL into point3d")
}
func (p *point3d) ScanIndex(i int) interface{} {
switch i {
case 0:
return &p.X
case 1:
return &p.Y
case 2:
return &p.Z
default:
panic("invalid index")
}
}
func TestCompositeCodecTranscodeStruct(t *testing.T) {
conn := testutil.MustConnectPgx(t)
defer testutil.MustCloseContext(t, conn)
_, err := conn.Exec(context.Background(), `drop type if exists point3d;
create type point3d as (
x float8,
y float8,
z float8
);`)
require.NoError(t, err)
defer conn.Exec(context.Background(), "drop type point3d")
dt, err := conn.LoadDataType(context.Background(), "point3d")
require.NoError(t, err)
conn.ConnInfo().RegisterDataType(*dt)
formats := []struct {
name string
code int16
}{
{name: "TextFormat", code: pgx.TextFormatCode},
{name: "BinaryFormat", code: pgx.BinaryFormatCode},
}
for _, format := range formats {
input := point3d{X: 1, Y: 2, Z: 3}
var output point3d
err := conn.QueryRow(context.Background(), "select $1::point3d", pgx.QueryResultFormats{format.code}, input).Scan(&output)
require.NoErrorf(t, err, "%v", format.name)
require.Equalf(t, input, output, "%v", format.name)
}
}

View File

@ -1,3 +0,0 @@
# pgxtype
pgxtype is a helper module that connects pgx and pgtype. This package is not currently covered by semantic version guarantees. i.e. The interfaces may change without a major version release of pgtype.

View File

@ -1,114 +0,0 @@
package pgxtype
import (
"context"
"errors"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype"
)
type Querier interface {
Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error)
Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (pgx.Rows, error)
QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) pgx.Row
}
// LoadDataType uses conn to inspect the database for typeName and produces a pgtype.DataType suitable for
// registration on ci.
func LoadDataType(ctx context.Context, conn Querier, ci *pgtype.ConnInfo, typeName string) (pgtype.DataType, error) {
var oid uint32
err := conn.QueryRow(ctx, "select $1::text::regtype::oid;", typeName).Scan(&oid)
if err != nil {
return pgtype.DataType{}, err
}
var typtype string
err = conn.QueryRow(ctx, "select typtype::text from pg_type where oid=$1", oid).Scan(&typtype)
if err != nil {
return pgtype.DataType{}, err
}
switch typtype {
case "b": // array
elementOID, err := GetArrayElementOID(ctx, conn, oid)
if err != nil {
return pgtype.DataType{}, err
}
var elementCodec pgtype.Codec
if dt, ok := ci.DataTypeForOID(elementOID); ok {
if dt.Codec == nil {
return pgtype.DataType{}, errors.New("array element OID not registered with Codec")
}
elementCodec = dt.Codec
}
return pgtype.DataType{Name: typeName, OID: oid, Codec: &pgtype.ArrayCodec{ElementOID: elementOID, ElementCodec: elementCodec}}, nil
case "c": // composite
panic("TODO - restore composite support")
// fields, err := GetCompositeFields(ctx, conn, oid)
// if err != nil {
// return pgtype.DataType{}, err
// }
// ct, err := pgtype.NewCompositeType(typeName, fields, ci)
// if err != nil {
// return pgtype.DataType{}, err
// }
// return pgtype.DataType{Value: ct, Name: typeName, OID: oid}, nil
case "e": // enum
return pgtype.DataType{Name: typeName, OID: oid, Codec: &pgtype.EnumCodec{}}, nil
default:
return pgtype.DataType{}, errors.New("unknown typtype")
}
}
func GetArrayElementOID(ctx context.Context, conn Querier, oid uint32) (uint32, error) {
var typelem uint32
err := conn.QueryRow(ctx, "select typelem from pg_type where oid=$1", oid).Scan(&typelem)
if err != nil {
return 0, err
}
return typelem, nil
}
// TODO - restore composite support
// GetCompositeFields gets the fields of a composite type.
// func GetCompositeFields(ctx context.Context, conn Querier, oid uint32) ([]pgtype.CompositeTypeField, error) {
// var typrelid uint32
// err := conn.QueryRow(ctx, "select typrelid from pg_type where oid=$1", oid).Scan(&typrelid)
// if err != nil {
// return nil, err
// }
// var fields []pgtype.CompositeTypeField
// rows, err := conn.Query(ctx, `select attname, atttypid
// from pg_attribute
// where attrelid=$1
// order by attnum`, typrelid)
// if err != nil {
// return nil, err
// }
// for rows.Next() {
// var f pgtype.CompositeTypeField
// err := rows.Scan(&f.Name, &f.OID)
// if err != nil {
// return nil, err
// }
// fields = append(fields, f)
// }
// if rows.Err() != nil {
// return nil, rows.Err()
// }
// return fields, nil
// }