mirror of https://github.com/jackc/pgx.git
Add SendBatch to pool
parent
00d123a944
commit
7b1272d254
39
batch.go
39
batch.go
|
@ -31,15 +31,29 @@ func (b *Batch) Queue(query string, arguments []interface{}, parameterOIDs []pgt
|
|||
})
|
||||
}
|
||||
|
||||
type BatchResults struct {
|
||||
type BatchResults interface {
|
||||
// ExecResults reads the results from the next query in the batch as if the query has been sent with Exec.
|
||||
ExecResults() (pgconn.CommandTag, error)
|
||||
|
||||
// QueryResults reads the results from the next query in the batch as if the query has been sent with Query.
|
||||
QueryResults() (Rows, error)
|
||||
|
||||
// QueryRowResults reads the results from the next query in the batch as if the query has been sent with QueryRow.
|
||||
QueryRowResults() Row
|
||||
|
||||
// Close closes the batch operation. Any error that occured during a batch operation may have made it impossible to
|
||||
// resyncronize the connection with the server. In this case the underlying connection will have been closed.
|
||||
Close() error
|
||||
}
|
||||
|
||||
type batchResults struct {
|
||||
conn *Conn
|
||||
mrr *pgconn.MultiResultReader
|
||||
err error
|
||||
}
|
||||
|
||||
// ExecResults reads the results from the next query in the batch as if the
|
||||
// query has been sent with Exec.
|
||||
func (br *BatchResults) ExecResults() (pgconn.CommandTag, error) {
|
||||
// ExecResults reads the results from the next query in the batch as if the query has been sent with Exec.
|
||||
func (br *batchResults) ExecResults() (pgconn.CommandTag, error) {
|
||||
if br.err != nil {
|
||||
return nil, br.err
|
||||
}
|
||||
|
@ -55,9 +69,8 @@ func (br *BatchResults) ExecResults() (pgconn.CommandTag, error) {
|
|||
return br.mrr.ResultReader().Close()
|
||||
}
|
||||
|
||||
// QueryResults reads the results from the next query in the batch as if the
|
||||
// query has been sent with Query.
|
||||
func (br *BatchResults) QueryResults() (Rows, error) {
|
||||
// QueryResults reads the results from the next query in the batch as if the query has been sent with Query.
|
||||
func (br *batchResults) QueryResults() (Rows, error) {
|
||||
rows := br.conn.getRows("batch query", nil)
|
||||
|
||||
if br.err != nil {
|
||||
|
@ -79,18 +92,16 @@ func (br *BatchResults) QueryResults() (Rows, error) {
|
|||
return rows, nil
|
||||
}
|
||||
|
||||
// QueryRowResults reads the results from the next query in the batch as if the
|
||||
// query has been sent with QueryRow.
|
||||
func (br *BatchResults) QueryRowResults() Row {
|
||||
// QueryRowResults reads the results from the next query in the batch as if the query has been sent with QueryRow.
|
||||
func (br *batchResults) QueryRowResults() Row {
|
||||
rows, _ := br.QueryResults()
|
||||
return (*connRow)(rows.(*connRows))
|
||||
|
||||
}
|
||||
|
||||
// Close closes the batch operation. Any error that occured during a batch
|
||||
// operation may have made it impossible to resyncronize the connection with the
|
||||
// server. In this case the underlying connection will have been closed.
|
||||
func (br *BatchResults) Close() error {
|
||||
// Close closes the batch operation. Any error that occured during a batch operation may have made it impossible to
|
||||
// resyncronize the connection with the server. In this case the underlying connection will have been closed.
|
||||
func (br *batchResults) Close() error {
|
||||
if br.err != nil {
|
||||
return br.err
|
||||
}
|
||||
|
|
|
@ -10,7 +10,7 @@ import (
|
|||
"github.com/jackc/pgx/v4"
|
||||
)
|
||||
|
||||
func TestConnBeginBatch(t *testing.T) {
|
||||
func TestConnSendBatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
|
@ -156,7 +156,7 @@ func TestConnBeginBatch(t *testing.T) {
|
|||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnBeginBatchWithPreparedStatement(t *testing.T) {
|
||||
func TestConnSendBatchWithPreparedStatement(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
|
@ -209,7 +209,7 @@ func TestConnBeginBatchWithPreparedStatement(t *testing.T) {
|
|||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnBeginBatchCloseRowsPartiallyRead(t *testing.T) {
|
||||
func TestConnSendBatchCloseRowsPartiallyRead(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
|
@ -277,7 +277,7 @@ func TestConnBeginBatchCloseRowsPartiallyRead(t *testing.T) {
|
|||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnBeginBatchQueryError(t *testing.T) {
|
||||
func TestConnSendBatchQueryError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
|
@ -324,7 +324,7 @@ func TestConnBeginBatchQueryError(t *testing.T) {
|
|||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnBeginBatchQuerySyntaxError(t *testing.T) {
|
||||
func TestConnSendBatchQuerySyntaxError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
|
@ -353,7 +353,7 @@ func TestConnBeginBatchQuerySyntaxError(t *testing.T) {
|
|||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnBeginBatchQueryRowInsert(t *testing.T) {
|
||||
func TestConnSendBatchQueryRowInsert(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
|
@ -399,7 +399,7 @@ func TestConnBeginBatchQueryRowInsert(t *testing.T) {
|
|||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnBeginBatchQueryPartialReadInsert(t *testing.T) {
|
||||
func TestConnSendBatchQueryPartialReadInsert(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
|
|
8
conn.go
8
conn.go
|
@ -683,7 +683,7 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) Ro
|
|||
|
||||
// SendBatch sends all queued queries to the server at once. All queries are run in an implicit transaction unless
|
||||
// explicit transaction control statements are executed.
|
||||
func (c *Conn) SendBatch(ctx context.Context, b *Batch) *BatchResults {
|
||||
func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
|
||||
batch := &pgconn.Batch{}
|
||||
|
||||
for _, bi := range b.items {
|
||||
|
@ -698,7 +698,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) *BatchResults {
|
|||
|
||||
args, err := convertDriverValuers(bi.arguments)
|
||||
if err != nil {
|
||||
return &BatchResults{err: err}
|
||||
return &batchResults{err: err}
|
||||
}
|
||||
|
||||
paramFormats := make([]int16, len(args))
|
||||
|
@ -707,7 +707,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) *BatchResults {
|
|||
paramFormats[i] = chooseParameterFormatCode(c.ConnInfo, parameterOIDs[i], args[i])
|
||||
paramValues[i], err = newencodePreparedStatementArgument(c.ConnInfo, parameterOIDs[i], args[i])
|
||||
if err != nil {
|
||||
return &BatchResults{err: err}
|
||||
return &batchResults{err: err}
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -739,7 +739,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) *BatchResults {
|
|||
|
||||
mrr := c.pgConn.ExecBatch(ctx, batch)
|
||||
|
||||
return &BatchResults{
|
||||
return &batchResults{
|
||||
conn: c,
|
||||
mrr: mrr,
|
||||
}
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
package pool
|
||||
|
||||
import (
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgx/v4"
|
||||
)
|
||||
|
||||
type errBatchResults struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (br errBatchResults) ExecResults() (pgconn.CommandTag, error) {
|
||||
return nil, br.err
|
||||
}
|
||||
|
||||
func (br errBatchResults) QueryResults() (pgx.Rows, error) {
|
||||
return errRows{err: br.err}, br.err
|
||||
}
|
||||
|
||||
func (br errBatchResults) QueryRowResults() pgx.Row {
|
||||
return errRow{err: br.err}
|
||||
}
|
||||
|
||||
func (br errBatchResults) Close() error {
|
||||
return br.err
|
||||
}
|
||||
|
||||
type poolBatchResults struct {
|
||||
br pgx.BatchResults
|
||||
c *Conn
|
||||
}
|
||||
|
||||
func (br *poolBatchResults) ExecResults() (pgconn.CommandTag, error) {
|
||||
return br.br.ExecResults()
|
||||
}
|
||||
|
||||
func (br *poolBatchResults) QueryResults() (pgx.Rows, error) {
|
||||
return br.br.QueryResults()
|
||||
}
|
||||
|
||||
func (br *poolBatchResults) QueryRowResults() pgx.Row {
|
||||
return br.br.QueryRowResults()
|
||||
}
|
||||
|
||||
func (br *poolBatchResults) Close() error {
|
||||
err := br.br.Close()
|
||||
if br.c != nil {
|
||||
br.c.Release()
|
||||
br.c = nil
|
||||
}
|
||||
return err
|
||||
}
|
|
@ -69,3 +69,28 @@ func testQueryRow(t *testing.T, db queryRower) {
|
|||
assert.Equal(t, "hello", what)
|
||||
assert.Equal(t, "world", who)
|
||||
}
|
||||
|
||||
type sendBatcher interface {
|
||||
SendBatch(context.Context, *pgx.Batch) pgx.BatchResults
|
||||
}
|
||||
|
||||
func testSendBatch(t *testing.T, db sendBatcher) {
|
||||
batch := &pgx.Batch{}
|
||||
batch.Queue("select 1", nil, nil, nil)
|
||||
batch.Queue("select 2", nil, nil, nil)
|
||||
|
||||
br := db.SendBatch(context.Background(), batch)
|
||||
|
||||
var err error
|
||||
var n int32
|
||||
err = br.QueryRowResults().Scan(&n)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 1, n)
|
||||
|
||||
err = br.QueryRowResults().Scan(&n)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 2, n)
|
||||
|
||||
err = br.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
|
|
@ -61,6 +61,10 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) pg
|
|||
return c.Conn().QueryRow(ctx, sql, args...)
|
||||
}
|
||||
|
||||
func (c *Conn) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults {
|
||||
return c.Conn().SendBatch(ctx, b)
|
||||
}
|
||||
|
||||
func (c *Conn) Begin(ctx context.Context, txOptions *pgx.TxOptions) (*pgx.Tx, error) {
|
||||
return c.Conn().Begin(ctx, txOptions)
|
||||
}
|
||||
|
|
|
@ -44,3 +44,15 @@ func TestConnQueryRow(t *testing.T) {
|
|||
|
||||
testQueryRow(t, c)
|
||||
}
|
||||
|
||||
func TestConnSendBatch(t *testing.T) {
|
||||
pool, err := pool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
c, err := pool.Acquire(context.Background())
|
||||
require.NoError(t, err)
|
||||
defer c.Release()
|
||||
|
||||
testSendBatch(t, c)
|
||||
}
|
||||
|
|
10
pool/pool.go
10
pool/pool.go
|
@ -127,6 +127,16 @@ func (p *Pool) QueryRow(ctx context.Context, sql string, args ...interface{}) pg
|
|||
return &poolRow{r: row, c: c}
|
||||
}
|
||||
|
||||
func (p *Pool) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults {
|
||||
c, err := p.Acquire(ctx)
|
||||
if err != nil {
|
||||
return errBatchResults{err: err}
|
||||
}
|
||||
|
||||
br := c.SendBatch(ctx, b)
|
||||
return &poolBatchResults{br: br, c: c}
|
||||
}
|
||||
|
||||
func (p *Pool) Begin(ctx context.Context, txOptions *pgx.TxOptions) (*Tx, error) {
|
||||
c, err := p.Acquire(ctx)
|
||||
if err != nil {
|
||||
|
|
|
@ -90,6 +90,19 @@ func TestPoolQueryRow(t *testing.T) {
|
|||
assert.EqualValues(t, 1, stats.TotalConns())
|
||||
}
|
||||
|
||||
func TestPoolSendBatch(t *testing.T) {
|
||||
pool, err := pool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
testSendBatch(t, pool)
|
||||
waitForReleaseToComplete()
|
||||
|
||||
stats := pool.Stat()
|
||||
assert.EqualValues(t, 0, stats.AcquiredConns())
|
||||
assert.EqualValues(t, 1, stats.TotalConns())
|
||||
}
|
||||
|
||||
func TestConnReleaseRollsBackFailedTransaction(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
func (p *ConnPool) BeginBatch() *Batch
|
||||
func (p *ConnPool) Close()
|
||||
func (p *ConnPool) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error)
|
||||
func (p *ConnPool) Deallocate(name string) (err error)
|
||||
func (p *ConnPool) Prepare(ctx context.Context, name, sql string) (*PreparedStatement, error)
|
||||
|
|
|
@ -45,3 +45,7 @@ func (tx *Tx) Query(ctx context.Context, sql string, args ...interface{}) (pgx.R
|
|||
func (tx *Tx) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row {
|
||||
return tx.c.QueryRow(ctx, sql, args...)
|
||||
}
|
||||
|
||||
func (tx *Tx) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults {
|
||||
return tx.c.SendBatch(ctx, b)
|
||||
}
|
||||
|
|
|
@ -44,3 +44,15 @@ func TestTxQueryRow(t *testing.T) {
|
|||
|
||||
testQueryRow(t, tx)
|
||||
}
|
||||
|
||||
func TestTxSendBatch(t *testing.T) {
|
||||
pool, err := pool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
tx, err := pool.Begin(context.Background(), nil)
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback(context.Background())
|
||||
|
||||
testSendBatch(t, tx)
|
||||
}
|
||||
|
|
4
tx.go
4
tx.go
|
@ -186,9 +186,9 @@ func (tx *Tx) CopyFrom(ctx context.Context, tableName Identifier, columnNames []
|
|||
}
|
||||
|
||||
// SendBatch delegates to the underlying *Conn
|
||||
func (tx *Tx) SendBatch(ctx context.Context, b *Batch) *BatchResults {
|
||||
func (tx *Tx) SendBatch(ctx context.Context, b *Batch) BatchResults {
|
||||
if tx.status != TxStatusInProgress {
|
||||
return &BatchResults{err: ErrTxClosed}
|
||||
return &batchResults{err: ErrTxClosed}
|
||||
}
|
||||
|
||||
return tx.conn.SendBatch(ctx, b)
|
||||
|
|
Loading…
Reference in New Issue