From 73f496d7de2ee5bc59155e42b145743cebdb2614 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 3 Jun 2017 11:49:27 -0500 Subject: [PATCH] Finish core batch operations --- batch.go | 180 ++++++++++++++++------------ batch_test.go | 296 +++++++++++++++++++++++++++++++++++++++++++++- bench_test.go | 90 ++++++++++++++ conn_pool.go | 7 ++ conn_pool_test.go | 67 +++++++++++ query.go | 5 + v3.md | 2 + 7 files changed, 572 insertions(+), 75 deletions(-) diff --git a/batch.go b/batch.go index 722ce340..a2e8e042 100644 --- a/batch.go +++ b/batch.go @@ -14,60 +14,31 @@ type batchItem struct { resultFormatCodes []int16 } +// Batch queries are a way of bundling multiple queries together to avoid +// unnecessary network round trips. type Batch struct { conn *Conn + connPool *ConnPool items []*batchItem resultsRead int sent bool + ctx context.Context + err error } -// Begin starts a transaction with the default transaction mode for the -// current connection. To use a specific transaction mode see BeginEx. +// BeginBatch returns a *Batch query for c. func (c *Conn) BeginBatch() *Batch { - // TODO - the type stuff below - - // err = c.waitForPreviousCancelQuery(ctx) - // if err != nil { - // return nil, err - // } - - // if err := c.ensureConnectionReadyForQuery(); err != nil { - // return nil, err - // } - - // c.lastActivityTime = time.Now() - - // rows = c.getRows(sql, args) - - // if err := c.lock(); err != nil { - // rows.fatal(err) - // return rows, err - // } - // rows.unlockConn = true - - // err = c.initContext(ctx) - // if err != nil { - // rows.fatal(err) - // return rows, rows.err - // } - - // if options != nil && options.SimpleProtocol { - // err = c.sanitizeAndSendSimpleQuery(sql, args...) - // if err != nil { - // rows.fatal(err) - // return rows, err - // } - - // return rows, nil - // } - return &Batch{conn: c} } +// Conn returns the underlying connection that b will or was performed on. func (b *Batch) Conn() *Conn { return b.conn } +// Queue queues a query to batch b. parameterOids are required if there are +// parameters and query is not the name of a prepared statement. +// resultFormatCodes are required if there is a result. func (b *Batch) Queue(query string, arguments []interface{}, parameterOids []pgtype.Oid, resultFormatCodes []int16) { b.items = append(b.items, &batchItem{ query: query, @@ -77,15 +48,46 @@ func (b *Batch) Queue(query string, arguments []interface{}, parameterOids []pgt }) } +// Send sends all queued queries to the server at once. All queries are wrapped +// in a transaction. The transaction can optionally be configured with +// txOptions. The context is in effect until the Batch is closed. func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error { + if b.err != nil { + return b.err + } + + b.ctx = ctx + + err := b.conn.waitForPreviousCancelQuery(ctx) + if err != nil { + return err + } + + if err := b.conn.ensureConnectionReadyForQuery(); err != nil { + return err + } + + err = b.conn.initContext(ctx) + if err != nil { + return err + } + buf := appendQuery(b.conn.wbuf, txOptions.beginSQL()) for _, bi := range b.items { - // TODO - don't parse if named prepared statement - buf = appendParse(buf, "", bi.query, bi.parameterOids) + var psName string + var psParameterOids []pgtype.Oid + + if ps, ok := b.conn.preparedStatements[bi.query]; ok { + psName = ps.Name + psParameterOids = ps.ParameterOids + } else { + psParameterOids = bi.parameterOids + buf = appendParse(buf, "", bi.query, psParameterOids) + } var err error - buf, err = appendBind(buf, "", "", b.conn.ConnInfo, bi.parameterOids, bi.arguments, bi.resultFormatCodes) + buf, err = appendBind(buf, "", psName, b.conn.ConnInfo, psParameterOids, bi.arguments, bi.resultFormatCodes) if err != nil { return err } @@ -129,7 +131,20 @@ func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error { return nil } +// ExecResults reads the results from the next query in the batch as if the +// query has been sent with Exec. func (b *Batch) ExecResults() (CommandTag, error) { + if b.err != nil { + return "", b.err + } + + select { + case <-b.ctx.Done(): + b.die(b.ctx.Err()) + return "", b.ctx.Err() + default: + } + b.resultsRead++ for { @@ -149,63 +164,80 @@ func (b *Batch) ExecResults() (CommandTag, error) { } } +// QueryResults reads the results from the next query in the batch as if the +// query has been sent with Query. func (b *Batch) QueryResults() (*Rows, error) { + if b.err != nil { + return nil, b.err + } + + select { + case <-b.ctx.Done(): + b.die(b.ctx.Err()) + return nil, b.ctx.Err() + default: + } + b.resultsRead++ rows := b.conn.getRows("batch query", nil) fieldDescriptions, err := b.conn.readUntilRowDescription() if err != nil { - rows.fatal(err) + b.die(b.ctx.Err()) return nil, err } + rows.batch = b rows.fields = fieldDescriptions return rows, nil } +// QueryRowResults reads the results from the next query in the batch as if the +// query has been sent with QueryRow. func (b *Batch) QueryRowResults() *Row { rows, _ := b.QueryResults() return (*Row)(rows) } -func (b *Batch) Finish() error { +// Close closes the batch operation. Any error that occured during a batch +// operation may have made it impossible to resyncronize the connection with the +// server. In this case the underlying connection will have been closed. +func (b *Batch) Close() (err error) { + if b.err != nil { + return b.err + } + + defer func() { + err = b.conn.termContext(err) + if b.conn != nil && b.connPool != nil { + b.connPool.Release(b.conn) + } + }() + for i := b.resultsRead; i < len(b.items); i++ { - _, err := b.ExecResults() - if err != nil { + if _, err = b.ExecResults(); err != nil { return err } } - // readyForQueryCount := 0 - - // for { - // msg, err := b.conn.rxMsg() - // if err != nil { - // return "", err - // } - - // switch msg := msg.(type) { - // case *pgproto3.ReadyForQuery: - // c.rxReadyForQuery(msg) - // default: - // if err := b.conn.processContextFreeMsg(msg); err != nil { - // return "", err - // } - // } - // } - - // switch msg := msg.(type) { - // case *pgproto3.ErrorResponse: - // return c.rxErrorResponse(msg) - // case *pgproto3.NotificationResponse: - // c.rxNotificationResponse(msg) - // case *pgproto3.ReadyForQuery: - // c.rxReadyForQuery(msg) - // case *pgproto3.ParameterStatus: - // c.rxParameterStatus(msg) - // } + if err = b.conn.ensureConnectionReadyForQuery(); err != nil { + return err + } return nil } + +func (b *Batch) die(err error) { + if b.err != nil { + return + } + + b.err = err + b.conn.die(err) + + if b.conn != nil && b.connPool != nil { + b.connPool.Release(b.conn) + } +} diff --git a/batch_test.go b/batch_test.go index aeef52f4..bccf9a20 100644 --- a/batch_test.go +++ b/batch_test.go @@ -141,10 +141,304 @@ func TestConnBeginBatch(t *testing.T) { t.Errorf("amount => %v, want %v", amount, 6) } - err = batch.Finish() + err = batch.Close() if err != nil { t.Fatal(err) } ensureConnValid(t, conn) } + +func TestConnBeginBatchWithPreparedStatement(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + _, err := conn.Prepare("ps1", "select n from generate_series(0,$1::int) n") + if err != nil { + t.Fatal(err) + } + + batch := conn.BeginBatch() + + queryCount := 3 + for i := 0; i < queryCount; i++ { + batch.Queue("ps1", + []interface{}{5}, + nil, + []int16{pgx.BinaryFormatCode}, + ) + } + + err = batch.Send(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + + for i := 0; i < queryCount; i++ { + rows, err := batch.QueryResults() + if err != nil { + t.Fatal(err) + } + + for k := 0; rows.Next(); k++ { + var n int + if err := rows.Scan(&n); err != nil { + t.Fatal(err) + } + if n != k { + t.Fatalf("n => %v, want %v", n, k) + } + } + + if rows.Err() != nil { + t.Fatal(rows.Err()) + } + } + + err = batch.Close() + if err != nil { + t.Fatal(err) + } + + ensureConnValid(t, conn) +} + +func TestConnBeginBatchContextCancelBeforeExecResults(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + + sql := `create temporary table ledger( + id serial primary key, + description varchar not null, + amount int not null +);` + mustExec(t, conn, sql) + + batch := conn.BeginBatch() + batch.Queue("insert into ledger(description, amount) values($1, $2)", + []interface{}{"q1", 1}, + []pgtype.Oid{pgtype.VarcharOid, pgtype.Int4Oid}, + nil, + ) + batch.Queue("select pg_sleep(2)", + nil, + nil, + nil, + ) + + ctx, cancelFn := context.WithCancel(context.Background()) + + err := batch.Send(ctx, nil) + if err != nil { + t.Fatal(err) + } + + cancelFn() + + _, err = batch.ExecResults() + if err != context.Canceled { + t.Errorf("err => %v, want %v", err, context.Canceled) + } + + if conn.IsAlive() { + t.Error("conn should be dead, but was alive") + } +} + +func TestConnBeginBatchContextCancelBeforeQueryResults(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + + batch := conn.BeginBatch() + batch.Queue("select pg_sleep(2)", + nil, + nil, + nil, + ) + batch.Queue("select pg_sleep(2)", + nil, + nil, + nil, + ) + + ctx, cancelFn := context.WithCancel(context.Background()) + + err := batch.Send(ctx, nil) + if err != nil { + t.Fatal(err) + } + + cancelFn() + + _, err = batch.QueryResults() + if err != context.Canceled { + t.Errorf("err => %v, want %v", err, context.Canceled) + } + + if conn.IsAlive() { + t.Error("conn should be dead, but was alive") + } +} + +func TestConnBeginBatchContextCancelBeforeFinish(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + + batch := conn.BeginBatch() + batch.Queue("select pg_sleep(2)", + nil, + nil, + nil, + ) + batch.Queue("select pg_sleep(2)", + nil, + nil, + nil, + ) + + ctx, cancelFn := context.WithCancel(context.Background()) + + err := batch.Send(ctx, nil) + if err != nil { + t.Fatal(err) + } + + cancelFn() + + err = batch.Close() + if err != context.Canceled { + t.Errorf("err => %v, want %v", err, context.Canceled) + } + + if conn.IsAlive() { + t.Error("conn should be dead, but was alive") + } +} + +func TestConnBeginBatchCloseRowsPartiallyRead(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + batch := conn.BeginBatch() + batch.Queue("select n from generate_series(0,5) n", + nil, + nil, + []int16{pgx.BinaryFormatCode}, + ) + batch.Queue("select n from generate_series(0,5) n", + nil, + nil, + []int16{pgx.BinaryFormatCode}, + ) + + err := batch.Send(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + + rows, err := batch.QueryResults() + if err != nil { + t.Error(err) + } + + for i := 0; i < 3; i++ { + if !rows.Next() { + t.Error("expected a row to be available") + } + + var n int + if err := rows.Scan(&n); err != nil { + t.Error(err) + } + if n != i { + t.Errorf("n => %v, want %v", n, i) + } + } + + rows.Close() + + rows, err = batch.QueryResults() + if err != nil { + t.Error(err) + } + + for i := 0; rows.Next(); i++ { + var n int + if err := rows.Scan(&n); err != nil { + t.Error(err) + } + if n != i { + t.Errorf("n => %v, want %v", n, i) + } + } + + if rows.Err() != nil { + t.Error(rows.Err()) + } + + err = batch.Close() + if err != nil { + t.Fatal(err) + } + + ensureConnValid(t, conn) +} + +func TestConnBeginBatchQueryError(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + batch := conn.BeginBatch() + batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0", + nil, + nil, + []int16{pgx.BinaryFormatCode}, + ) + batch.Queue("select n from generate_series(0,5) n", + nil, + nil, + []int16{pgx.BinaryFormatCode}, + ) + + err := batch.Send(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + + rows, err := batch.QueryResults() + if err != nil { + t.Error(err) + } + + for i := 0; rows.Next(); i++ { + var n int + if err := rows.Scan(&n); err != nil { + t.Error(err) + } + if n != i { + t.Errorf("n => %v, want %v", n, i) + } + } + + if pgErr, ok := rows.Err().(pgx.PgError); !(ok && pgErr.Code == "22012") { + t.Errorf("rows.Err() => %v, want error code %v", rows.Err(), 22012) + } + + err = batch.Close() + if pgErr, ok := err.(pgx.PgError); !(ok && pgErr.Code == "22012") { + t.Errorf("rows.Err() => %v, want error code %v", err, 22012) + } + + if conn.IsAlive() { + t.Error("conn should be dead, but was alive") + } +} diff --git a/bench_test.go b/bench_test.go index d3525df5..7f82891e 100644 --- a/bench_test.go +++ b/bench_test.go @@ -2,6 +2,7 @@ package pgx_test import ( "bytes" + "context" "fmt" "strings" "testing" @@ -609,3 +610,92 @@ func BenchmarkWrite10000RowsViaMultiInsert(b *testing.B) { func BenchmarkWrite10000RowsViaCopy(b *testing.B) { benchmarkWriteNRowsViaCopy(b, 10000) } + +func BenchmarkMultipleQueriesNonBatch(b *testing.B) { + config := pgx.ConnPoolConfig{ConnConfig: *defaultConnConfig, MaxConnections: 5} + pool, err := pgx.NewConnPool(config) + if err != nil { + b.Fatalf("Unable to create connection pool: %v", err) + } + defer pool.Close() + + queryCount := 3 + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for j := 0; j < queryCount; j++ { + rows, err := pool.Query("select n from generate_series(0, 5) n") + if err != nil { + b.Fatal(err) + } + + for k := 0; rows.Next(); k++ { + var n int + if err := rows.Scan(&n); err != nil { + b.Fatal(err) + } + if n != k { + b.Fatalf("n => %v, want %v", n, k) + } + } + + if rows.Err() != nil { + b.Fatal(rows.Err()) + } + } + } +} + +func BenchmarkMultipleQueriesBatch(b *testing.B) { + config := pgx.ConnPoolConfig{ConnConfig: *defaultConnConfig, MaxConnections: 5} + pool, err := pgx.NewConnPool(config) + if err != nil { + b.Fatalf("Unable to create connection pool: %v", err) + } + defer pool.Close() + + queryCount := 3 + + b.ResetTimer() + for i := 0; i < b.N; i++ { + batch := pool.BeginBatch() + for j := 0; j < queryCount; j++ { + batch.Queue("select n from generate_series(0,5) n", + nil, + nil, + []int16{pgx.BinaryFormatCode}, + ) + } + + err := batch.Send(context.Background(), nil) + if err != nil { + b.Fatal(err) + } + + for j := 0; j < queryCount; j++ { + rows, err := batch.QueryResults() + if err != nil { + b.Fatal(err) + } + + for k := 0; rows.Next(); k++ { + var n int + if err := rows.Scan(&n); err != nil { + b.Fatal(err) + } + if n != k { + b.Fatalf("n => %v, want %v", n, k) + } + } + + if rows.Err() != nil { + b.Fatal(rows.Err()) + } + } + + err = batch.Close() + if err != nil { + b.Fatal(err) + } + } +} diff --git a/conn_pool.go b/conn_pool.go index 42200b85..fdfc70f5 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -536,3 +536,10 @@ func (p *ConnPool) CopyFrom(tableName Identifier, columnNames []string, rowSrc C return c.CopyFrom(tableName, columnNames, rowSrc) } + +// BeginBatch acquires a connection and begins a batch on that connection. When +// *Batch is finished, the connection is released automatically. +func (p *ConnPool) BeginBatch() *Batch { + c, err := p.Acquire() + return &Batch{conn: c, connPool: p, err: err} +} diff --git a/conn_pool_test.go b/conn_pool_test.go index 560ab3ae..4e0dc199 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -981,3 +981,70 @@ func TestConnPoolPrepareWhenConnIsAlreadyAcquired(t *testing.T) { t.Errorf("Expected error calling deallocated prepared statement, but got: %v", err) } } + +func TestConnPoolBeginBatch(t *testing.T) { + t.Parallel() + + pool := createConnPool(t, 2) + defer pool.Close() + + batch := pool.BeginBatch() + batch.Queue("select n from generate_series(0,5) n", + nil, + nil, + []int16{pgx.BinaryFormatCode}, + ) + batch.Queue("select n from generate_series(0,5) n", + nil, + nil, + []int16{pgx.BinaryFormatCode}, + ) + + err := batch.Send(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + + rows, err := batch.QueryResults() + if err != nil { + t.Error(err) + } + + for i := 0; rows.Next(); i++ { + var n int + if err := rows.Scan(&n); err != nil { + t.Error(err) + } + if n != i { + t.Errorf("n => %v, want %v", n, i) + } + } + + if rows.Err() != nil { + t.Error(rows.Err()) + } + + rows, err = batch.QueryResults() + if err != nil { + t.Error(err) + } + + for i := 0; rows.Next(); i++ { + var n int + if err := rows.Scan(&n); err != nil { + t.Error(err) + } + if n != i { + t.Errorf("n => %v, want %v", n, i) + } + } + + if rows.Err() != nil { + t.Error(rows.Err()) + } + + err = batch.Close() + if err != nil { + t.Fatal(err) + } +} diff --git a/query.go b/query.go index a3903a22..6c9f6ab0 100644 --- a/query.go +++ b/query.go @@ -43,6 +43,7 @@ func (r *Row) Scan(dest ...interface{}) (err error) { type Rows struct { conn *Conn connPool *ConnPool + batch *Batch values [][]byte fields []FieldDescription rowCount int @@ -84,6 +85,10 @@ func (rows *Rows) Close() { rows.conn.log(LogLevelError, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args)}) } + if rows.batch != nil && rows.err != nil { + rows.batch.die(rows.err) + } + if rows.connPool != nil { rows.connPool.Release(rows.conn) } diff --git a/v3.md b/v3.md index 33a27d2d..b369be18 100644 --- a/v3.md +++ b/v3.md @@ -50,6 +50,8 @@ Removed Tx.Conn() Added ctx parameter to (Conn/Tx/ConnPool).PrepareEx +Added batch operations + ## TODO / Possible / Investigate Organize errors better