From b5bf9d7bb9fc3d84e14e91778e30009fce72809f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 29 Jan 2022 16:32:05 -0600 Subject: [PATCH] Move LoadDataType to pgx.Conn --- conn.go | 90 ++++++++++++++++++++++++++++++ pgtype/composite_test.go | 103 +++++++++++++++++++++++++--------- pgtype/pgxtype/README.md | 3 - pgtype/pgxtype/pgxtype.go | 114 -------------------------------------- 4 files changed, 166 insertions(+), 144 deletions(-) delete mode 100644 pgtype/pgxtype/README.md delete mode 100644 pgtype/pgxtype/pgxtype.go diff --git a/conn.go b/conn.go index 4412e174..11f275a6 100644 --- a/conn.go +++ b/conn.go @@ -862,3 +862,93 @@ func (c *Conn) sanitizeForSimpleQuery(sql string, args ...interface{}) (string, return sanitize.SanitizeSQL(sql, valueArgs...) } + +// LoadDataType inspects the database for typeName and produces a pgtype.DataType suitable for +// registration. +func (c *Conn) LoadDataType(ctx context.Context, typeName string) (*pgtype.DataType, error) { + var oid uint32 + + err := c.QueryRow(ctx, "select $1::text::regtype::oid;", typeName).Scan(&oid) + if err != nil { + return nil, err + } + + var typtype string + + err = c.QueryRow(ctx, "select typtype::text from pg_type where oid=$1", oid).Scan(&typtype) + if err != nil { + return nil, err + } + + switch typtype { + case "b": // array + elementOID, err := c.getArrayElementOID(ctx, oid) + if err != nil { + return nil, err + } + + var elementCodec pgtype.Codec + if dt, ok := c.ConnInfo().DataTypeForOID(elementOID); ok { + if dt.Codec == nil { + return nil, errors.New("array element OID not registered with Codec") + } + elementCodec = dt.Codec + } + + return &pgtype.DataType{Name: typeName, OID: oid, Codec: &pgtype.ArrayCodec{ElementOID: elementOID, ElementCodec: elementCodec}}, nil + case "c": // composite + fields, err := c.getCompositeFields(ctx, oid) + if err != nil { + return nil, err + } + + return &pgtype.DataType{Name: typeName, OID: oid, Codec: &pgtype.CompositeCodec{Fields: fields}}, nil + case "e": // enum + return &pgtype.DataType{Name: typeName, OID: oid, Codec: &pgtype.EnumCodec{}}, nil + default: + return &pgtype.DataType{}, errors.New("unknown typtype") + } +} + +func (c *Conn) getArrayElementOID(ctx context.Context, oid uint32) (uint32, error) { + var typelem uint32 + + err := c.QueryRow(ctx, "select typelem from pg_type where oid=$1", oid).Scan(&typelem) + if err != nil { + return 0, err + } + + return typelem, nil +} + +func (c *Conn) getCompositeFields(ctx context.Context, oid uint32) ([]pgtype.CompositeCodecField, error) { + var typrelid uint32 + + err := c.QueryRow(ctx, "select typrelid from pg_type where oid=$1", oid).Scan(&typrelid) + if err != nil { + return nil, err + } + + var fields []pgtype.CompositeCodecField + var fieldName string + var fieldOID uint32 + _, err = c.QueryFunc(ctx, `select attname, atttypid +from pg_attribute +where attrelid=$1 +order by attnum`, + []interface{}{typrelid}, + []interface{}{&fieldName, &fieldOID}, + func(qfr QueryFuncRow) error { + dt, ok := c.ConnInfo().DataTypeForOID(fieldOID) + if !ok { + return fmt.Errorf("unknown composite type field OID: %v", fieldOID) + } + fields = append(fields, pgtype.CompositeCodecField{Name: fieldName, DataType: dt}) + return nil + }) + if err != nil { + return nil, err + } + + return fields, nil +} diff --git a/pgtype/composite_test.go b/pgtype/composite_test.go index ba91de80..c9319c2d 100644 --- a/pgtype/composite_test.go +++ b/pgtype/composite_test.go @@ -2,6 +2,7 @@ package pgtype_test import ( "context" + "fmt" "testing" pgx "github.com/jackc/pgx/v5" @@ -23,34 +24,9 @@ create type ct_test as ( require.NoError(t, err) defer conn.Exec(context.Background(), "drop type ct_test") - var oid uint32 - err = conn.QueryRow(context.Background(), `select 'ct_test'::regtype::oid`).Scan(&oid) + dt, err := conn.LoadDataType(context.Background(), "ct_test") require.NoError(t, err) - - defer conn.Exec(context.Background(), "drop type ct_test") - - textDataType, ok := conn.ConnInfo().DataTypeForOID(pgtype.TextOID) - require.True(t, ok) - - int4DataType, ok := conn.ConnInfo().DataTypeForOID(pgtype.Int4OID) - require.True(t, ok) - - conn.ConnInfo().RegisterDataType(pgtype.DataType{ - Name: "ct_test", - OID: oid, - Codec: &pgtype.CompositeCodec{ - Fields: []pgtype.CompositeCodecField{ - { - Name: "a", - DataType: textDataType, - }, - { - Name: "b", - DataType: int4DataType, - }, - }, - }, - }) + conn.ConnInfo().RegisterDataType(*dt) formats := []struct { name string @@ -74,3 +50,76 @@ create type ct_test as ( require.EqualValuesf(t, 42, b, "%v", format.name) } } + +type point3d struct { + X, Y, Z float64 +} + +func (p point3d) IsNull() bool { + return false +} + +func (p point3d) Index(i int) interface{} { + switch i { + case 0: + return p.X + case 1: + return p.Y + case 2: + return p.Z + default: + panic("invalid index") + } +} + +func (p *point3d) ScanNull() error { + return fmt.Errorf("cannot scan NULL into point3d") +} + +func (p *point3d) ScanIndex(i int) interface{} { + switch i { + case 0: + return &p.X + case 1: + return &p.Y + case 2: + return &p.Z + default: + panic("invalid index") + } +} + +func TestCompositeCodecTranscodeStruct(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + _, err := conn.Exec(context.Background(), `drop type if exists point3d; + +create type point3d as ( + x float8, + y float8, + z float8 +);`) + require.NoError(t, err) + defer conn.Exec(context.Background(), "drop type point3d") + + dt, err := conn.LoadDataType(context.Background(), "point3d") + require.NoError(t, err) + conn.ConnInfo().RegisterDataType(*dt) + + formats := []struct { + name string + code int16 + }{ + {name: "TextFormat", code: pgx.TextFormatCode}, + {name: "BinaryFormat", code: pgx.BinaryFormatCode}, + } + + for _, format := range formats { + input := point3d{X: 1, Y: 2, Z: 3} + var output point3d + err := conn.QueryRow(context.Background(), "select $1::point3d", pgx.QueryResultFormats{format.code}, input).Scan(&output) + require.NoErrorf(t, err, "%v", format.name) + require.Equalf(t, input, output, "%v", format.name) + } +} diff --git a/pgtype/pgxtype/README.md b/pgtype/pgxtype/README.md deleted file mode 100644 index a070111f..00000000 --- a/pgtype/pgxtype/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# pgxtype - -pgxtype is a helper module that connects pgx and pgtype. This package is not currently covered by semantic version guarantees. i.e. The interfaces may change without a major version release of pgtype. diff --git a/pgtype/pgxtype/pgxtype.go b/pgtype/pgxtype/pgxtype.go deleted file mode 100644 index 6436f01b..00000000 --- a/pgtype/pgxtype/pgxtype.go +++ /dev/null @@ -1,114 +0,0 @@ -package pgxtype - -import ( - "context" - "errors" - - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgconn" - "github.com/jackc/pgx/v5/pgtype" -) - -type Querier interface { - Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) - Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (pgx.Rows, error) - QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) pgx.Row -} - -// LoadDataType uses conn to inspect the database for typeName and produces a pgtype.DataType suitable for -// registration on ci. -func LoadDataType(ctx context.Context, conn Querier, ci *pgtype.ConnInfo, typeName string) (pgtype.DataType, error) { - var oid uint32 - - err := conn.QueryRow(ctx, "select $1::text::regtype::oid;", typeName).Scan(&oid) - if err != nil { - return pgtype.DataType{}, err - } - - var typtype string - - err = conn.QueryRow(ctx, "select typtype::text from pg_type where oid=$1", oid).Scan(&typtype) - if err != nil { - return pgtype.DataType{}, err - } - - switch typtype { - case "b": // array - elementOID, err := GetArrayElementOID(ctx, conn, oid) - if err != nil { - return pgtype.DataType{}, err - } - - var elementCodec pgtype.Codec - if dt, ok := ci.DataTypeForOID(elementOID); ok { - if dt.Codec == nil { - return pgtype.DataType{}, errors.New("array element OID not registered with Codec") - } - elementCodec = dt.Codec - } - - return pgtype.DataType{Name: typeName, OID: oid, Codec: &pgtype.ArrayCodec{ElementOID: elementOID, ElementCodec: elementCodec}}, nil - case "c": // composite - panic("TODO - restore composite support") - // fields, err := GetCompositeFields(ctx, conn, oid) - // if err != nil { - // return pgtype.DataType{}, err - // } - // ct, err := pgtype.NewCompositeType(typeName, fields, ci) - // if err != nil { - // return pgtype.DataType{}, err - // } - // return pgtype.DataType{Value: ct, Name: typeName, OID: oid}, nil - case "e": // enum - return pgtype.DataType{Name: typeName, OID: oid, Codec: &pgtype.EnumCodec{}}, nil - default: - return pgtype.DataType{}, errors.New("unknown typtype") - } -} - -func GetArrayElementOID(ctx context.Context, conn Querier, oid uint32) (uint32, error) { - var typelem uint32 - - err := conn.QueryRow(ctx, "select typelem from pg_type where oid=$1", oid).Scan(&typelem) - if err != nil { - return 0, err - } - - return typelem, nil -} - -// TODO - restore composite support -// GetCompositeFields gets the fields of a composite type. -// func GetCompositeFields(ctx context.Context, conn Querier, oid uint32) ([]pgtype.CompositeTypeField, error) { -// var typrelid uint32 - -// err := conn.QueryRow(ctx, "select typrelid from pg_type where oid=$1", oid).Scan(&typrelid) -// if err != nil { -// return nil, err -// } - -// var fields []pgtype.CompositeTypeField - -// rows, err := conn.Query(ctx, `select attname, atttypid -// from pg_attribute -// where attrelid=$1 -// order by attnum`, typrelid) -// if err != nil { -// return nil, err -// } - -// for rows.Next() { -// var f pgtype.CompositeTypeField -// err := rows.Scan(&f.Name, &f.OID) -// if err != nil { -// return nil, err -// } -// fields = append(fields, f) -// } - -// if rows.Err() != nil { -// return nil, rows.Err() -// } - -// return fields, nil -// }