From 7b1272d2542d8cccf29b4a49ff0268118e259bee Mon Sep 17 00:00:00 2001
From: Jack Christensen <jack@jackchristensen.com>
Date: Thu, 25 Apr 2019 15:07:35 -0500
Subject: [PATCH] Add SendBatch to pool

---
 batch.go              | 39 ++++++++++++++++++++------------
 batch_test.go         | 14 ++++++------
 conn.go               |  8 +++----
 pool/batch_results.go | 52 +++++++++++++++++++++++++++++++++++++++++++
 pool/common_test.go   | 25 +++++++++++++++++++++
 pool/conn.go          |  4 ++++
 pool/conn_test.go     | 12 ++++++++++
 pool/pool.go          | 10 +++++++++
 pool/pool_test.go     | 13 +++++++++++
 pool/todo.txt         |  2 --
 pool/tx.go            |  4 ++++
 pool/tx_test.go       | 12 ++++++++++
 tx.go                 |  4 ++--
 13 files changed, 170 insertions(+), 29 deletions(-)
 create mode 100644 pool/batch_results.go

diff --git a/batch.go b/batch.go
index 9dc45847..c652c2a5 100644
--- a/batch.go
+++ b/batch.go
@@ -31,15 +31,29 @@ func (b *Batch) Queue(query string, arguments []interface{}, parameterOIDs []pgt
 	})
 }
 
-type BatchResults struct {
+type BatchResults interface {
+	// ExecResults reads the results from the next query in the batch as if the query has been sent with Exec.
+	ExecResults() (pgconn.CommandTag, error)
+
+	// QueryResults reads the results from the next query in the batch as if the query has been sent with Query.
+	QueryResults() (Rows, error)
+
+	// QueryRowResults reads the results from the next query in the batch as if the query has been sent with QueryRow.
+	QueryRowResults() Row
+
+	// Close closes the batch operation. Any error that occured during a batch operation may have made it impossible to
+	// resyncronize the connection with the server. In this case the underlying connection will have been closed.
+	Close() error
+}
+
+type batchResults struct {
 	conn *Conn
 	mrr  *pgconn.MultiResultReader
 	err  error
 }
 
-// ExecResults reads the results from the next query in the batch as if the
-// query has been sent with Exec.
-func (br *BatchResults) ExecResults() (pgconn.CommandTag, error) {
+// ExecResults reads the results from the next query in the batch as if the query has been sent with Exec.
+func (br *batchResults) ExecResults() (pgconn.CommandTag, error) {
 	if br.err != nil {
 		return nil, br.err
 	}
@@ -55,9 +69,8 @@ func (br *BatchResults) ExecResults() (pgconn.CommandTag, error) {
 	return br.mrr.ResultReader().Close()
 }
 
-// QueryResults reads the results from the next query in the batch as if the
-// query has been sent with Query.
-func (br *BatchResults) QueryResults() (Rows, error) {
+// QueryResults reads the results from the next query in the batch as if the query has been sent with Query.
+func (br *batchResults) QueryResults() (Rows, error) {
 	rows := br.conn.getRows("batch query", nil)
 
 	if br.err != nil {
@@ -79,18 +92,16 @@ func (br *BatchResults) QueryResults() (Rows, error) {
 	return rows, nil
 }
 
-// QueryRowResults reads the results from the next query in the batch as if the
-// query has been sent with QueryRow.
-func (br *BatchResults) QueryRowResults() Row {
+// QueryRowResults reads the results from the next query in the batch as if the query has been sent with QueryRow.
+func (br *batchResults) QueryRowResults() Row {
 	rows, _ := br.QueryResults()
 	return (*connRow)(rows.(*connRows))
 
 }
 
-// Close closes the batch operation. Any error that occured during a batch
-// operation may have made it impossible to resyncronize the connection with the
-// server. In this case the underlying connection will have been closed.
-func (br *BatchResults) Close() error {
+// Close closes the batch operation. Any error that occured during a batch operation may have made it impossible to
+// resyncronize the connection with the server. In this case the underlying connection will have been closed.
+func (br *batchResults) Close() error {
 	if br.err != nil {
 		return br.err
 	}
diff --git a/batch_test.go b/batch_test.go
index 74e04a60..cb2da08e 100644
--- a/batch_test.go
+++ b/batch_test.go
@@ -10,7 +10,7 @@ import (
 	"github.com/jackc/pgx/v4"
 )
 
-func TestConnBeginBatch(t *testing.T) {
+func TestConnSendBatch(t *testing.T) {
 	t.Parallel()
 
 	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
@@ -156,7 +156,7 @@ func TestConnBeginBatch(t *testing.T) {
 	ensureConnValid(t, conn)
 }
 
-func TestConnBeginBatchWithPreparedStatement(t *testing.T) {
+func TestConnSendBatchWithPreparedStatement(t *testing.T) {
 	t.Parallel()
 
 	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
@@ -209,7 +209,7 @@ func TestConnBeginBatchWithPreparedStatement(t *testing.T) {
 	ensureConnValid(t, conn)
 }
 
-func TestConnBeginBatchCloseRowsPartiallyRead(t *testing.T) {
+func TestConnSendBatchCloseRowsPartiallyRead(t *testing.T) {
 	t.Parallel()
 
 	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
@@ -277,7 +277,7 @@ func TestConnBeginBatchCloseRowsPartiallyRead(t *testing.T) {
 	ensureConnValid(t, conn)
 }
 
-func TestConnBeginBatchQueryError(t *testing.T) {
+func TestConnSendBatchQueryError(t *testing.T) {
 	t.Parallel()
 
 	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
@@ -324,7 +324,7 @@ func TestConnBeginBatchQueryError(t *testing.T) {
 	ensureConnValid(t, conn)
 }
 
-func TestConnBeginBatchQuerySyntaxError(t *testing.T) {
+func TestConnSendBatchQuerySyntaxError(t *testing.T) {
 	t.Parallel()
 
 	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
@@ -353,7 +353,7 @@ func TestConnBeginBatchQuerySyntaxError(t *testing.T) {
 	ensureConnValid(t, conn)
 }
 
-func TestConnBeginBatchQueryRowInsert(t *testing.T) {
+func TestConnSendBatchQueryRowInsert(t *testing.T) {
 	t.Parallel()
 
 	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
@@ -399,7 +399,7 @@ func TestConnBeginBatchQueryRowInsert(t *testing.T) {
 	ensureConnValid(t, conn)
 }
 
-func TestConnBeginBatchQueryPartialReadInsert(t *testing.T) {
+func TestConnSendBatchQueryPartialReadInsert(t *testing.T) {
 	t.Parallel()
 
 	conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
diff --git a/conn.go b/conn.go
index af8cf389..80d040db 100644
--- a/conn.go
+++ b/conn.go
@@ -683,7 +683,7 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) Ro
 
 // SendBatch sends all queued queries to the server at once. All queries are run in an implicit transaction unless
 // explicit transaction control statements are executed.
-func (c *Conn) SendBatch(ctx context.Context, b *Batch) *BatchResults {
+func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
 	batch := &pgconn.Batch{}
 
 	for _, bi := range b.items {
@@ -698,7 +698,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) *BatchResults {
 
 		args, err := convertDriverValuers(bi.arguments)
 		if err != nil {
-			return &BatchResults{err: err}
+			return &batchResults{err: err}
 		}
 
 		paramFormats := make([]int16, len(args))
@@ -707,7 +707,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) *BatchResults {
 			paramFormats[i] = chooseParameterFormatCode(c.ConnInfo, parameterOIDs[i], args[i])
 			paramValues[i], err = newencodePreparedStatementArgument(c.ConnInfo, parameterOIDs[i], args[i])
 			if err != nil {
-				return &BatchResults{err: err}
+				return &batchResults{err: err}
 			}
 
 		}
@@ -739,7 +739,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) *BatchResults {
 
 	mrr := c.pgConn.ExecBatch(ctx, batch)
 
-	return &BatchResults{
+	return &batchResults{
 		conn: c,
 		mrr:  mrr,
 	}
diff --git a/pool/batch_results.go b/pool/batch_results.go
new file mode 100644
index 00000000..949d42b4
--- /dev/null
+++ b/pool/batch_results.go
@@ -0,0 +1,52 @@
+package pool
+
+import (
+	"github.com/jackc/pgconn"
+	"github.com/jackc/pgx/v4"
+)
+
+type errBatchResults struct {
+	err error
+}
+
+func (br errBatchResults) ExecResults() (pgconn.CommandTag, error) {
+	return nil, br.err
+}
+
+func (br errBatchResults) QueryResults() (pgx.Rows, error) {
+	return errRows{err: br.err}, br.err
+}
+
+func (br errBatchResults) QueryRowResults() pgx.Row {
+	return errRow{err: br.err}
+}
+
+func (br errBatchResults) Close() error {
+	return br.err
+}
+
+type poolBatchResults struct {
+	br pgx.BatchResults
+	c  *Conn
+}
+
+func (br *poolBatchResults) ExecResults() (pgconn.CommandTag, error) {
+	return br.br.ExecResults()
+}
+
+func (br *poolBatchResults) QueryResults() (pgx.Rows, error) {
+	return br.br.QueryResults()
+}
+
+func (br *poolBatchResults) QueryRowResults() pgx.Row {
+	return br.br.QueryRowResults()
+}
+
+func (br *poolBatchResults) Close() error {
+	err := br.br.Close()
+	if br.c != nil {
+		br.c.Release()
+		br.c = nil
+	}
+	return err
+}
diff --git a/pool/common_test.go b/pool/common_test.go
index b5a0682f..d0a8fa4a 100644
--- a/pool/common_test.go
+++ b/pool/common_test.go
@@ -69,3 +69,28 @@ func testQueryRow(t *testing.T, db queryRower) {
 	assert.Equal(t, "hello", what)
 	assert.Equal(t, "world", who)
 }
+
+type sendBatcher interface {
+	SendBatch(context.Context, *pgx.Batch) pgx.BatchResults
+}
+
+func testSendBatch(t *testing.T, db sendBatcher) {
+	batch := &pgx.Batch{}
+	batch.Queue("select 1", nil, nil, nil)
+	batch.Queue("select 2", nil, nil, nil)
+
+	br := db.SendBatch(context.Background(), batch)
+
+	var err error
+	var n int32
+	err = br.QueryRowResults().Scan(&n)
+	assert.NoError(t, err)
+	assert.EqualValues(t, 1, n)
+
+	err = br.QueryRowResults().Scan(&n)
+	assert.NoError(t, err)
+	assert.EqualValues(t, 2, n)
+
+	err = br.Close()
+	assert.NoError(t, err)
+}
diff --git a/pool/conn.go b/pool/conn.go
index 273be0aa..7ccb00f6 100644
--- a/pool/conn.go
+++ b/pool/conn.go
@@ -61,6 +61,10 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) pg
 	return c.Conn().QueryRow(ctx, sql, args...)
 }
 
+func (c *Conn) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults {
+	return c.Conn().SendBatch(ctx, b)
+}
+
 func (c *Conn) Begin(ctx context.Context, txOptions *pgx.TxOptions) (*pgx.Tx, error) {
 	return c.Conn().Begin(ctx, txOptions)
 }
diff --git a/pool/conn_test.go b/pool/conn_test.go
index de39dd7b..af6cc8cb 100644
--- a/pool/conn_test.go
+++ b/pool/conn_test.go
@@ -44,3 +44,15 @@ func TestConnQueryRow(t *testing.T) {
 
 	testQueryRow(t, c)
 }
+
+func TestConnSendBatch(t *testing.T) {
+	pool, err := pool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
+	require.NoError(t, err)
+	defer pool.Close()
+
+	c, err := pool.Acquire(context.Background())
+	require.NoError(t, err)
+	defer c.Release()
+
+	testSendBatch(t, c)
+}
diff --git a/pool/pool.go b/pool/pool.go
index ed459735..993077f1 100644
--- a/pool/pool.go
+++ b/pool/pool.go
@@ -127,6 +127,16 @@ func (p *Pool) QueryRow(ctx context.Context, sql string, args ...interface{}) pg
 	return &poolRow{r: row, c: c}
 }
 
+func (p *Pool) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults {
+	c, err := p.Acquire(ctx)
+	if err != nil {
+		return errBatchResults{err: err}
+	}
+
+	br := c.SendBatch(ctx, b)
+	return &poolBatchResults{br: br, c: c}
+}
+
 func (p *Pool) Begin(ctx context.Context, txOptions *pgx.TxOptions) (*Tx, error) {
 	c, err := p.Acquire(ctx)
 	if err != nil {
diff --git a/pool/pool_test.go b/pool/pool_test.go
index 8393c3f9..86767103 100644
--- a/pool/pool_test.go
+++ b/pool/pool_test.go
@@ -90,6 +90,19 @@ func TestPoolQueryRow(t *testing.T) {
 	assert.EqualValues(t, 1, stats.TotalConns())
 }
 
+func TestPoolSendBatch(t *testing.T) {
+	pool, err := pool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
+	require.NoError(t, err)
+	defer pool.Close()
+
+	testSendBatch(t, pool)
+	waitForReleaseToComplete()
+
+	stats := pool.Stat()
+	assert.EqualValues(t, 0, stats.AcquiredConns())
+	assert.EqualValues(t, 1, stats.TotalConns())
+}
+
 func TestConnReleaseRollsBackFailedTransaction(t *testing.T) {
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
 	defer cancel()
diff --git a/pool/todo.txt b/pool/todo.txt
index f12b1052..3b9d1813 100644
--- a/pool/todo.txt
+++ b/pool/todo.txt
@@ -1,5 +1,3 @@
-func (p *ConnPool) BeginBatch() *Batch
-func (p *ConnPool) Close()
 func (p *ConnPool) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error)
 func (p *ConnPool) Deallocate(name string) (err error)
 func (p *ConnPool) Prepare(ctx context.Context, name, sql string) (*PreparedStatement, error)
diff --git a/pool/tx.go b/pool/tx.go
index c9c15b5f..3f0618e0 100644
--- a/pool/tx.go
+++ b/pool/tx.go
@@ -45,3 +45,7 @@ func (tx *Tx) Query(ctx context.Context, sql string, args ...interface{}) (pgx.R
 func (tx *Tx) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row {
 	return tx.c.QueryRow(ctx, sql, args...)
 }
+
+func (tx *Tx) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults {
+	return tx.c.SendBatch(ctx, b)
+}
diff --git a/pool/tx_test.go b/pool/tx_test.go
index 3ec4a0ce..0ea4b8b6 100644
--- a/pool/tx_test.go
+++ b/pool/tx_test.go
@@ -44,3 +44,15 @@ func TestTxQueryRow(t *testing.T) {
 
 	testQueryRow(t, tx)
 }
+
+func TestTxSendBatch(t *testing.T) {
+	pool, err := pool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
+	require.NoError(t, err)
+	defer pool.Close()
+
+	tx, err := pool.Begin(context.Background(), nil)
+	require.NoError(t, err)
+	defer tx.Rollback(context.Background())
+
+	testSendBatch(t, tx)
+}
diff --git a/tx.go b/tx.go
index ec89c2d5..45477da0 100644
--- a/tx.go
+++ b/tx.go
@@ -186,9 +186,9 @@ func (tx *Tx) CopyFrom(ctx context.Context, tableName Identifier, columnNames []
 }
 
 // SendBatch delegates to the underlying *Conn
-func (tx *Tx) SendBatch(ctx context.Context, b *Batch) *BatchResults {
+func (tx *Tx) SendBatch(ctx context.Context, b *Batch) BatchResults {
 	if tx.status != TxStatusInProgress {
-		return &BatchResults{err: ErrTxClosed}
+		return &batchResults{err: ErrTxClosed}
 	}
 
 	return tx.conn.SendBatch(ctx, b)