diff --git a/batch.go b/batch.go index 2396fd71..9dc45847 100644 --- a/batch.go +++ b/batch.go @@ -1,8 +1,6 @@ package pgx import ( - "context" - "github.com/jackc/pgconn" "github.com/jackc/pgtype" errors "golang.org/x/xerrors" @@ -18,21 +16,7 @@ type batchItem struct { // Batch queries are a way of bundling multiple queries together to avoid // unnecessary network round trips. type Batch struct { - conn *Conn items []*batchItem - err error - - mrr *pgconn.MultiResultReader -} - -// BeginBatch returns a *Batch query for c. -func (c *Conn) BeginBatch() *Batch { - return &Batch{conn: c} -} - -// Conn returns the underlying connection that b will or was performed on. -func (b *Batch) Conn() *Conn { - return b.conn } // Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement. parameterOIDs and @@ -47,92 +31,43 @@ func (b *Batch) Queue(query string, arguments []interface{}, parameterOIDs []pgt }) } -// Send sends all queued queries to the server at once. All queries are run in an implicit transaction unless explicit -// transaction control statements are executed. -func (b *Batch) Send(ctx context.Context) error { - if b.err != nil { - return b.err - } - - batch := &pgconn.Batch{} - - for _, bi := range b.items { - var parameterOIDs []pgtype.OID - ps := b.conn.preparedStatements[bi.query] - - if ps != nil { - parameterOIDs = ps.ParameterOIDs - } else { - parameterOIDs = bi.parameterOIDs - } - - args, err := convertDriverValuers(bi.arguments) - if err != nil { - return err - } - - paramFormats := make([]int16, len(args)) - paramValues := make([][]byte, len(args)) - for i := range args { - paramFormats[i] = chooseParameterFormatCode(b.conn.ConnInfo, parameterOIDs[i], args[i]) - paramValues[i], err = newencodePreparedStatementArgument(b.conn.ConnInfo, parameterOIDs[i], args[i]) - if err != nil { - return err - } - - } - - if ps != nil { - resultFormats := bi.resultFormatCodes - if resultFormats == nil { - resultFormats = make([]int16, len(ps.FieldDescriptions)) - for i := range resultFormats { - if dt, ok := b.conn.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok { - if _, ok := dt.Value.(pgtype.BinaryDecoder); ok { - resultFormats[i] = BinaryFormatCode - } else { - resultFormats[i] = TextFormatCode - } - } - } - } - - batch.ExecPrepared(ps.Name, paramValues, paramFormats, resultFormats) - } else { - oids := make([]uint32, len(parameterOIDs)) - for i := 0; i < len(parameterOIDs); i++ { - oids[i] = uint32(parameterOIDs[i]) - } - batch.ExecParams(bi.query, paramValues, oids, paramFormats, bi.resultFormatCodes) - } - } - - b.mrr = b.conn.pgConn.ExecBatch(ctx, batch) - - return nil +type BatchResults struct { + conn *Conn + mrr *pgconn.MultiResultReader + err error } // ExecResults reads the results from the next query in the batch as if the // query has been sent with Exec. -func (b *Batch) ExecResults() (pgconn.CommandTag, error) { - if !b.mrr.NextResult() { - err := b.mrr.Close() +func (br *BatchResults) ExecResults() (pgconn.CommandTag, error) { + if br.err != nil { + return nil, br.err + } + + if !br.mrr.NextResult() { + err := br.mrr.Close() if err == nil { err = errors.New("no result") } return nil, err } - return b.mrr.ResultReader().Close() + return br.mrr.ResultReader().Close() } // QueryResults reads the results from the next query in the batch as if the // query has been sent with Query. -func (b *Batch) QueryResults() (Rows, error) { - rows := b.conn.getRows("batch query", nil) +func (br *BatchResults) QueryResults() (Rows, error) { + rows := br.conn.getRows("batch query", nil) - if !b.mrr.NextResult() { - rows.err = b.mrr.Close() + if br.err != nil { + rows.err = br.err + rows.closed = true + return rows, br.err + } + + if !br.mrr.NextResult() { + rows.err = br.mrr.Close() if rows.err == nil { rows.err = errors.New("no result") } @@ -140,14 +75,14 @@ func (b *Batch) QueryResults() (Rows, error) { return rows, rows.err } - rows.resultReader = b.mrr.ResultReader() + rows.resultReader = br.mrr.ResultReader() return rows, nil } // QueryRowResults reads the results from the next query in the batch as if the // query has been sent with QueryRow. -func (b *Batch) QueryRowResults() Row { - rows, _ := b.QueryResults() +func (br *BatchResults) QueryRowResults() Row { + rows, _ := br.QueryResults() return (*connRow)(rows.(*connRows)) } @@ -155,6 +90,10 @@ func (b *Batch) QueryRowResults() Row { // Close closes the batch operation. Any error that occured during a batch // operation may have made it impossible to resyncronize the connection with the // server. In this case the underlying connection will have been closed. -func (b *Batch) Close() (err error) { - return b.mrr.Close() +func (br *BatchResults) Close() error { + if br.err != nil { + return br.err + } + + return br.mrr.Close() } diff --git a/batch_test.go b/batch_test.go index fd65985e..74e04a60 100644 --- a/batch_test.go +++ b/batch_test.go @@ -23,7 +23,7 @@ func TestConnBeginBatch(t *testing.T) { );` mustExec(t, conn, sql) - batch := conn.BeginBatch() + batch := &pgx.Batch{} batch.Queue("insert into ledger(description, amount) values($1, $2)", []interface{}{"q1", 1}, []pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID}, @@ -50,12 +50,9 @@ func TestConnBeginBatch(t *testing.T) { []int16{pgx.BinaryFormatCode}, ) - err := batch.Send(context.Background()) - if err != nil { - t.Fatal(err) - } + br := conn.SendBatch(context.Background(), batch) - ct, err := batch.ExecResults() + ct, err := br.ExecResults() if err != nil { t.Error(err) } @@ -63,7 +60,7 @@ func TestConnBeginBatch(t *testing.T) { t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) } - ct, err = batch.ExecResults() + ct, err = br.ExecResults() if err != nil { t.Error(err) } @@ -71,7 +68,7 @@ func TestConnBeginBatch(t *testing.T) { t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) } - ct, err = batch.ExecResults() + ct, err = br.ExecResults() if err != nil { t.Error(err) } @@ -79,7 +76,7 @@ func TestConnBeginBatch(t *testing.T) { t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) } - rows, err := batch.QueryResults() + rows, err := br.QueryResults() if err != nil { t.Error(err) } @@ -143,7 +140,7 @@ func TestConnBeginBatch(t *testing.T) { t.Fatal(rows.Err()) } - err = batch.QueryRowResults().Scan(&amount) + err = br.QueryRowResults().Scan(&amount) if err != nil { t.Error(err) } @@ -151,7 +148,7 @@ func TestConnBeginBatch(t *testing.T) { t.Errorf("amount => %v, want %v", amount, 6) } - err = batch.Close() + err = br.Close() if err != nil { t.Fatal(err) } @@ -170,7 +167,7 @@ func TestConnBeginBatchWithPreparedStatement(t *testing.T) { t.Fatal(err) } - batch := conn.BeginBatch() + batch := &pgx.Batch{} queryCount := 3 for i := 0; i < queryCount; i++ { @@ -181,13 +178,10 @@ func TestConnBeginBatchWithPreparedStatement(t *testing.T) { ) } - err = batch.Send(context.Background()) - if err != nil { - t.Fatal(err) - } + br := conn.SendBatch(context.Background(), batch) for i := 0; i < queryCount; i++ { - rows, err := batch.QueryResults() + rows, err := br.QueryResults() if err != nil { t.Fatal(err) } @@ -207,7 +201,7 @@ func TestConnBeginBatchWithPreparedStatement(t *testing.T) { } } - err = batch.Close() + err = br.Close() if err != nil { t.Fatal(err) } @@ -221,7 +215,7 @@ func TestConnBeginBatchCloseRowsPartiallyRead(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - batch := conn.BeginBatch() + batch := &pgx.Batch{} batch.Queue("select n from generate_series(0,5) n", nil, nil, @@ -233,12 +227,9 @@ func TestConnBeginBatchCloseRowsPartiallyRead(t *testing.T) { []int16{pgx.BinaryFormatCode}, ) - err := batch.Send(context.Background()) - if err != nil { - t.Fatal(err) - } + br := conn.SendBatch(context.Background(), batch) - rows, err := batch.QueryResults() + rows, err := br.QueryResults() if err != nil { t.Error(err) } @@ -259,7 +250,7 @@ func TestConnBeginBatchCloseRowsPartiallyRead(t *testing.T) { rows.Close() - rows, err = batch.QueryResults() + rows, err = br.QueryResults() if err != nil { t.Error(err) } @@ -278,7 +269,7 @@ func TestConnBeginBatchCloseRowsPartiallyRead(t *testing.T) { t.Error(rows.Err()) } - err = batch.Close() + err = br.Close() if err != nil { t.Fatal(err) } @@ -292,7 +283,7 @@ func TestConnBeginBatchQueryError(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - batch := conn.BeginBatch() + batch := &pgx.Batch{} batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0", nil, nil, @@ -304,12 +295,9 @@ func TestConnBeginBatchQueryError(t *testing.T) { []int16{pgx.BinaryFormatCode}, ) - err := batch.Send(context.Background()) - if err != nil { - t.Fatal(err) - } + br := conn.SendBatch(context.Background(), batch) - rows, err := batch.QueryResults() + rows, err := br.QueryResults() if err != nil { t.Error(err) } @@ -328,7 +316,7 @@ func TestConnBeginBatchQueryError(t *testing.T) { t.Errorf("rows.Err() => %v, want error code %v", rows.Err(), 22012) } - err = batch.Close() + err = br.Close() if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "22012") { t.Errorf("rows.Err() => %v, want error code %v", err, 22012) } @@ -342,25 +330,22 @@ func TestConnBeginBatchQuerySyntaxError(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) - batch := conn.BeginBatch() + batch := &pgx.Batch{} batch.Queue("select 1 1", nil, nil, []int16{pgx.BinaryFormatCode}, ) - err := batch.Send(context.Background()) - if err != nil { - t.Fatal(err) - } + br := conn.SendBatch(context.Background(), batch) var n int32 - err = batch.QueryRowResults().Scan(&n) + err := br.QueryRowResults().Scan(&n) if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "42601") { t.Errorf("rows.Err() => %v, want error code %v", err, 42601) } - err = batch.Close() + err = br.Close() if err == nil { t.Error("Expected error") } @@ -381,7 +366,7 @@ func TestConnBeginBatchQueryRowInsert(t *testing.T) { );` mustExec(t, conn, sql) - batch := conn.BeginBatch() + batch := &pgx.Batch{} batch.Queue("select 1", nil, nil, @@ -393,18 +378,15 @@ func TestConnBeginBatchQueryRowInsert(t *testing.T) { nil, ) - err := batch.Send(context.Background()) - if err != nil { - t.Fatal(err) - } + br := conn.SendBatch(context.Background(), batch) var value int - err = batch.QueryRowResults().Scan(&value) + err := br.QueryRowResults().Scan(&value) if err != nil { t.Error(err) } - ct, err := batch.ExecResults() + ct, err := br.ExecResults() if err != nil { t.Error(err) } @@ -412,7 +394,7 @@ func TestConnBeginBatchQueryRowInsert(t *testing.T) { t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2) } - batch.Close() + br.Close() ensureConnValid(t, conn) } @@ -430,7 +412,7 @@ func TestConnBeginBatchQueryPartialReadInsert(t *testing.T) { );` mustExec(t, conn, sql) - batch := conn.BeginBatch() + batch := &pgx.Batch{} batch.Queue("select 1 union all select 2 union all select 3", nil, nil, @@ -442,18 +424,15 @@ func TestConnBeginBatchQueryPartialReadInsert(t *testing.T) { nil, ) - err := batch.Send(context.Background()) - if err != nil { - t.Fatal(err) - } + br := conn.SendBatch(context.Background(), batch) - rows, err := batch.QueryResults() + rows, err := br.QueryResults() if err != nil { t.Error(err) } rows.Close() - ct, err := batch.ExecResults() + ct, err := br.ExecResults() if err != nil { t.Error(err) } @@ -461,12 +440,12 @@ func TestConnBeginBatchQueryPartialReadInsert(t *testing.T) { t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2) } - batch.Close() + br.Close() ensureConnValid(t, conn) } -func TestTxBeginBatch(t *testing.T) { +func TestTxSendBatch(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) @@ -485,25 +464,23 @@ func TestTxBeginBatch(t *testing.T) { mustExec(t, conn, sql) tx, _ := conn.Begin(context.Background(), nil) - batch := tx.BeginBatch() + batch := &pgx.Batch{} batch.Queue("insert into ledger1(description) values($1) returning id", []interface{}{"q1"}, []pgtype.OID{pgtype.VarcharOID}, []int16{pgx.BinaryFormatCode}, ) - err := batch.Send(context.Background()) - if err != nil { - t.Fatal(err) - } + br := tx.SendBatch(context.Background(), batch) + var id int - err = batch.QueryRowResults().Scan(&id) + err := br.QueryRowResults().Scan(&id) if err != nil { t.Error(err) } - batch.Close() + br.Close() - batch = tx.BeginBatch() + batch = &pgx.Batch{} batch.Queue("insert into ledger2(id,amount) values($1, $2)", []interface{}{id, 2}, []pgtype.OID{pgtype.Int4OID, pgtype.Int4OID}, @@ -516,11 +493,9 @@ func TestTxBeginBatch(t *testing.T) { nil, ) - err = batch.Send(context.Background()) - if err != nil { - t.Fatal(err) - } - ct, err := batch.ExecResults() + br = tx.SendBatch(context.Background(), batch) + + ct, err := br.ExecResults() if err != nil { t.Error(err) } @@ -529,12 +504,12 @@ func TestTxBeginBatch(t *testing.T) { } var amount int - err = batch.QueryRowResults().Scan(&amount) + err = br.QueryRowResults().Scan(&amount) if err != nil { t.Error(err) } - batch.Close() + br.Close() tx.Commit(context.Background()) var count int @@ -543,7 +518,7 @@ func TestTxBeginBatch(t *testing.T) { t.Errorf("count => %v, want %v", count, 1) } - err = batch.Close() + err = br.Close() if err != nil { t.Fatal(err) } @@ -551,7 +526,7 @@ func TestTxBeginBatch(t *testing.T) { ensureConnValid(t, conn) } -func TestTxBeginBatchRollback(t *testing.T) { +func TestTxSendBatchRollback(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) @@ -564,23 +539,21 @@ func TestTxBeginBatchRollback(t *testing.T) { mustExec(t, conn, sql) tx, _ := conn.Begin(context.Background(), nil) - batch := tx.BeginBatch() + batch := &pgx.Batch{} batch.Queue("insert into ledger1(description) values($1) returning id", []interface{}{"q1"}, []pgtype.OID{pgtype.VarcharOID}, []int16{pgx.BinaryFormatCode}, ) - err := batch.Send(context.Background()) - if err != nil { - t.Fatal(err) - } + br := tx.SendBatch(context.Background(), batch) + var id int - err = batch.QueryRowResults().Scan(&id) + err := br.QueryRowResults().Scan(&id) if err != nil { t.Error(err) } - batch.Close() + br.Close() tx.Rollback(context.Background()) row := conn.QueryRow(context.Background(), "select count(1) from ledger1 where id = $1", id) diff --git a/bench_test.go b/bench_test.go index 48433ff3..d212c74f 100644 --- a/bench_test.go +++ b/bench_test.go @@ -613,7 +613,7 @@ func BenchmarkMultipleQueriesBatch(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - batch := conn.BeginBatch() + batch := &pgx.Batch{} for j := 0; j < queryCount; j++ { batch.Queue("select n from generate_series(0,5) n", nil, @@ -622,13 +622,10 @@ func BenchmarkMultipleQueriesBatch(b *testing.B) { ) } - err := batch.Send(context.Background()) - if err != nil { - b.Fatal(err) - } + br := conn.SendBatch(context.Background(), batch) for j := 0; j < queryCount; j++ { - rows, err := batch.QueryResults() + rows, err := br.QueryResults() if err != nil { b.Fatal(err) } @@ -648,7 +645,7 @@ func BenchmarkMultipleQueriesBatch(b *testing.B) { } } - err = batch.Close() + err := br.Close() if err != nil { b.Fatal(err) } diff --git a/conn.go b/conn.go index 28b5546f..363b1b07 100644 --- a/conn.go +++ b/conn.go @@ -704,3 +704,67 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) Ro rows, _ := c.Query(ctx, sql, args...) return (*connRow)(rows.(*connRows)) } + +// SendBatch sends all queued queries to the server at once. All queries are run in an implicit transaction unless +// explicit transaction control statements are executed. +func (c *Conn) SendBatch(ctx context.Context, b *Batch) *BatchResults { + batch := &pgconn.Batch{} + + for _, bi := range b.items { + var parameterOIDs []pgtype.OID + ps := c.preparedStatements[bi.query] + + if ps != nil { + parameterOIDs = ps.ParameterOIDs + } else { + parameterOIDs = bi.parameterOIDs + } + + args, err := convertDriverValuers(bi.arguments) + if err != nil { + return &BatchResults{err: err} + } + + paramFormats := make([]int16, len(args)) + paramValues := make([][]byte, len(args)) + for i := range args { + paramFormats[i] = chooseParameterFormatCode(c.ConnInfo, parameterOIDs[i], args[i]) + paramValues[i], err = newencodePreparedStatementArgument(c.ConnInfo, parameterOIDs[i], args[i]) + if err != nil { + return &BatchResults{err: err} + } + + } + + if ps != nil { + resultFormats := bi.resultFormatCodes + if resultFormats == nil { + resultFormats = make([]int16, len(ps.FieldDescriptions)) + for i := range resultFormats { + if dt, ok := c.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok { + if _, ok := dt.Value.(pgtype.BinaryDecoder); ok { + resultFormats[i] = BinaryFormatCode + } else { + resultFormats[i] = TextFormatCode + } + } + } + } + + batch.ExecPrepared(ps.Name, paramValues, paramFormats, resultFormats) + } else { + oids := make([]uint32, len(parameterOIDs)) + for i := 0; i < len(parameterOIDs); i++ { + oids[i] = uint32(parameterOIDs[i]) + } + batch.ExecParams(bi.query, paramValues, oids, paramFormats, bi.resultFormatCodes) + } + } + + mrr := c.pgConn.ExecBatch(ctx, batch) + + return &BatchResults{ + conn: c, + mrr: mrr, + } +} diff --git a/go.mod b/go.mod index 80d3e26e..e9c3b3c3 100644 --- a/go.mod +++ b/go.mod @@ -10,8 +10,6 @@ require ( github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0 github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b - github.com/lib/pq v1.1.0 - github.com/pkg/errors v0.8.1 github.com/rs/zerolog v1.13.0 github.com/satori/go.uuid v1.2.0 github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24 diff --git a/go.sum b/go.sum index 9d1aae79..f3b0ea68 100644 --- a/go.sum +++ b/go.sum @@ -18,7 +18,6 @@ github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db h1:UpaK github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0 h1:mX93v750WifMD1htCt7vqeolcnpaG1gz8URVGjSzcUM= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= -github.com/jackc/pgx v3.3.0+incompatible h1:Wa90/+qsITBAPkAZjiByeIGHFcj3Ztu+VzrrIpHjL90= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b h1:cIcUpcEP55F/QuZWEtXyqHoWk+IV4TBiLjtBkeq/Q1c= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= @@ -31,6 +30,7 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lib/pq v1.0.0 h1:X5PMW56eZitiTeO7tKzZxFCSpbFZJtkMMooicw2us9A= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.1.0 h1:/5u4a+KGJptBRqGzPvYQL9p0d/tPR4S31+Tnzj9lEO4= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= diff --git a/tx.go b/tx.go index ab6a45e7..ec89c2d5 100644 --- a/tx.go +++ b/tx.go @@ -185,9 +185,13 @@ func (tx *Tx) CopyFrom(ctx context.Context, tableName Identifier, columnNames [] return tx.conn.CopyFrom(ctx, tableName, columnNames, rowSrc) } -// BeginBatch returns a *Batch query for the tx's connection. -func (tx *Tx) BeginBatch() *Batch { - return &Batch{conn: tx.conn} +// SendBatch delegates to the underlying *Conn +func (tx *Tx) SendBatch(ctx context.Context, b *Batch) *BatchResults { + if tx.status != TxStatusInProgress { + return &BatchResults{err: ErrTxClosed} + } + + return tx.conn.SendBatch(ctx, b) } // Status returns the status of the transaction from the set of