Conn.LoadType supports domain types

If the underlying type is registered then use the same Codec.

fixes https://github.com/jackc/pgx/issues/1373
pull/1379/head
Jack Christensen 2022-11-12 08:10:46 -06:00
parent b265fedd75
commit 5b6fb75669
2 changed files with 15 additions and 10 deletions

10
conn.go
View File

@ -1147,8 +1147,9 @@ func (c *Conn) LoadType(ctx context.Context, typeName string) (*pgtype.Type, err
}
var typtype string
var typbasetype uint32
err = c.QueryRow(ctx, "select typtype::text from pg_type where oid=$1", oid).Scan(&typtype)
err = c.QueryRow(ctx, "select typtype::text, typbasetype from pg_type where oid=$1", oid).Scan(&typtype, &typbasetype)
if err != nil {
return nil, err
}
@ -1173,6 +1174,13 @@ func (c *Conn) LoadType(ctx context.Context, typeName string) (*pgtype.Type, err
}
return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.CompositeCodec{Fields: fields}}, nil
case "d": // domain
dt, ok := c.TypeMap().TypeForOID(typbasetype)
if !ok {
return nil, errors.New("domain base type OID not registered")
}
return &pgtype.Type{Name: typeName, OID: oid, Codec: dt.Codec}, nil
case "e": // enum
return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.EnumCodec{}}, nil
default:

View File

@ -837,24 +837,21 @@ func TestDomainType(t *testing.T) {
// uint64 but a result OID of the underlying numeric.
var s string
err := conn.QueryRow(context.Background(), "select $1::uint64", "24").Scan(&s)
err := conn.QueryRow(ctx, "select $1::uint64", "24").Scan(&s)
require.NoError(t, err)
require.Equal(t, "24", s)
// Register type
var uint64OID uint32
err = conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='uint64';").Scan(&uint64OID)
if err != nil {
t.Fatalf("did not find uint64 OID, %v", err)
}
conn.TypeMap().RegisterType(&pgtype.Type{Name: "uint64", OID: uint64OID, Codec: pgtype.NumericCodec{}})
uint64Type, err := conn.LoadType(ctx, "uint64")
require.NoError(t, err)
conn.TypeMap().RegisterType(uint64Type)
var n uint64
err = conn.QueryRow(context.Background(), "select $1::uint64", uint64(24)).Scan(&n)
err = conn.QueryRow(ctx, "select $1::uint64", uint64(24)).Scan(&n)
require.NoError(t, err)
// String is still an acceptable argument after registration
err = conn.QueryRow(context.Background(), "select $1::uint64", "7").Scan(&n)
err = conn.QueryRow(ctx, "select $1::uint64", "7").Scan(&n)
if err != nil {
t.Fatal(err)
}