From 5b6fb75669c9410fac55c599290b2b11283734bf Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 12 Nov 2022 08:10:46 -0600 Subject: [PATCH] Conn.LoadType supports domain types If the underlying type is registered then use the same Codec. fixes https://github.com/jackc/pgx/issues/1373 --- conn.go | 10 +++++++++- conn_test.go | 15 ++++++--------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/conn.go b/conn.go index cb5034bc..2ab009ff 100644 --- a/conn.go +++ b/conn.go @@ -1147,8 +1147,9 @@ func (c *Conn) LoadType(ctx context.Context, typeName string) (*pgtype.Type, err } var typtype string + var typbasetype uint32 - err = c.QueryRow(ctx, "select typtype::text from pg_type where oid=$1", oid).Scan(&typtype) + err = c.QueryRow(ctx, "select typtype::text, typbasetype from pg_type where oid=$1", oid).Scan(&typtype, &typbasetype) if err != nil { return nil, err } @@ -1173,6 +1174,13 @@ func (c *Conn) LoadType(ctx context.Context, typeName string) (*pgtype.Type, err } return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.CompositeCodec{Fields: fields}}, nil + case "d": // domain + dt, ok := c.TypeMap().TypeForOID(typbasetype) + if !ok { + return nil, errors.New("domain base type OID not registered") + } + + return &pgtype.Type{Name: typeName, OID: oid, Codec: dt.Codec}, nil case "e": // enum return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.EnumCodec{}}, nil default: diff --git a/conn_test.go b/conn_test.go index 204ff615..9cf5fd58 100644 --- a/conn_test.go +++ b/conn_test.go @@ -837,24 +837,21 @@ func TestDomainType(t *testing.T) { // uint64 but a result OID of the underlying numeric. var s string - err := conn.QueryRow(context.Background(), "select $1::uint64", "24").Scan(&s) + err := conn.QueryRow(ctx, "select $1::uint64", "24").Scan(&s) require.NoError(t, err) require.Equal(t, "24", s) // Register type - var uint64OID uint32 - err = conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='uint64';").Scan(&uint64OID) - if err != nil { - t.Fatalf("did not find uint64 OID, %v", err) - } - conn.TypeMap().RegisterType(&pgtype.Type{Name: "uint64", OID: uint64OID, Codec: pgtype.NumericCodec{}}) + uint64Type, err := conn.LoadType(ctx, "uint64") + require.NoError(t, err) + conn.TypeMap().RegisterType(uint64Type) var n uint64 - err = conn.QueryRow(context.Background(), "select $1::uint64", uint64(24)).Scan(&n) + err = conn.QueryRow(ctx, "select $1::uint64", uint64(24)).Scan(&n) require.NoError(t, err) // String is still an acceptable argument after registration - err = conn.QueryRow(context.Background(), "select $1::uint64", "7").Scan(&n) + err = conn.QueryRow(ctx, "select $1::uint64", "7").Scan(&n) if err != nil { t.Fatal(err) }