mirror of https://github.com/jackc/pgx.git
Add AfterConnect hook to pool
parent
9008387300
commit
7558b8d05f
22
pool/pool.go
22
pool/pool.go
|
@ -18,6 +18,7 @@ var defaultHealthCheckPeriod = time.Minute
|
||||||
|
|
||||||
type Pool struct {
|
type Pool struct {
|
||||||
p *puddle.Pool
|
p *puddle.Pool
|
||||||
|
afterConnect func(context.Context, *pgx.Conn) error
|
||||||
beforeAcquire func(*pgx.Conn) bool
|
beforeAcquire func(*pgx.Conn) bool
|
||||||
afterRelease func(*pgx.Conn) bool
|
afterRelease func(*pgx.Conn) bool
|
||||||
maxConnLifetime time.Duration
|
maxConnLifetime time.Duration
|
||||||
|
@ -30,6 +31,9 @@ type Pool struct {
|
||||||
type Config struct {
|
type Config struct {
|
||||||
ConnConfig *pgx.ConnConfig
|
ConnConfig *pgx.ConnConfig
|
||||||
|
|
||||||
|
// AfterConnect is called after a connection is established, but before it is added to the pool.
|
||||||
|
AfterConnect func(context.Context, *pgx.Conn) error
|
||||||
|
|
||||||
// BeforeAcquire is called before before a connection is acquired from the pool. It must return true to allow the
|
// BeforeAcquire is called before before a connection is acquired from the pool. It must return true to allow the
|
||||||
// acquision or false to indicate that the connection should be destroyed and a different connection should be
|
// acquision or false to indicate that the connection should be destroyed and a different connection should be
|
||||||
// acquired.
|
// acquired.
|
||||||
|
@ -64,6 +68,7 @@ func Connect(ctx context.Context, connString string) (*Pool, error) {
|
||||||
// connection.
|
// connection.
|
||||||
func ConnectConfig(ctx context.Context, config *Config) (*Pool, error) {
|
func ConnectConfig(ctx context.Context, config *Config) (*Pool, error) {
|
||||||
p := &Pool{
|
p := &Pool{
|
||||||
|
afterConnect: config.AfterConnect,
|
||||||
beforeAcquire: config.BeforeAcquire,
|
beforeAcquire: config.BeforeAcquire,
|
||||||
afterRelease: config.AfterRelease,
|
afterRelease: config.AfterRelease,
|
||||||
maxConnLifetime: config.MaxConnLifetime,
|
maxConnLifetime: config.MaxConnLifetime,
|
||||||
|
@ -72,7 +77,22 @@ func ConnectConfig(ctx context.Context, config *Config) (*Pool, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
p.p = puddle.NewPool(
|
p.p = puddle.NewPool(
|
||||||
func(ctx context.Context) (interface{}, error) { return pgx.ConnectConfig(ctx, config.ConnConfig) },
|
func(ctx context.Context) (interface{}, error) {
|
||||||
|
conn, err := pgx.ConnectConfig(ctx, config.ConnConfig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.afterConnect != nil {
|
||||||
|
err = p.afterConnect(ctx, conn)
|
||||||
|
if err != nil {
|
||||||
|
conn.Close(ctx)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn, nil
|
||||||
|
},
|
||||||
func(value interface{}) {
|
func(value interface{}) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
value.(*pgx.Conn).Close(ctx)
|
value.(*pgx.Conn).Close(ctx)
|
||||||
|
|
|
@ -51,6 +51,27 @@ func TestPoolAcquireAndConnRelease(t *testing.T) {
|
||||||
c.Release()
|
c.Release()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPoolAfterConnect(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
config, err := pool.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
config.AfterConnect = func(ctx context.Context, c *pgx.Conn) error {
|
||||||
|
_, err := c.Prepare(ctx, "ps1", "select 1")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
db, err := pool.ConnectConfig(context.Background(), config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
var n int32
|
||||||
|
err = db.QueryRow(context.Background(), "ps1").Scan(&n)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.EqualValues(t, 1, n)
|
||||||
|
}
|
||||||
|
|
||||||
func TestPoolBeforeAcquire(t *testing.T) {
|
func TestPoolBeforeAcquire(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue