diff --git a/conn.go b/conn.go index f549e03e..540ff31f 100644 --- a/conn.go +++ b/conn.go @@ -384,9 +384,13 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl func (c *Conn) initConnInfo() error { nameOIDs := make(map[string]pgtype.OID, 256) - rows, err := c.Query(`select t.oid, t.typname + rows, err := c.Query(`select t.oid, + case when nsp.nspname in ('pg_catalog', 'public') then t.typname + else nsp.nspname||'.'||t.typname + end from pg_type t left join pg_type base_type on t.typelem=base_type.oid +left join pg_namespace nsp on t.typnamespace=nsp.oid where ( t.typtype in('b', 'p', 'r', 'e') and (base_type.oid is null or base_type.typtype in('b', 'p', 'r')) diff --git a/conn_test.go b/conn_test.go index 1996f814..557a86f1 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1953,3 +1953,31 @@ end$$;`) ensureConnValid(t, conn) } + +func TestConnInitConnInfo(t *testing.T) { + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + // spot check that the standard postgres type names aren't qualified + nameOIDs := map[string]pgtype.OID{ + "_int8": pgtype.Int8ArrayOID, + "int8": pgtype.Int8OID, + "json": pgtype.JSONOID, + "text": pgtype.TextOID, + } + for name, oid := range nameOIDs { + dtByName, ok := conn.ConnInfo.DataTypeForName(name) + if !ok { + t.Fatalf("Expected type named %v to be present", name) + } + dtByOID, ok := conn.ConnInfo.DataTypeForOID(oid) + if !ok { + t.Fatalf("Expected type OID %v to be present", oid) + } + if dtByName != dtByOID { + t.Fatalf("Expected type named %v to be the same as type OID %v", name, oid) + } + } + + ensureConnValid(t, conn) +} diff --git a/pgmock/pgmock.go b/pgmock/pgmock.go index fe78b009..e65a25fb 100644 --- a/pgmock/pgmock.go +++ b/pgmock/pgmock.go @@ -204,7 +204,17 @@ func AcceptUnauthenticatedConnRequestSteps() []Step { func PgxInitSteps() []Step { steps := []Step{ ExpectMessage(&pgproto3.Parse{ - Query: "select t.oid, t.typname\nfrom pg_type t\nleft join pg_type base_type on t.typelem=base_type.oid\nwhere (\n\t t.typtype in('b', 'p', 'r', 'e')\n\t and (base_type.oid is null or base_type.typtype in('b', 'p', 'r'))\n\t)", + Query: `select t.oid, + case when nsp.nspname in ('pg_catalog', 'public') then t.typname + else nsp.nspname||'.'||t.typname + end +from pg_type t +left join pg_type base_type on t.typelem=base_type.oid +left join pg_namespace nsp on t.typnamespace=nsp.oid +where ( + t.typtype in('b', 'p', 'r', 'e') + and (base_type.oid is null or base_type.typtype in('b', 'p', 'r')) + )`, }), ExpectMessage(&pgproto3.Describe{ ObjectType: 'S',