diff --git a/stdlib/opendb.go b/stdlib/opendb.go index 215aa146..ad4e979b 100644 --- a/stdlib/opendb.go +++ b/stdlib/opendb.go @@ -14,7 +14,7 @@ import ( type OptionOpenDB func(*connector) // OptionAfterConnect provide a callback for after connect. -func OptionAfterConnect(ac func(*pgx.Conn) error) OptionOpenDB { +func OptionAfterConnect(ac func(context.Context, *pgx.Conn) error) OptionOpenDB { return func(dc *connector) { dc.AfterConnect = ac } @@ -23,7 +23,7 @@ func OptionAfterConnect(ac func(*pgx.Conn) error) OptionOpenDB { func OpenDB(config pgx.ConnConfig, opts ...OptionOpenDB) *sql.DB { c := connector{ ConnConfig: config, - AfterConnect: func(*pgx.Conn) error { return nil }, // noop after connect by default + AfterConnect: func(context.Context, *pgx.Conn) error { return nil }, // noop after connect by default driver: pgxDriver, } @@ -36,7 +36,7 @@ func OpenDB(config pgx.ConnConfig, opts ...OptionOpenDB) *sql.DB { type connector struct { pgx.ConnConfig - AfterConnect func(*pgx.Conn) error // function to call on every new connection + AfterConnect func(context.Context, *pgx.Conn) error // function to call on every new connection driver *Driver } @@ -51,7 +51,7 @@ func (c connector) Connect(ctx context.Context) (driver.Conn, error) { return nil, err } - if err = c.AfterConnect(conn); err != nil { + if err = c.AfterConnect(ctx, conn); err != nil { return nil, err }