From 373bb84e9df4febbb8e77bc54bc52d956688ef56 Mon Sep 17 00:00:00 2001 From: Jack Christensen <jack@jackchristensen.com> Date: Sat, 20 Feb 2021 17:13:14 -0600 Subject: [PATCH] Add *pgxpool.AcquireFunc refs #821 --- pgxpool/pool.go | 13 +++++++++++++ pgxpool/pool_test.go | 29 +++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/pgxpool/pool.go b/pgxpool/pool.go index a036049b..8efb9265 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -386,6 +386,19 @@ func (p *Pool) Acquire(ctx context.Context) (*Conn, error) { } } +// AcquireFunc acquires a *Conn and calls f with that *Conn. ctx will only affect the Acquire. It has no effect on the +// call of f. The return value is either an error acquiring the *Conn or the return value of f. The *Conn is +// automatically released after the call of f. +func (p *Pool) AcquireFunc(ctx context.Context, f func(*Conn) error) error { + conn, err := p.Acquire(ctx) + if err != nil { + return err + } + defer conn.Release() + + return f(conn) +} + // AcquireAllIdle atomically acquires all currently idle connections. Its intended use is for health check and // keep-alive functionality. It does not update pool statistics. func (p *Pool) AcquireAllIdle(ctx context.Context) []*Conn { diff --git a/pgxpool/pool_test.go b/pgxpool/pool_test.go index 55e931cb..12f92c0a 100644 --- a/pgxpool/pool_test.go +++ b/pgxpool/pool_test.go @@ -2,6 +2,7 @@ package pgxpool_test import ( "context" + "fmt" "os" "testing" "time" @@ -112,6 +113,34 @@ func TestPoolAcquireAndConnRelease(t *testing.T) { c.Release() } +func TestPoolAcquireFunc(t *testing.T) { + t.Parallel() + + pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + var n int32 + err = pool.AcquireFunc(context.Background(), func(c *pgxpool.Conn) error { + return c.QueryRow(context.Background(), "select 1").Scan(&n) + }) + require.NoError(t, err) + require.EqualValues(t, 1, n) +} + +func TestPoolAcquireFuncReturnsFnError(t *testing.T) { + t.Parallel() + + pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + err = pool.AcquireFunc(context.Background(), func(c *pgxpool.Conn) error { + return fmt.Errorf("some error") + }) + require.EqualError(t, err, "some error") +} + func TestPoolBeforeConnect(t *testing.T) { t.Parallel()