diff --git a/conn.go b/conn.go index d86c4025..a1542b1c 100644 --- a/conn.go +++ b/conn.go @@ -72,8 +72,9 @@ type ConnConfig struct { Logger Logger LogLevel int Dial DialFunc - RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) - OnNotice NoticeHandler // Callback function called when a notice response is received. + RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) + OnNotice NoticeHandler // Callback function called when a notice response is received. + CustomConnInfo func(*Conn) (*pgtype.ConnInfo, error) // Callback function to implement connection strategies for different backends. crate, pgbouncer, pgpool, etc. } func (cc *ConnConfig) networkAddress() (network, address string) { @@ -382,10 +383,9 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl } } -func (c *Conn) initConnInfo() error { - nameOIDs := make(map[string]pgtype.OID, 256) - - rows, err := c.Query(`select t.oid, +func initPostgresql(c *Conn) (*pgtype.ConnInfo, error) { + const ( + namedOIDQuery = `select t.oid, case when nsp.nspname in ('pg_catalog', 'public') then t.typname else nsp.nspname||'.'||t.typname end @@ -395,45 +395,53 @@ left join pg_namespace nsp on t.typnamespace=nsp.oid where ( t.typtype in('b', 'p', 'r', 'e') and (base_type.oid is null or base_type.typtype in('b', 'p', 'r')) - )`) - isCrateDB := false + )` + ) + + nameOIDs, err := connInfoFromRows(c.Query(namedOIDQuery)) if err != nil { - // Check if CrateDB specific approach might still allow us to connect. - if rows, err = c.crateDBTypesQuery(err); err != nil { - return err - } - isCrateDB = true + return nil, err } - for rows.Next() { - var oid pgtype.OID - var name pgtype.Text - if err := rows.Scan(&oid, &name); err != nil { + cinfo := pgtype.NewConnInfo() + cinfo.InitializeDataTypes(nameOIDs) + + if err = c.initConnInfoEnumArray(cinfo); err != nil { + return nil, err + } + + return cinfo, nil +} + +func (c *Conn) initConnInfo() (err error) { + var ( + connInfo *pgtype.ConnInfo + ) + + if c.config.CustomConnInfo != nil { + if c.ConnInfo, err = c.config.CustomConnInfo(c); err != nil { return err } - nameOIDs[name.String] = oid - } - - if rows.Err() != nil { - return rows.Err() - } - - c.ConnInfo = pgtype.NewConnInfo() - c.ConnInfo.InitializeDataTypes(nameOIDs) - - if isCrateDB { - // CrateDB does not support enums (initConnInfoEnumArray), so we return - // early here. return nil } - return c.initConnInfoEnumArray() + + if connInfo, err = initPostgresql(c); err == nil { + c.ConnInfo = connInfo + return err + } + + // Check if CrateDB specific approach might still allow us to connect. + if connInfo, err = c.crateDBTypesQuery(err); err == nil { + c.ConnInfo = connInfo + } + + return err } // initConnInfoEnumArray introspects for arrays of enums and registers a data type for them. -func (c *Conn) initConnInfoEnumArray() error { +func (c *Conn) initConnInfoEnumArray(cinfo *pgtype.ConnInfo) error { nameOIDs := make(map[string]pgtype.OID, 16) - rows, err := c.Query(`select t.oid, t.typname from pg_type t join pg_type base_type on t.typelem=base_type.oid @@ -458,10 +466,10 @@ where t.typtype = 'b' } for name, oid := range nameOIDs { - c.ConnInfo.RegisterDataType(pgtype.DataType{ - &pgtype.EnumArray{}, - name, - oid, + cinfo.RegisterDataType(pgtype.DataType{ + Value: &pgtype.EnumArray{}, + Name: name, + OID: oid, }) } @@ -472,7 +480,7 @@ where t.typtype = 'b' // CrateDB not implementing the pg_types table correctly. If yes, a CrateDB // specific query against pg_types is executed and its results are returned. If // not, the original error is returned. -func (c *Conn) crateDBTypesQuery(err error) (*Rows, error) { +func (c *Conn) crateDBTypesQuery(err error) (*pgtype.ConnInfo, error) { // CrateDB 2.1.6 is a database that implements the PostgreSQL wire protocol, // but not perfectly. In particular, the pg_catalog schema containing the // pg_type table is not visible by default and the pg_type.typtype column is @@ -521,8 +529,20 @@ func (c *Conn) crateDBTypesQuery(err error) (*Rows, error) { (pgErr.Code == "XX000" || strings.Contains(pgErr.Message, "TableUnknownException") || strings.Contains(pgErr.Message, "ColumnUnknownException")) { - return c.Query(`select oid, typname from pg_catalog.pg_type`) + var ( + nameOIDs map[string]pgtype.OID + ) + + if nameOIDs, err = connInfoFromRows(c.Query(`select oid, typname from pg_catalog.pg_type`)); err != nil { + return nil, err + } + + cinfo := pgtype.NewConnInfo() + cinfo.InitializeDataTypes(nameOIDs) + + return cinfo, err } + return nil, err } @@ -1744,3 +1764,27 @@ func (c *Conn) ensureConnectionReadyForQuery() error { return nil } + +func connInfoFromRows(rows *Rows, err error) (map[string]pgtype.OID, error) { + if err != nil { + return nil, err + } + defer rows.Close() + + nameOIDs := make(map[string]pgtype.OID, 256) + for rows.Next() { + var oid pgtype.OID + var name pgtype.Text + if err = rows.Scan(&oid, &name); err != nil { + return nil, err + } + + nameOIDs[name.String] = oid + } + + if err = rows.Err(); err != nil { + return nil, err + } + + return nameOIDs, err +}