mirror of https://github.com/jackc/pgx.git
Merge pull request #2046 from nicois/nicois/load-types
Faster/easier loading of typespull/2067/head
commit
161ce73ec1
7
conn.go
7
conn.go
|
@ -107,8 +107,10 @@ var (
|
|||
ErrTooManyRows = errors.New("too many rows in result set")
|
||||
)
|
||||
|
||||
var errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache")
|
||||
var errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache")
|
||||
var (
|
||||
errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache")
|
||||
errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache")
|
||||
)
|
||||
|
||||
// Connect establishes a connection with a PostgreSQL server with a connection string. See
|
||||
// pgconn.Connect for details.
|
||||
|
@ -843,7 +845,6 @@ func (c *Conn) getStatementDescription(
|
|||
mode QueryExecMode,
|
||||
sql string,
|
||||
) (sd *pgconn.StatementDescription, err error) {
|
||||
|
||||
switch mode {
|
||||
case QueryExecModeCacheStatement:
|
||||
if c.statementCache == nil {
|
||||
|
|
|
@ -0,0 +1,260 @@
|
|||
package pgx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
/*
|
||||
buildLoadDerivedTypesSQL generates the correct query for retrieving type information.
|
||||
|
||||
pgVersion: the major version of the PostgreSQL server
|
||||
typeNames: the names of the types to load. If nil, load all types.
|
||||
*/
|
||||
func buildLoadDerivedTypesSQL(pgVersion int64, typeNames []string) string {
|
||||
supportsMultirange := (pgVersion >= 14)
|
||||
var typeNamesClause string
|
||||
|
||||
if typeNames == nil {
|
||||
// This should not occur; this will not return any types
|
||||
typeNamesClause = "= ''"
|
||||
} else {
|
||||
typeNamesClause = "= ANY($1)"
|
||||
}
|
||||
parts := make([]string, 0, 10)
|
||||
|
||||
// Each of the type names provided might be found in pg_class or pg_type.
|
||||
// Additionally, it may or may not include a schema portion.
|
||||
parts = append(parts, `
|
||||
WITH RECURSIVE
|
||||
-- find the OIDs in pg_class which match one of the provided type names
|
||||
selected_classes(oid,reltype) AS (
|
||||
-- this query uses the namespace search path, so will match type names without a schema prefix
|
||||
SELECT pg_class.oid, pg_class.reltype
|
||||
FROM pg_catalog.pg_class
|
||||
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = pg_class.relnamespace
|
||||
WHERE pg_catalog.pg_table_is_visible(pg_class.oid)
|
||||
AND relname `, typeNamesClause, `
|
||||
UNION ALL
|
||||
-- this query will only match type names which include the schema prefix
|
||||
SELECT pg_class.oid, pg_class.reltype
|
||||
FROM pg_class
|
||||
INNER JOIN pg_namespace ON (pg_class.relnamespace = pg_namespace.oid)
|
||||
WHERE nspname || '.' || relname `, typeNamesClause, `
|
||||
),
|
||||
selected_types(oid) AS (
|
||||
-- collect the OIDs from pg_types which correspond to the selected classes
|
||||
SELECT reltype AS oid
|
||||
FROM selected_classes
|
||||
UNION ALL
|
||||
-- as well as any other type names which match our criteria
|
||||
SELECT pg_type.oid
|
||||
FROM pg_type
|
||||
LEFT OUTER JOIN pg_namespace ON (pg_type.typnamespace = pg_namespace.oid)
|
||||
WHERE typname `, typeNamesClause, `
|
||||
OR nspname || '.' || typname `, typeNamesClause, `
|
||||
),
|
||||
-- this builds a parent/child mapping of objects, allowing us to know
|
||||
-- all the child (ie: dependent) types that a parent (type) requires
|
||||
-- As can be seen, there are 3 ways this can occur (the last of which
|
||||
-- is due to being a composite class, where the composite fields are children)
|
||||
pc(parent, child) AS (
|
||||
SELECT parent.oid, parent.typelem
|
||||
FROM pg_type parent
|
||||
WHERE parent.typtype = 'b' AND parent.typelem != 0
|
||||
UNION ALL
|
||||
SELECT parent.oid, parent.typbasetype
|
||||
FROM pg_type parent
|
||||
WHERE parent.typtypmod = -1 AND parent.typbasetype != 0
|
||||
UNION ALL
|
||||
SELECT pg_type.oid, atttypid
|
||||
FROM pg_attribute
|
||||
INNER JOIN pg_class ON (pg_class.oid = pg_attribute.attrelid)
|
||||
INNER JOIN pg_type ON (pg_type.oid = pg_class.reltype)
|
||||
WHERE NOT attisdropped
|
||||
AND attnum > 0
|
||||
),
|
||||
-- Now construct a recursive query which includes a 'depth' element.
|
||||
-- This is used to ensure that the "youngest" children are registered before
|
||||
-- their parents.
|
||||
relationships(parent, child, depth) AS (
|
||||
SELECT DISTINCT 0::OID, selected_types.oid, 0
|
||||
FROM selected_types
|
||||
UNION ALL
|
||||
SELECT pg_type.oid AS parent, pg_attribute.atttypid AS child, 1
|
||||
FROM selected_classes c
|
||||
inner join pg_type ON (c.reltype = pg_type.oid)
|
||||
inner join pg_attribute on (c.oid = pg_attribute.attrelid)
|
||||
UNION ALL
|
||||
SELECT pc.parent, pc.child, relationships.depth + 1
|
||||
FROM pc
|
||||
INNER JOIN relationships ON (pc.parent = relationships.child)
|
||||
),
|
||||
-- composite fields need to be encapsulated as a couple of arrays to provide the required information for registration
|
||||
composite AS (
|
||||
SELECT pg_type.oid, ARRAY_AGG(attname ORDER BY attnum) AS attnames, ARRAY_AGG(atttypid ORDER BY ATTNUM) AS atttypids
|
||||
FROM pg_attribute
|
||||
INNER JOIN pg_class ON (pg_class.oid = pg_attribute.attrelid)
|
||||
INNER JOIN pg_type ON (pg_type.oid = pg_class.reltype)
|
||||
WHERE NOT attisdropped
|
||||
AND attnum > 0
|
||||
GROUP BY pg_type.oid
|
||||
)
|
||||
-- Bring together this information, showing all the information which might possibly be required
|
||||
-- to complete the registration, applying filters to only show the items which relate to the selected
|
||||
-- types/classes.
|
||||
SELECT typname,
|
||||
pg_namespace.nspname,
|
||||
typtype,
|
||||
typbasetype,
|
||||
typelem,
|
||||
pg_type.oid,`)
|
||||
if supportsMultirange {
|
||||
parts = append(parts, `
|
||||
COALESCE(multirange.rngtypid, 0) AS rngtypid,`)
|
||||
} else {
|
||||
parts = append(parts, `
|
||||
0 AS rngtypid,`)
|
||||
}
|
||||
parts = append(parts, `
|
||||
COALESCE(pg_range.rngsubtype, 0) AS rngsubtype,
|
||||
attnames, atttypids
|
||||
FROM relationships
|
||||
INNER JOIN pg_type ON (pg_type.oid = relationships.child)
|
||||
LEFT OUTER JOIN pg_range ON (pg_type.oid = pg_range.rngtypid)`)
|
||||
if supportsMultirange {
|
||||
parts = append(parts, `
|
||||
LEFT OUTER JOIN pg_range multirange ON (pg_type.oid = multirange.rngmultitypid)`)
|
||||
}
|
||||
|
||||
parts = append(parts, `
|
||||
LEFT OUTER JOIN composite USING (oid)
|
||||
LEFT OUTER JOIN pg_namespace ON (pg_type.typnamespace = pg_namespace.oid)
|
||||
WHERE NOT (typtype = 'b' AND typelem = 0)`)
|
||||
parts = append(parts, `
|
||||
GROUP BY typname, pg_namespace.nspname, typtype, typbasetype, typelem, pg_type.oid, pg_range.rngsubtype,`)
|
||||
if supportsMultirange {
|
||||
parts = append(parts, `
|
||||
multirange.rngtypid,`)
|
||||
}
|
||||
parts = append(parts, `
|
||||
attnames, atttypids
|
||||
ORDER BY MAX(depth) desc, typname;`)
|
||||
return strings.Join(parts, "")
|
||||
}
|
||||
|
||||
type derivedTypeInfo struct {
|
||||
Oid, Typbasetype, Typelem, Rngsubtype, Rngtypid uint32
|
||||
TypeName, Typtype, NspName string
|
||||
Attnames []string
|
||||
Atttypids []uint32
|
||||
}
|
||||
|
||||
// LoadTypes performs a single (complex) query, returning all the required
|
||||
// information to register the named types, as well as any other types directly
|
||||
// or indirectly required to complete the registration.
|
||||
// The result of this call can be passed into RegisterTypes to complete the process.
|
||||
func (c *Conn) LoadTypes(ctx context.Context, typeNames []string) ([]*pgtype.Type, error) {
|
||||
m := c.TypeMap()
|
||||
if typeNames == nil || len(typeNames) == 0 {
|
||||
return nil, fmt.Errorf("No type names were supplied.")
|
||||
}
|
||||
|
||||
// Disregard server version errors. This will result in
|
||||
// the SQL not support recent structures such as multirange
|
||||
serverVersion, _ := serverVersion(c)
|
||||
sql := buildLoadDerivedTypesSQL(serverVersion, typeNames)
|
||||
var rows Rows
|
||||
var err error
|
||||
if typeNames == nil {
|
||||
rows, err = c.Query(ctx, sql, QueryExecModeSimpleProtocol)
|
||||
} else {
|
||||
rows, err = c.Query(ctx, sql, QueryExecModeSimpleProtocol, typeNames)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("While generating load types query: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
result := make([]*pgtype.Type, 0, 100)
|
||||
for rows.Next() {
|
||||
ti := derivedTypeInfo{}
|
||||
err = rows.Scan(&ti.TypeName, &ti.NspName, &ti.Typtype, &ti.Typbasetype, &ti.Typelem, &ti.Oid, &ti.Rngtypid, &ti.Rngsubtype, &ti.Attnames, &ti.Atttypids)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("While scanning type information: %w", err)
|
||||
}
|
||||
var type_ *pgtype.Type
|
||||
switch ti.Typtype {
|
||||
case "b": // array
|
||||
dt, ok := m.TypeForOID(ti.Typelem)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Array element OID %v not registered while loading pgtype %q", ti.Typelem, ti.TypeName)
|
||||
}
|
||||
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.ArrayCodec{ElementType: dt}}
|
||||
case "c": // composite
|
||||
var fields []pgtype.CompositeCodecField
|
||||
for i, fieldName := range ti.Attnames {
|
||||
dt, ok := m.TypeForOID(ti.Atttypids[i])
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Unknown field for composite type %q: field %q (OID %v) is not already registered.", ti.TypeName, fieldName, ti.Atttypids[i])
|
||||
}
|
||||
fields = append(fields, pgtype.CompositeCodecField{Name: fieldName, Type: dt})
|
||||
}
|
||||
|
||||
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.CompositeCodec{Fields: fields}}
|
||||
case "d": // domain
|
||||
dt, ok := m.TypeForOID(ti.Typbasetype)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Domain base type OID %v was not already registered, needed for %q", ti.Typbasetype, ti.TypeName)
|
||||
}
|
||||
|
||||
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: dt.Codec}
|
||||
case "e": // enum
|
||||
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.EnumCodec{}}
|
||||
case "r": // range
|
||||
dt, ok := m.TypeForOID(ti.Rngsubtype)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Range element OID %v was not already registered, needed for %q", ti.Rngsubtype, ti.TypeName)
|
||||
}
|
||||
|
||||
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.RangeCodec{ElementType: dt}}
|
||||
case "m": // multirange
|
||||
dt, ok := m.TypeForOID(ti.Rngtypid)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Multirange element OID %v was not already registered, needed for %q", ti.Rngtypid, ti.TypeName)
|
||||
}
|
||||
|
||||
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.MultirangeCodec{ElementType: dt}}
|
||||
default:
|
||||
return nil, fmt.Errorf("Unknown typtype %q was found while registering %q", ti.Typtype, ti.TypeName)
|
||||
}
|
||||
if type_ != nil {
|
||||
m.RegisterType(type_)
|
||||
if ti.NspName != "" {
|
||||
m.RegisterType(&pgtype.Type{Name: ti.NspName + "." + type_.Name, OID: type_.OID, Codec: type_.Codec})
|
||||
}
|
||||
result = append(result, type_)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// serverVersion returns the postgresql server version.
|
||||
func serverVersion(c *Conn) (int64, error) {
|
||||
serverVersionStr := c.PgConn().ParameterStatus("server_version")
|
||||
serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr)
|
||||
// if not PostgreSQL do nothing
|
||||
if serverVersionStr == "" {
|
||||
return 0, fmt.Errorf("Cannot identify server version in %q", serverVersionStr)
|
||||
}
|
||||
|
||||
version, err := strconv.ParseInt(serverVersionStr, 10, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("postgres version parsing failed: %w", err)
|
||||
}
|
||||
return version, nil
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
package pgx_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCompositeCodecTranscodeWithLoadTypes(t *testing.T) {
|
||||
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
|
||||
|
||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
_, err := conn.Exec(ctx, `
|
||||
drop type if exists dtype_test;
|
||||
drop domain if exists anotheruint64;
|
||||
|
||||
create domain anotheruint64 as numeric(20,0);
|
||||
create type dtype_test as (
|
||||
a text,
|
||||
b int4,
|
||||
c anotheruint64,
|
||||
d anotheruint64[]
|
||||
);`)
|
||||
require.NoError(t, err)
|
||||
defer conn.Exec(ctx, "drop type dtype_test")
|
||||
defer conn.Exec(ctx, "drop domain anotheruint64")
|
||||
|
||||
types, err := conn.LoadTypes(ctx, []string{"dtype_test"})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, types, 3)
|
||||
require.Equal(t, types[0].Name, "anotheruint64")
|
||||
require.Equal(t, types[1].Name, "_anotheruint64")
|
||||
require.Equal(t, types[2].Name, "dtype_test")
|
||||
})
|
||||
}
|
4
go.sum
4
go.sum
|
@ -4,10 +4,6 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
|||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA=
|
||||
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
package pgtype_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
pgx "github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDerivedTypes(t *testing.T) {
|
||||
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
|
||||
|
||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
_, err := conn.Exec(ctx, `
|
||||
drop type if exists dt_test;
|
||||
drop domain if exists dt_uint64;
|
||||
|
||||
create domain dt_uint64 as numeric(20,0);
|
||||
create type dt_test as (
|
||||
a text,
|
||||
b dt_uint64,
|
||||
c dt_uint64[]
|
||||
);`)
|
||||
require.NoError(t, err)
|
||||
defer conn.Exec(ctx, "drop domain dt_uint64")
|
||||
defer conn.Exec(ctx, "drop type dt_test")
|
||||
|
||||
dtypes, err := conn.LoadTypes(ctx, []string{"dt_test"})
|
||||
require.Len(t, dtypes, 3)
|
||||
require.Equal(t, dtypes[0].Name, "dt_uint64")
|
||||
require.Equal(t, dtypes[1].Name, "_dt_uint64")
|
||||
require.Equal(t, dtypes[2].Name, "dt_test")
|
||||
require.NoError(t, err)
|
||||
conn.TypeMap().RegisterTypes(dtypes)
|
||||
|
||||
formats := []struct {
|
||||
name string
|
||||
code int16
|
||||
}{
|
||||
{name: "TextFormat", code: pgx.TextFormatCode},
|
||||
{name: "BinaryFormat", code: pgx.BinaryFormatCode},
|
||||
}
|
||||
|
||||
for _, format := range formats {
|
||||
var a string
|
||||
var b uint64
|
||||
var c *[]uint64
|
||||
|
||||
row := conn.QueryRow(ctx, "select $1::dt_test", pgx.QueryResultFormats{format.code}, pgtype.CompositeFields{"hi", uint64(42), []uint64{10, 20, 30}})
|
||||
err := row.Scan(pgtype.CompositeFields{&a, &b, &c})
|
||||
require.NoError(t, err)
|
||||
require.EqualValuesf(t, "hi", a, "%v", format.name)
|
||||
require.EqualValuesf(t, 42, b, "%v", format.name)
|
||||
}
|
||||
})
|
||||
}
|
|
@ -214,6 +214,15 @@ type Map struct {
|
|||
TryWrapScanPlanFuncs []TryWrapScanPlanFunc
|
||||
}
|
||||
|
||||
// Copy returns a new Map containing the same registered types.
|
||||
func (m *Map) Copy() *Map {
|
||||
newMap := NewMap()
|
||||
for _, type_ := range m.oidToType {
|
||||
newMap.RegisterType(type_)
|
||||
}
|
||||
return newMap
|
||||
}
|
||||
|
||||
func NewMap() *Map {
|
||||
defaultMapInitOnce.Do(initDefaultMap)
|
||||
|
||||
|
@ -248,6 +257,13 @@ func NewMap() *Map {
|
|||
}
|
||||
}
|
||||
|
||||
// RegisterTypes registers multiple data types in the sequence they are provided.
|
||||
func (m *Map) RegisterTypes(types []*Type) {
|
||||
for _, t := range types {
|
||||
m.RegisterType(t)
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterType registers a data type with the Map. t must not be mutated after it is registered.
|
||||
func (m *Map) RegisterType(t *Type) {
|
||||
m.oidToType[t.OID] = t
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
package pgx_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
_ "github.com/jackc/pgx/v5/stdlib"
|
||||
)
|
||||
|
||||
func skipCockroachDB(t testing.TB, msg string) {
|
||||
conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close(context.Background())
|
||||
|
||||
if conn.PgConn().ParameterStatus("crdb_version") != "" {
|
||||
t.Skip(msg)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue