From 64b07f0d6622985b47d7cc5ba7fcf69a24754f73 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 24 Aug 2019 23:34:54 -0500 Subject: [PATCH] Batch uses statement cache. This streamlines Queue's interface as well. --- batch.go | 18 ++-- batch_test.go | 184 ++++++++++++++++++----------------------- bench_test.go | 76 ++++++++++++++--- conn.go | 75 ++++++++++------- pgxpool/common_test.go | 4 +- 5 files changed, 201 insertions(+), 156 deletions(-) diff --git a/batch.go b/batch.go index 453ef5c5..4dd49d3e 100644 --- a/batch.go +++ b/batch.go @@ -8,10 +8,8 @@ import ( ) type batchItem struct { - query string - arguments []interface{} - parameterOIDs []uint32 - resultFormatCodes []int16 + query string + arguments []interface{} } // Batch queries are a way of bundling multiple queries together to avoid @@ -20,15 +18,11 @@ type Batch struct { items []*batchItem } -// Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement. parameterOIDs and -// resultFormatCodes should be nil if query is a prepared statement. Otherwise, parameterOIDs are required if there are -// parameters and resultFormatCodes are required if there is a result. -func (b *Batch) Queue(query string, arguments []interface{}, parameterOIDs []uint32, resultFormatCodes []int16) { +// Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement. +func (b *Batch) Queue(query string, arguments ...interface{}) { b.items = append(b.items, &batchItem{ - query: query, - arguments: arguments, - parameterOIDs: parameterOIDs, - resultFormatCodes: resultFormatCodes, + query: query, + arguments: arguments, }) } diff --git a/batch_test.go b/batch_test.go index d8ebf53e..54338f68 100644 --- a/batch_test.go +++ b/batch_test.go @@ -6,8 +6,9 @@ import ( "testing" "github.com/jackc/pgconn" - "github.com/jackc/pgtype" + "github.com/jackc/pgconn/stmtcache" "github.com/jackc/pgx/v4" + "github.com/stretchr/testify/require" ) func TestConnSendBatch(t *testing.T) { @@ -24,31 +25,11 @@ func TestConnSendBatch(t *testing.T) { mustExec(t, conn, sql) batch := &pgx.Batch{} - batch.Queue("insert into ledger(description, amount) values($1, $2)", - []interface{}{"q1", 1}, - []uint32{pgtype.VarcharOID, pgtype.Int4OID}, - nil, - ) - batch.Queue("insert into ledger(description, amount) values($1, $2)", - []interface{}{"q2", 2}, - []uint32{pgtype.VarcharOID, pgtype.Int4OID}, - nil, - ) - batch.Queue("insert into ledger(description, amount) values($1, $2)", - []interface{}{"q3", 3}, - []uint32{pgtype.VarcharOID, pgtype.Int4OID}, - nil, - ) - batch.Queue("select id, description, amount from ledger order by id", - nil, - nil, - []int16{pgx.BinaryFormatCode, pgx.TextFormatCode, pgx.BinaryFormatCode}, - ) - batch.Queue("select sum(amount) from ledger", - nil, - nil, - []int16{pgx.BinaryFormatCode}, - ) + batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1) + batch.Queue("insert into ledger(description, amount) values($1, $2)", "q2", 2) + batch.Queue("insert into ledger(description, amount) values($1, $2)", "q3", 3) + batch.Queue("select id, description, amount from ledger order by id") + batch.Queue("select sum(amount) from ledger") br := conn.SendBatch(context.Background(), batch) @@ -171,11 +152,7 @@ func TestConnSendBatchWithPreparedStatement(t *testing.T) { queryCount := 3 for i := 0; i < queryCount; i++ { - batch.Queue("ps1", - []interface{}{5}, - nil, - nil, - ) + batch.Queue("ps1", 5) } br := conn.SendBatch(context.Background(), batch) @@ -216,16 +193,8 @@ func TestConnSendBatchCloseRowsPartiallyRead(t *testing.T) { defer closeConn(t, conn) batch := &pgx.Batch{} - batch.Queue("select n from generate_series(0,5) n", - nil, - nil, - []int16{pgx.BinaryFormatCode}, - ) - batch.Queue("select n from generate_series(0,5) n", - nil, - nil, - []int16{pgx.BinaryFormatCode}, - ) + batch.Queue("select n from generate_series(0,5) n") + batch.Queue("select n from generate_series(0,5) n") br := conn.SendBatch(context.Background(), batch) @@ -284,16 +253,8 @@ func TestConnSendBatchQueryError(t *testing.T) { defer closeConn(t, conn) batch := &pgx.Batch{} - batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0", - nil, - nil, - []int16{pgx.BinaryFormatCode}, - ) - batch.Queue("select n from generate_series(0,5) n", - nil, - nil, - []int16{pgx.BinaryFormatCode}, - ) + batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0") + batch.Queue("select n from generate_series(0,5) n") br := conn.SendBatch(context.Background(), batch) @@ -331,11 +292,7 @@ func TestConnSendBatchQuerySyntaxError(t *testing.T) { defer closeConn(t, conn) batch := &pgx.Batch{} - batch.Queue("select 1 1", - nil, - nil, - []int16{pgx.BinaryFormatCode}, - ) + batch.Queue("select 1 1") br := conn.SendBatch(context.Background(), batch) @@ -367,16 +324,8 @@ func TestConnSendBatchQueryRowInsert(t *testing.T) { mustExec(t, conn, sql) batch := &pgx.Batch{} - batch.Queue("select 1", - nil, - nil, - []int16{pgx.BinaryFormatCode}, - ) - batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", - []interface{}{"q1", 1}, - []uint32{pgtype.VarcharOID, pgtype.Int4OID}, - nil, - ) + batch.Queue("select 1") + batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1) br := conn.SendBatch(context.Background(), batch) @@ -413,16 +362,8 @@ func TestConnSendBatchQueryPartialReadInsert(t *testing.T) { mustExec(t, conn, sql) batch := &pgx.Batch{} - batch.Queue("select 1 union all select 2 union all select 3", - nil, - nil, - []int16{pgx.BinaryFormatCode}, - ) - batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", - []interface{}{"q1", 1}, - []uint32{pgtype.VarcharOID, pgtype.Int4OID}, - nil, - ) + batch.Queue("select 1 union all select 2 union all select 3") + batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1) br := conn.SendBatch(context.Background(), batch) @@ -465,11 +406,7 @@ func TestTxSendBatch(t *testing.T) { tx, _ := conn.Begin(context.Background()) batch := &pgx.Batch{} - batch.Queue("insert into ledger1(description) values($1) returning id", - []interface{}{"q1"}, - []uint32{pgtype.VarcharOID}, - []int16{pgx.BinaryFormatCode}, - ) + batch.Queue("insert into ledger1(description) values($1) returning id", "q1") br := tx.SendBatch(context.Background(), batch) @@ -481,17 +418,8 @@ func TestTxSendBatch(t *testing.T) { br.Close() batch = &pgx.Batch{} - batch.Queue("insert into ledger2(id,amount) values($1, $2)", - []interface{}{id, 2}, - []uint32{pgtype.Int4OID, pgtype.Int4OID}, - nil, - ) - - batch.Queue("select amount from ledger2 where id = $1", - []interface{}{id}, - []uint32{pgtype.Int4OID}, - nil, - ) + batch.Queue("insert into ledger2(id,amount) values($1, $2)", id, 2) + batch.Queue("select amount from ledger2 where id = $1", id) br = tx.SendBatch(context.Background(), batch) @@ -540,11 +468,7 @@ func TestTxSendBatchRollback(t *testing.T) { tx, _ := conn.Begin(context.Background()) batch := &pgx.Batch{} - batch.Queue("insert into ledger1(description) values($1) returning id", - []interface{}{"q1"}, - []uint32{pgtype.VarcharOID}, - []int16{pgx.BinaryFormatCode}, - ) + batch.Queue("insert into ledger1(description) values($1) returning id", "q1") br := tx.SendBatch(context.Background(), batch) @@ -582,11 +506,7 @@ func TestConnBeginBatchDeferredError(t *testing.T) { batch := &pgx.Batch{} - batch.Queue(`update t set n=n+1 where id='b' returning *`, - nil, - nil, - []int16{pgx.BinaryFormatCode}, - ) + batch.Queue(`update t set n=n+1 where id='b' returning *`) br := conn.SendBatch(context.Background(), batch) @@ -615,3 +535,63 @@ func TestConnBeginBatchDeferredError(t *testing.T) { ensureConnValid(t, conn) } + +func TestConnSendBatchNoStatementCache(t *testing.T) { + config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) + config.BuildStatementCache = nil + + conn := mustConnect(t, config) + defer closeConn(t, conn) + + testConnSendBatch(t, conn, 3) +} + +func TestConnSendBatchPrepareStatementCache(t *testing.T) { + config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) + config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { + return stmtcache.New(conn, stmtcache.ModePrepare, 32) + } + + conn := mustConnect(t, config) + defer closeConn(t, conn) + + testConnSendBatch(t, conn, 3) +} + +func TestConnSendBatchDescribeStatementCache(t *testing.T) { + config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) + config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { + return stmtcache.New(conn, stmtcache.ModeDescribe, 32) + } + + conn := mustConnect(t, config) + defer closeConn(t, conn) + + testConnSendBatch(t, conn, 3) +} + +func testConnSendBatch(t *testing.T, conn *pgx.Conn, queryCount int) { + batch := &pgx.Batch{} + for j := 0; j < queryCount; j++ { + batch.Queue("select n from generate_series(0,5) n") + } + + br := conn.SendBatch(context.Background(), batch) + + for j := 0; j < queryCount; j++ { + rows, err := br.QueryResults() + require.NoError(t, err) + + for k := 0; rows.Next(); k++ { + var n int + err := rows.Scan(&n) + require.NoError(t, err) + require.Equal(t, k, n) + } + + require.NoError(t, rows.Err()) + } + + err := br.Close() + require.NoError(t, err) +} diff --git a/bench_test.go b/bench_test.go index bd8fa1ad..b8edfc27 100644 --- a/bench_test.go +++ b/bench_test.go @@ -710,12 +710,41 @@ func BenchmarkWrite10000RowsViaCopy(b *testing.B) { benchmarkWriteNRowsViaCopy(b, 10000) } -func BenchmarkMultipleQueriesNonBatch(b *testing.B) { - conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))) +func BenchmarkMultipleQueriesNonBatchNoStatementCache(b *testing.B) { + config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) + config.BuildStatementCache = nil + + conn := mustConnect(b, config) defer closeConn(b, conn) - queryCount := 3 + benchmarkMultipleQueriesNonBatch(b, conn, 3) +} +func BenchmarkMultipleQueriesNonBatchPrepareStatementCache(b *testing.B) { + config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) + config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { + return stmtcache.New(conn, stmtcache.ModePrepare, 32) + } + + conn := mustConnect(b, config) + defer closeConn(b, conn) + + benchmarkMultipleQueriesNonBatch(b, conn, 3) +} + +func BenchmarkMultipleQueriesNonBatchDescribeStatementCache(b *testing.B) { + config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) + config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { + return stmtcache.New(conn, stmtcache.ModeDescribe, 32) + } + + conn := mustConnect(b, config) + defer closeConn(b, conn) + + benchmarkMultipleQueriesNonBatch(b, conn, 3) +} + +func benchmarkMultipleQueriesNonBatch(b *testing.B, conn *pgx.Conn, queryCount int) { b.ResetTimer() for i := 0; i < b.N; i++ { for j := 0; j < queryCount; j++ { @@ -741,21 +770,46 @@ func BenchmarkMultipleQueriesNonBatch(b *testing.B) { } } -func BenchmarkMultipleQueriesBatch(b *testing.B) { - conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))) +func BenchmarkMultipleQueriesBatchNoStatementCache(b *testing.B) { + config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) + config.BuildStatementCache = nil + + conn := mustConnect(b, config) defer closeConn(b, conn) - queryCount := 3 + benchmarkMultipleQueriesBatch(b, conn, 3) +} +func BenchmarkMultipleQueriesBatchPrepareStatementCache(b *testing.B) { + config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) + config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { + return stmtcache.New(conn, stmtcache.ModePrepare, 32) + } + + conn := mustConnect(b, config) + defer closeConn(b, conn) + + benchmarkMultipleQueriesBatch(b, conn, 3) +} + +func BenchmarkMultipleQueriesBatchDescribeStatementCache(b *testing.B) { + config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) + config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { + return stmtcache.New(conn, stmtcache.ModeDescribe, 32) + } + + conn := mustConnect(b, config) + defer closeConn(b, conn) + + benchmarkMultipleQueriesBatch(b, conn, 3) +} + +func benchmarkMultipleQueriesBatch(b *testing.B, conn *pgx.Conn, queryCount int) { b.ResetTimer() for i := 0; i < b.N; i++ { batch := &pgx.Batch{} for j := 0; j < queryCount; j++ { - batch.Queue("select n from generate_series(0,5) n", - nil, - nil, - []int16{pgx.BinaryFormatCode}, - ) + batch.Queue("select n from generate_series(0,5) n") } br := conn.SendBatch(context.Background(), batch) diff --git a/conn.go b/conn.go index dc1819e7..36352d1c 100644 --- a/conn.go +++ b/conn.go @@ -715,57 +715,74 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) Ro // SendBatch sends all queued queries to the server at once. All queries are run in an implicit transaction unless // explicit transaction control statements are executed. func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { + distinctUnpreparedQueries := map[string]struct{}{} + + for _, bi := range b.items { + if _, ok := c.preparedStatements[bi.query]; ok { + continue + } + distinctUnpreparedQueries[bi.query] = struct{}{} + } + + var stmtCache stmtcache.Cache + if c.stmtcache != nil && c.stmtcache.Cap() >= len(distinctUnpreparedQueries) { + stmtCache = c.stmtcache + } else { + stmtCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, len(distinctUnpreparedQueries)) + } + + for sql, _ := range distinctUnpreparedQueries { + _, err := stmtCache.Get(ctx, sql) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} + } + } + batch := &pgconn.Batch{} for _, bi := range b.items { c.eqb.Reset() - var parameterOIDs []uint32 sd := c.preparedStatements[bi.query] + if sd == nil { + var err error + sd, err = stmtCache.Get(ctx, bi.query) + if err != nil { + // the stmtCache was prefilled from distinctUnpreparedQueries above so we are guaranteed no errors + panic("BUG: unexpected error from stmtCache") + } + } - if sd != nil { - parameterOIDs = sd.ParamOIDs - } else { - parameterOIDs = bi.parameterOIDs + if len(sd.ParamOIDs) != len(bi.arguments) { + return &batchResults{ctx: ctx, conn: c, err: errors.Errorf("mismatched param and argument count")} } args, err := convertDriverValuers(bi.arguments) if err != nil { - return &batchResults{err: err} + return &batchResults{ctx: ctx, conn: c, err: err} } for i := range args { - err = c.eqb.AppendParam(c.ConnInfo, parameterOIDs[i], args[i]) + err = c.eqb.AppendParam(c.ConnInfo, sd.ParamOIDs[i], args[i]) if err != nil { - return &batchResults{err: err} + return &batchResults{ctx: ctx, conn: c, err: err} } - } - if sd != nil { - resultFormats := bi.resultFormatCodes - if resultFormats == nil { - - for i := range sd.Fields { - if dt, ok := c.ConnInfo.DataTypeForOID(uint32(sd.Fields[i].DataTypeOID)); ok { - if _, ok := dt.Value.(pgtype.BinaryDecoder); ok { - c.eqb.AppendResultFormat(BinaryFormatCode) - } else { - c.eqb.AppendResultFormat(TextFormatCode) - } - } + for i := range sd.Fields { + if dt, ok := c.ConnInfo.DataTypeForOID(uint32(sd.Fields[i].DataTypeOID)); ok { + if _, ok := dt.Value.(pgtype.BinaryDecoder); ok { + c.eqb.AppendResultFormat(BinaryFormatCode) + } else { + c.eqb.AppendResultFormat(TextFormatCode) } - - resultFormats = c.eqb.resultFormats } + } - batch.ExecPrepared(sd.Name, c.eqb.paramValues, c.eqb.paramFormats, resultFormats) + if sd.Name == "" { + batch.ExecParams(bi.query, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, c.eqb.resultFormats) } else { - oids := make([]uint32, len(parameterOIDs)) - for i := 0; i < len(parameterOIDs); i++ { - oids[i] = uint32(parameterOIDs[i]) - } - batch.ExecParams(bi.query, c.eqb.paramValues, oids, c.eqb.paramFormats, bi.resultFormatCodes) + batch.ExecPrepared(sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats) } } diff --git a/pgxpool/common_test.go b/pgxpool/common_test.go index fbe96d15..a4ceeb1d 100644 --- a/pgxpool/common_test.go +++ b/pgxpool/common_test.go @@ -68,8 +68,8 @@ type sendBatcher interface { func testSendBatch(t *testing.T, db sendBatcher) { batch := &pgx.Batch{} - batch.Queue("select 1", nil, nil, nil) - batch.Queue("select 2", nil, nil, nil) + batch.Queue("select 1") + batch.Queue("select 2") br := db.SendBatch(context.Background(), batch)