From 7b1272d2542d8cccf29b4a49ff0268118e259bee Mon Sep 17 00:00:00 2001 From: Jack Christensen <jack@jackchristensen.com> Date: Thu, 25 Apr 2019 15:07:35 -0500 Subject: [PATCH] Add SendBatch to pool --- batch.go | 39 ++++++++++++++++++++------------ batch_test.go | 14 ++++++------ conn.go | 8 +++---- pool/batch_results.go | 52 +++++++++++++++++++++++++++++++++++++++++++ pool/common_test.go | 25 +++++++++++++++++++++ pool/conn.go | 4 ++++ pool/conn_test.go | 12 ++++++++++ pool/pool.go | 10 +++++++++ pool/pool_test.go | 13 +++++++++++ pool/todo.txt | 2 -- pool/tx.go | 4 ++++ pool/tx_test.go | 12 ++++++++++ tx.go | 4 ++-- 13 files changed, 170 insertions(+), 29 deletions(-) create mode 100644 pool/batch_results.go diff --git a/batch.go b/batch.go index 9dc45847..c652c2a5 100644 --- a/batch.go +++ b/batch.go @@ -31,15 +31,29 @@ func (b *Batch) Queue(query string, arguments []interface{}, parameterOIDs []pgt }) } -type BatchResults struct { +type BatchResults interface { + // ExecResults reads the results from the next query in the batch as if the query has been sent with Exec. + ExecResults() (pgconn.CommandTag, error) + + // QueryResults reads the results from the next query in the batch as if the query has been sent with Query. + QueryResults() (Rows, error) + + // QueryRowResults reads the results from the next query in the batch as if the query has been sent with QueryRow. + QueryRowResults() Row + + // Close closes the batch operation. Any error that occured during a batch operation may have made it impossible to + // resyncronize the connection with the server. In this case the underlying connection will have been closed. + Close() error +} + +type batchResults struct { conn *Conn mrr *pgconn.MultiResultReader err error } -// ExecResults reads the results from the next query in the batch as if the -// query has been sent with Exec. -func (br *BatchResults) ExecResults() (pgconn.CommandTag, error) { +// ExecResults reads the results from the next query in the batch as if the query has been sent with Exec. +func (br *batchResults) ExecResults() (pgconn.CommandTag, error) { if br.err != nil { return nil, br.err } @@ -55,9 +69,8 @@ func (br *BatchResults) ExecResults() (pgconn.CommandTag, error) { return br.mrr.ResultReader().Close() } -// QueryResults reads the results from the next query in the batch as if the -// query has been sent with Query. -func (br *BatchResults) QueryResults() (Rows, error) { +// QueryResults reads the results from the next query in the batch as if the query has been sent with Query. +func (br *batchResults) QueryResults() (Rows, error) { rows := br.conn.getRows("batch query", nil) if br.err != nil { @@ -79,18 +92,16 @@ func (br *BatchResults) QueryResults() (Rows, error) { return rows, nil } -// QueryRowResults reads the results from the next query in the batch as if the -// query has been sent with QueryRow. -func (br *BatchResults) QueryRowResults() Row { +// QueryRowResults reads the results from the next query in the batch as if the query has been sent with QueryRow. +func (br *batchResults) QueryRowResults() Row { rows, _ := br.QueryResults() return (*connRow)(rows.(*connRows)) } -// Close closes the batch operation. Any error that occured during a batch -// operation may have made it impossible to resyncronize the connection with the -// server. In this case the underlying connection will have been closed. -func (br *BatchResults) Close() error { +// Close closes the batch operation. Any error that occured during a batch operation may have made it impossible to +// resyncronize the connection with the server. In this case the underlying connection will have been closed. +func (br *batchResults) Close() error { if br.err != nil { return br.err } diff --git a/batch_test.go b/batch_test.go index 74e04a60..cb2da08e 100644 --- a/batch_test.go +++ b/batch_test.go @@ -10,7 +10,7 @@ import ( "github.com/jackc/pgx/v4" ) -func TestConnBeginBatch(t *testing.T) { +func TestConnSendBatch(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) @@ -156,7 +156,7 @@ func TestConnBeginBatch(t *testing.T) { ensureConnValid(t, conn) } -func TestConnBeginBatchWithPreparedStatement(t *testing.T) { +func TestConnSendBatchWithPreparedStatement(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) @@ -209,7 +209,7 @@ func TestConnBeginBatchWithPreparedStatement(t *testing.T) { ensureConnValid(t, conn) } -func TestConnBeginBatchCloseRowsPartiallyRead(t *testing.T) { +func TestConnSendBatchCloseRowsPartiallyRead(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) @@ -277,7 +277,7 @@ func TestConnBeginBatchCloseRowsPartiallyRead(t *testing.T) { ensureConnValid(t, conn) } -func TestConnBeginBatchQueryError(t *testing.T) { +func TestConnSendBatchQueryError(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) @@ -324,7 +324,7 @@ func TestConnBeginBatchQueryError(t *testing.T) { ensureConnValid(t, conn) } -func TestConnBeginBatchQuerySyntaxError(t *testing.T) { +func TestConnSendBatchQuerySyntaxError(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) @@ -353,7 +353,7 @@ func TestConnBeginBatchQuerySyntaxError(t *testing.T) { ensureConnValid(t, conn) } -func TestConnBeginBatchQueryRowInsert(t *testing.T) { +func TestConnSendBatchQueryRowInsert(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) @@ -399,7 +399,7 @@ func TestConnBeginBatchQueryRowInsert(t *testing.T) { ensureConnValid(t, conn) } -func TestConnBeginBatchQueryPartialReadInsert(t *testing.T) { +func TestConnSendBatchQueryPartialReadInsert(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) diff --git a/conn.go b/conn.go index af8cf389..80d040db 100644 --- a/conn.go +++ b/conn.go @@ -683,7 +683,7 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) Ro // SendBatch sends all queued queries to the server at once. All queries are run in an implicit transaction unless // explicit transaction control statements are executed. -func (c *Conn) SendBatch(ctx context.Context, b *Batch) *BatchResults { +func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { batch := &pgconn.Batch{} for _, bi := range b.items { @@ -698,7 +698,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) *BatchResults { args, err := convertDriverValuers(bi.arguments) if err != nil { - return &BatchResults{err: err} + return &batchResults{err: err} } paramFormats := make([]int16, len(args)) @@ -707,7 +707,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) *BatchResults { paramFormats[i] = chooseParameterFormatCode(c.ConnInfo, parameterOIDs[i], args[i]) paramValues[i], err = newencodePreparedStatementArgument(c.ConnInfo, parameterOIDs[i], args[i]) if err != nil { - return &BatchResults{err: err} + return &batchResults{err: err} } } @@ -739,7 +739,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) *BatchResults { mrr := c.pgConn.ExecBatch(ctx, batch) - return &BatchResults{ + return &batchResults{ conn: c, mrr: mrr, } diff --git a/pool/batch_results.go b/pool/batch_results.go new file mode 100644 index 00000000..949d42b4 --- /dev/null +++ b/pool/batch_results.go @@ -0,0 +1,52 @@ +package pool + +import ( + "github.com/jackc/pgconn" + "github.com/jackc/pgx/v4" +) + +type errBatchResults struct { + err error +} + +func (br errBatchResults) ExecResults() (pgconn.CommandTag, error) { + return nil, br.err +} + +func (br errBatchResults) QueryResults() (pgx.Rows, error) { + return errRows{err: br.err}, br.err +} + +func (br errBatchResults) QueryRowResults() pgx.Row { + return errRow{err: br.err} +} + +func (br errBatchResults) Close() error { + return br.err +} + +type poolBatchResults struct { + br pgx.BatchResults + c *Conn +} + +func (br *poolBatchResults) ExecResults() (pgconn.CommandTag, error) { + return br.br.ExecResults() +} + +func (br *poolBatchResults) QueryResults() (pgx.Rows, error) { + return br.br.QueryResults() +} + +func (br *poolBatchResults) QueryRowResults() pgx.Row { + return br.br.QueryRowResults() +} + +func (br *poolBatchResults) Close() error { + err := br.br.Close() + if br.c != nil { + br.c.Release() + br.c = nil + } + return err +} diff --git a/pool/common_test.go b/pool/common_test.go index b5a0682f..d0a8fa4a 100644 --- a/pool/common_test.go +++ b/pool/common_test.go @@ -69,3 +69,28 @@ func testQueryRow(t *testing.T, db queryRower) { assert.Equal(t, "hello", what) assert.Equal(t, "world", who) } + +type sendBatcher interface { + SendBatch(context.Context, *pgx.Batch) pgx.BatchResults +} + +func testSendBatch(t *testing.T, db sendBatcher) { + batch := &pgx.Batch{} + batch.Queue("select 1", nil, nil, nil) + batch.Queue("select 2", nil, nil, nil) + + br := db.SendBatch(context.Background(), batch) + + var err error + var n int32 + err = br.QueryRowResults().Scan(&n) + assert.NoError(t, err) + assert.EqualValues(t, 1, n) + + err = br.QueryRowResults().Scan(&n) + assert.NoError(t, err) + assert.EqualValues(t, 2, n) + + err = br.Close() + assert.NoError(t, err) +} diff --git a/pool/conn.go b/pool/conn.go index 273be0aa..7ccb00f6 100644 --- a/pool/conn.go +++ b/pool/conn.go @@ -61,6 +61,10 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) pg return c.Conn().QueryRow(ctx, sql, args...) } +func (c *Conn) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults { + return c.Conn().SendBatch(ctx, b) +} + func (c *Conn) Begin(ctx context.Context, txOptions *pgx.TxOptions) (*pgx.Tx, error) { return c.Conn().Begin(ctx, txOptions) } diff --git a/pool/conn_test.go b/pool/conn_test.go index de39dd7b..af6cc8cb 100644 --- a/pool/conn_test.go +++ b/pool/conn_test.go @@ -44,3 +44,15 @@ func TestConnQueryRow(t *testing.T) { testQueryRow(t, c) } + +func TestConnSendBatch(t *testing.T) { + pool, err := pool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + c, err := pool.Acquire(context.Background()) + require.NoError(t, err) + defer c.Release() + + testSendBatch(t, c) +} diff --git a/pool/pool.go b/pool/pool.go index ed459735..993077f1 100644 --- a/pool/pool.go +++ b/pool/pool.go @@ -127,6 +127,16 @@ func (p *Pool) QueryRow(ctx context.Context, sql string, args ...interface{}) pg return &poolRow{r: row, c: c} } +func (p *Pool) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults { + c, err := p.Acquire(ctx) + if err != nil { + return errBatchResults{err: err} + } + + br := c.SendBatch(ctx, b) + return &poolBatchResults{br: br, c: c} +} + func (p *Pool) Begin(ctx context.Context, txOptions *pgx.TxOptions) (*Tx, error) { c, err := p.Acquire(ctx) if err != nil { diff --git a/pool/pool_test.go b/pool/pool_test.go index 8393c3f9..86767103 100644 --- a/pool/pool_test.go +++ b/pool/pool_test.go @@ -90,6 +90,19 @@ func TestPoolQueryRow(t *testing.T) { assert.EqualValues(t, 1, stats.TotalConns()) } +func TestPoolSendBatch(t *testing.T) { + pool, err := pool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + testSendBatch(t, pool) + waitForReleaseToComplete() + + stats := pool.Stat() + assert.EqualValues(t, 0, stats.AcquiredConns()) + assert.EqualValues(t, 1, stats.TotalConns()) +} + func TestConnReleaseRollsBackFailedTransaction(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() diff --git a/pool/todo.txt b/pool/todo.txt index f12b1052..3b9d1813 100644 --- a/pool/todo.txt +++ b/pool/todo.txt @@ -1,5 +1,3 @@ -func (p *ConnPool) BeginBatch() *Batch -func (p *ConnPool) Close() func (p *ConnPool) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) func (p *ConnPool) Deallocate(name string) (err error) func (p *ConnPool) Prepare(ctx context.Context, name, sql string) (*PreparedStatement, error) diff --git a/pool/tx.go b/pool/tx.go index c9c15b5f..3f0618e0 100644 --- a/pool/tx.go +++ b/pool/tx.go @@ -45,3 +45,7 @@ func (tx *Tx) Query(ctx context.Context, sql string, args ...interface{}) (pgx.R func (tx *Tx) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row { return tx.c.QueryRow(ctx, sql, args...) } + +func (tx *Tx) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults { + return tx.c.SendBatch(ctx, b) +} diff --git a/pool/tx_test.go b/pool/tx_test.go index 3ec4a0ce..0ea4b8b6 100644 --- a/pool/tx_test.go +++ b/pool/tx_test.go @@ -44,3 +44,15 @@ func TestTxQueryRow(t *testing.T) { testQueryRow(t, tx) } + +func TestTxSendBatch(t *testing.T) { + pool, err := pool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + tx, err := pool.Begin(context.Background(), nil) + require.NoError(t, err) + defer tx.Rollback(context.Background()) + + testSendBatch(t, tx) +} diff --git a/tx.go b/tx.go index ec89c2d5..45477da0 100644 --- a/tx.go +++ b/tx.go @@ -186,9 +186,9 @@ func (tx *Tx) CopyFrom(ctx context.Context, tableName Identifier, columnNames [] } // SendBatch delegates to the underlying *Conn -func (tx *Tx) SendBatch(ctx context.Context, b *Batch) *BatchResults { +func (tx *Tx) SendBatch(ctx context.Context, b *Batch) BatchResults { if tx.status != TxStatusInProgress { - return &BatchResults{err: ErrTxClosed} + return &batchResults{err: ErrTxClosed} } return tx.conn.SendBatch(ctx, b)