mirror of https://github.com/jackc/pgx.git
Move LoadDataType to pgx.Conn
parent
f5c3eeb813
commit
b5bf9d7bb9
90
conn.go
90
conn.go
|
@ -862,3 +862,93 @@ func (c *Conn) sanitizeForSimpleQuery(sql string, args ...interface{}) (string,
|
|||
|
||||
return sanitize.SanitizeSQL(sql, valueArgs...)
|
||||
}
|
||||
|
||||
// LoadDataType inspects the database for typeName and produces a pgtype.DataType suitable for
|
||||
// registration.
|
||||
func (c *Conn) LoadDataType(ctx context.Context, typeName string) (*pgtype.DataType, error) {
|
||||
var oid uint32
|
||||
|
||||
err := c.QueryRow(ctx, "select $1::text::regtype::oid;", typeName).Scan(&oid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var typtype string
|
||||
|
||||
err = c.QueryRow(ctx, "select typtype::text from pg_type where oid=$1", oid).Scan(&typtype)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch typtype {
|
||||
case "b": // array
|
||||
elementOID, err := c.getArrayElementOID(ctx, oid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var elementCodec pgtype.Codec
|
||||
if dt, ok := c.ConnInfo().DataTypeForOID(elementOID); ok {
|
||||
if dt.Codec == nil {
|
||||
return nil, errors.New("array element OID not registered with Codec")
|
||||
}
|
||||
elementCodec = dt.Codec
|
||||
}
|
||||
|
||||
return &pgtype.DataType{Name: typeName, OID: oid, Codec: &pgtype.ArrayCodec{ElementOID: elementOID, ElementCodec: elementCodec}}, nil
|
||||
case "c": // composite
|
||||
fields, err := c.getCompositeFields(ctx, oid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &pgtype.DataType{Name: typeName, OID: oid, Codec: &pgtype.CompositeCodec{Fields: fields}}, nil
|
||||
case "e": // enum
|
||||
return &pgtype.DataType{Name: typeName, OID: oid, Codec: &pgtype.EnumCodec{}}, nil
|
||||
default:
|
||||
return &pgtype.DataType{}, errors.New("unknown typtype")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) getArrayElementOID(ctx context.Context, oid uint32) (uint32, error) {
|
||||
var typelem uint32
|
||||
|
||||
err := c.QueryRow(ctx, "select typelem from pg_type where oid=$1", oid).Scan(&typelem)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return typelem, nil
|
||||
}
|
||||
|
||||
func (c *Conn) getCompositeFields(ctx context.Context, oid uint32) ([]pgtype.CompositeCodecField, error) {
|
||||
var typrelid uint32
|
||||
|
||||
err := c.QueryRow(ctx, "select typrelid from pg_type where oid=$1", oid).Scan(&typrelid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var fields []pgtype.CompositeCodecField
|
||||
var fieldName string
|
||||
var fieldOID uint32
|
||||
_, err = c.QueryFunc(ctx, `select attname, atttypid
|
||||
from pg_attribute
|
||||
where attrelid=$1
|
||||
order by attnum`,
|
||||
[]interface{}{typrelid},
|
||||
[]interface{}{&fieldName, &fieldOID},
|
||||
func(qfr QueryFuncRow) error {
|
||||
dt, ok := c.ConnInfo().DataTypeForOID(fieldOID)
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown composite type field OID: %v", fieldOID)
|
||||
}
|
||||
fields = append(fields, pgtype.CompositeCodecField{Name: fieldName, DataType: dt})
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return fields, nil
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package pgtype_test
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
pgx "github.com/jackc/pgx/v5"
|
||||
|
@ -23,34 +24,9 @@ create type ct_test as (
|
|||
require.NoError(t, err)
|
||||
defer conn.Exec(context.Background(), "drop type ct_test")
|
||||
|
||||
var oid uint32
|
||||
err = conn.QueryRow(context.Background(), `select 'ct_test'::regtype::oid`).Scan(&oid)
|
||||
dt, err := conn.LoadDataType(context.Background(), "ct_test")
|
||||
require.NoError(t, err)
|
||||
|
||||
defer conn.Exec(context.Background(), "drop type ct_test")
|
||||
|
||||
textDataType, ok := conn.ConnInfo().DataTypeForOID(pgtype.TextOID)
|
||||
require.True(t, ok)
|
||||
|
||||
int4DataType, ok := conn.ConnInfo().DataTypeForOID(pgtype.Int4OID)
|
||||
require.True(t, ok)
|
||||
|
||||
conn.ConnInfo().RegisterDataType(pgtype.DataType{
|
||||
Name: "ct_test",
|
||||
OID: oid,
|
||||
Codec: &pgtype.CompositeCodec{
|
||||
Fields: []pgtype.CompositeCodecField{
|
||||
{
|
||||
Name: "a",
|
||||
DataType: textDataType,
|
||||
},
|
||||
{
|
||||
Name: "b",
|
||||
DataType: int4DataType,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
conn.ConnInfo().RegisterDataType(*dt)
|
||||
|
||||
formats := []struct {
|
||||
name string
|
||||
|
@ -74,3 +50,76 @@ create type ct_test as (
|
|||
require.EqualValuesf(t, 42, b, "%v", format.name)
|
||||
}
|
||||
}
|
||||
|
||||
type point3d struct {
|
||||
X, Y, Z float64
|
||||
}
|
||||
|
||||
func (p point3d) IsNull() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (p point3d) Index(i int) interface{} {
|
||||
switch i {
|
||||
case 0:
|
||||
return p.X
|
||||
case 1:
|
||||
return p.Y
|
||||
case 2:
|
||||
return p.Z
|
||||
default:
|
||||
panic("invalid index")
|
||||
}
|
||||
}
|
||||
|
||||
func (p *point3d) ScanNull() error {
|
||||
return fmt.Errorf("cannot scan NULL into point3d")
|
||||
}
|
||||
|
||||
func (p *point3d) ScanIndex(i int) interface{} {
|
||||
switch i {
|
||||
case 0:
|
||||
return &p.X
|
||||
case 1:
|
||||
return &p.Y
|
||||
case 2:
|
||||
return &p.Z
|
||||
default:
|
||||
panic("invalid index")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompositeCodecTranscodeStruct(t *testing.T) {
|
||||
conn := testutil.MustConnectPgx(t)
|
||||
defer testutil.MustCloseContext(t, conn)
|
||||
|
||||
_, err := conn.Exec(context.Background(), `drop type if exists point3d;
|
||||
|
||||
create type point3d as (
|
||||
x float8,
|
||||
y float8,
|
||||
z float8
|
||||
);`)
|
||||
require.NoError(t, err)
|
||||
defer conn.Exec(context.Background(), "drop type point3d")
|
||||
|
||||
dt, err := conn.LoadDataType(context.Background(), "point3d")
|
||||
require.NoError(t, err)
|
||||
conn.ConnInfo().RegisterDataType(*dt)
|
||||
|
||||
formats := []struct {
|
||||
name string
|
||||
code int16
|
||||
}{
|
||||
{name: "TextFormat", code: pgx.TextFormatCode},
|
||||
{name: "BinaryFormat", code: pgx.BinaryFormatCode},
|
||||
}
|
||||
|
||||
for _, format := range formats {
|
||||
input := point3d{X: 1, Y: 2, Z: 3}
|
||||
var output point3d
|
||||
err := conn.QueryRow(context.Background(), "select $1::point3d", pgx.QueryResultFormats{format.code}, input).Scan(&output)
|
||||
require.NoErrorf(t, err, "%v", format.name)
|
||||
require.Equalf(t, input, output, "%v", format.name)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,3 +0,0 @@
|
|||
# pgxtype
|
||||
|
||||
pgxtype is a helper module that connects pgx and pgtype. This package is not currently covered by semantic version guarantees. i.e. The interfaces may change without a major version release of pgtype.
|
|
@ -1,114 +0,0 @@
|
|||
package pgxtype
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
type Querier interface {
|
||||
Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error)
|
||||
Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (pgx.Rows, error)
|
||||
QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) pgx.Row
|
||||
}
|
||||
|
||||
// LoadDataType uses conn to inspect the database for typeName and produces a pgtype.DataType suitable for
|
||||
// registration on ci.
|
||||
func LoadDataType(ctx context.Context, conn Querier, ci *pgtype.ConnInfo, typeName string) (pgtype.DataType, error) {
|
||||
var oid uint32
|
||||
|
||||
err := conn.QueryRow(ctx, "select $1::text::regtype::oid;", typeName).Scan(&oid)
|
||||
if err != nil {
|
||||
return pgtype.DataType{}, err
|
||||
}
|
||||
|
||||
var typtype string
|
||||
|
||||
err = conn.QueryRow(ctx, "select typtype::text from pg_type where oid=$1", oid).Scan(&typtype)
|
||||
if err != nil {
|
||||
return pgtype.DataType{}, err
|
||||
}
|
||||
|
||||
switch typtype {
|
||||
case "b": // array
|
||||
elementOID, err := GetArrayElementOID(ctx, conn, oid)
|
||||
if err != nil {
|
||||
return pgtype.DataType{}, err
|
||||
}
|
||||
|
||||
var elementCodec pgtype.Codec
|
||||
if dt, ok := ci.DataTypeForOID(elementOID); ok {
|
||||
if dt.Codec == nil {
|
||||
return pgtype.DataType{}, errors.New("array element OID not registered with Codec")
|
||||
}
|
||||
elementCodec = dt.Codec
|
||||
}
|
||||
|
||||
return pgtype.DataType{Name: typeName, OID: oid, Codec: &pgtype.ArrayCodec{ElementOID: elementOID, ElementCodec: elementCodec}}, nil
|
||||
case "c": // composite
|
||||
panic("TODO - restore composite support")
|
||||
// fields, err := GetCompositeFields(ctx, conn, oid)
|
||||
// if err != nil {
|
||||
// return pgtype.DataType{}, err
|
||||
// }
|
||||
// ct, err := pgtype.NewCompositeType(typeName, fields, ci)
|
||||
// if err != nil {
|
||||
// return pgtype.DataType{}, err
|
||||
// }
|
||||
// return pgtype.DataType{Value: ct, Name: typeName, OID: oid}, nil
|
||||
case "e": // enum
|
||||
return pgtype.DataType{Name: typeName, OID: oid, Codec: &pgtype.EnumCodec{}}, nil
|
||||
default:
|
||||
return pgtype.DataType{}, errors.New("unknown typtype")
|
||||
}
|
||||
}
|
||||
|
||||
func GetArrayElementOID(ctx context.Context, conn Querier, oid uint32) (uint32, error) {
|
||||
var typelem uint32
|
||||
|
||||
err := conn.QueryRow(ctx, "select typelem from pg_type where oid=$1", oid).Scan(&typelem)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return typelem, nil
|
||||
}
|
||||
|
||||
// TODO - restore composite support
|
||||
// GetCompositeFields gets the fields of a composite type.
|
||||
// func GetCompositeFields(ctx context.Context, conn Querier, oid uint32) ([]pgtype.CompositeTypeField, error) {
|
||||
// var typrelid uint32
|
||||
|
||||
// err := conn.QueryRow(ctx, "select typrelid from pg_type where oid=$1", oid).Scan(&typrelid)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
|
||||
// var fields []pgtype.CompositeTypeField
|
||||
|
||||
// rows, err := conn.Query(ctx, `select attname, atttypid
|
||||
// from pg_attribute
|
||||
// where attrelid=$1
|
||||
// order by attnum`, typrelid)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
|
||||
// for rows.Next() {
|
||||
// var f pgtype.CompositeTypeField
|
||||
// err := rows.Scan(&f.Name, &f.OID)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// fields = append(fields, f)
|
||||
// }
|
||||
|
||||
// if rows.Err() != nil {
|
||||
// return nil, rows.Err()
|
||||
// }
|
||||
|
||||
// return fields, nil
|
||||
// }
|
Loading…
Reference in New Issue