From 373bb84e9df4febbb8e77bc54bc52d956688ef56 Mon Sep 17 00:00:00 2001
From: Jack Christensen <jack@jackchristensen.com>
Date: Sat, 20 Feb 2021 17:13:14 -0600
Subject: [PATCH] Add *pgxpool.AcquireFunc

refs #821
---
 pgxpool/pool.go      | 13 +++++++++++++
 pgxpool/pool_test.go | 29 +++++++++++++++++++++++++++++
 2 files changed, 42 insertions(+)

diff --git a/pgxpool/pool.go b/pgxpool/pool.go
index a036049b..8efb9265 100644
--- a/pgxpool/pool.go
+++ b/pgxpool/pool.go
@@ -386,6 +386,19 @@ func (p *Pool) Acquire(ctx context.Context) (*Conn, error) {
 	}
 }
 
+// AcquireFunc acquires a *Conn and calls f with that *Conn. ctx will only affect the Acquire. It has no effect on the
+// call of f. The return value is either an error acquiring the *Conn or the return value of f. The *Conn is
+// automatically released after the call of f.
+func (p *Pool) AcquireFunc(ctx context.Context, f func(*Conn) error) error {
+	conn, err := p.Acquire(ctx)
+	if err != nil {
+		return err
+	}
+	defer conn.Release()
+
+	return f(conn)
+}
+
 // AcquireAllIdle atomically acquires all currently idle connections. Its intended use is for health check and
 // keep-alive functionality. It does not update pool statistics.
 func (p *Pool) AcquireAllIdle(ctx context.Context) []*Conn {
diff --git a/pgxpool/pool_test.go b/pgxpool/pool_test.go
index 55e931cb..12f92c0a 100644
--- a/pgxpool/pool_test.go
+++ b/pgxpool/pool_test.go
@@ -2,6 +2,7 @@ package pgxpool_test
 
 import (
 	"context"
+	"fmt"
 	"os"
 	"testing"
 	"time"
@@ -112,6 +113,34 @@ func TestPoolAcquireAndConnRelease(t *testing.T) {
 	c.Release()
 }
 
+func TestPoolAcquireFunc(t *testing.T) {
+	t.Parallel()
+
+	pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
+	require.NoError(t, err)
+	defer pool.Close()
+
+	var n int32
+	err = pool.AcquireFunc(context.Background(), func(c *pgxpool.Conn) error {
+		return c.QueryRow(context.Background(), "select 1").Scan(&n)
+	})
+	require.NoError(t, err)
+	require.EqualValues(t, 1, n)
+}
+
+func TestPoolAcquireFuncReturnsFnError(t *testing.T) {
+	t.Parallel()
+
+	pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
+	require.NoError(t, err)
+	defer pool.Close()
+
+	err = pool.AcquireFunc(context.Background(), func(c *pgxpool.Conn) error {
+		return fmt.Errorf("some error")
+	})
+	require.EqualError(t, err, "some error")
+}
+
 func TestPoolBeforeConnect(t *testing.T) {
 	t.Parallel()