Happy-path batch query mode

batch-wip
Jack Christensen 2017-05-29 19:15:42 -05:00
parent dfe250c13b
commit fe0af9b357
5 changed files with 375 additions and 11 deletions

211
batch.go Normal file
View File

@ -0,0 +1,211 @@
package pgx
import (
"context"
"github.com/jackc/pgx/pgproto3"
"github.com/jackc/pgx/pgtype"
)
type batchItem struct {
query string
arguments []interface{}
parameterOids []pgtype.Oid
resultFormatCodes []int16
}
type Batch struct {
conn *Conn
items []*batchItem
resultsRead int
sent bool
}
// Begin starts a transaction with the default transaction mode for the
// current connection. To use a specific transaction mode see BeginEx.
func (c *Conn) BeginBatch() *Batch {
// TODO - the type stuff below
// err = c.waitForPreviousCancelQuery(ctx)
// if err != nil {
// return nil, err
// }
// if err := c.ensureConnectionReadyForQuery(); err != nil {
// return nil, err
// }
// c.lastActivityTime = time.Now()
// rows = c.getRows(sql, args)
// if err := c.lock(); err != nil {
// rows.fatal(err)
// return rows, err
// }
// rows.unlockConn = true
// err = c.initContext(ctx)
// if err != nil {
// rows.fatal(err)
// return rows, rows.err
// }
// if options != nil && options.SimpleProtocol {
// err = c.sanitizeAndSendSimpleQuery(sql, args...)
// if err != nil {
// rows.fatal(err)
// return rows, err
// }
// return rows, nil
// }
return &Batch{conn: c}
}
func (b *Batch) Conn() *Conn {
return b.conn
}
func (b *Batch) Queue(query string, arguments []interface{}, parameterOids []pgtype.Oid, resultFormatCodes []int16) {
b.items = append(b.items, &batchItem{
query: query,
arguments: arguments,
parameterOids: parameterOids,
resultFormatCodes: resultFormatCodes,
})
}
func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error {
buf := appendQuery(b.conn.wbuf, txOptions.beginSQL())
for _, bi := range b.items {
// TODO - don't parse if named prepared statement
buf = appendParse(buf, "", bi.query, bi.parameterOids)
var err error
buf, err = appendBind(buf, "", "", b.conn.ConnInfo, bi.parameterOids, bi.arguments, bi.resultFormatCodes)
if err != nil {
return err
}
buf = appendDescribe(buf, 'P', "")
buf = appendExecute(buf, "", 0)
}
buf = appendSync(buf)
buf = appendQuery(buf, "commit")
n, err := b.conn.conn.Write(buf)
if err != nil {
if fatalWriteErr(n, err) {
b.conn.die(err)
}
return err
}
// expect ReadyForQuery from sync and from commit
b.conn.pendingReadyForQueryCount = b.conn.pendingReadyForQueryCount + 2
b.sent = true
for {
msg, err := b.conn.rxMsg()
if err != nil {
return err
}
switch msg := msg.(type) {
case *pgproto3.ReadyForQuery:
return nil
default:
if err := b.conn.processContextFreeMsg(msg); err != nil {
return err
}
}
}
return nil
}
func (b *Batch) ExecResults() (CommandTag, error) {
b.resultsRead++
for {
msg, err := b.conn.rxMsg()
if err != nil {
return "", err
}
switch msg := msg.(type) {
case *pgproto3.CommandComplete:
return CommandTag(msg.CommandTag), nil
default:
if err := b.conn.processContextFreeMsg(msg); err != nil {
return "", err
}
}
}
}
func (b *Batch) QueryResults() (*Rows, error) {
b.resultsRead++
rows := b.conn.getRows("batch query", nil)
fieldDescriptions, err := b.conn.readUntilRowDescription()
if err != nil {
rows.fatal(err)
return nil, err
}
rows.fields = fieldDescriptions
return rows, nil
}
func (b *Batch) QueryRowResults() *Row {
rows, _ := b.QueryResults()
return (*Row)(rows)
}
func (b *Batch) Finish() error {
for i := b.resultsRead; i < len(b.items); i++ {
_, err := b.ExecResults()
if err != nil {
return err
}
}
// readyForQueryCount := 0
// for {
// msg, err := b.conn.rxMsg()
// if err != nil {
// return "", err
// }
// switch msg := msg.(type) {
// case *pgproto3.ReadyForQuery:
// c.rxReadyForQuery(msg)
// default:
// if err := b.conn.processContextFreeMsg(msg); err != nil {
// return "", err
// }
// }
// }
// switch msg := msg.(type) {
// case *pgproto3.ErrorResponse:
// return c.rxErrorResponse(msg)
// case *pgproto3.NotificationResponse:
// c.rxNotificationResponse(msg)
// case *pgproto3.ReadyForQuery:
// c.rxReadyForQuery(msg)
// case *pgproto3.ParameterStatus:
// c.rxParameterStatus(msg)
// }
return nil
}

150
batch_test.go Normal file
View File

@ -0,0 +1,150 @@
package pgx_test
import (
"context"
"testing"
"github.com/jackc/pgx"
"github.com/jackc/pgx/pgtype"
)
func TestConnBeginBatch(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
sql := `create temporary table ledger(
id serial primary key,
description varchar not null,
amount int not null
);`
mustExec(t, conn, sql)
batch := conn.BeginBatch()
batch.Queue("insert into ledger(description, amount) values($1, $2)",
[]interface{}{"q1", 1},
[]pgtype.Oid{pgtype.VarcharOid, pgtype.Int4Oid},
nil,
)
batch.Queue("insert into ledger(description, amount) values($1, $2)",
[]interface{}{"q2", 2},
[]pgtype.Oid{pgtype.VarcharOid, pgtype.Int4Oid},
nil,
)
batch.Queue("insert into ledger(description, amount) values($1, $2)",
[]interface{}{"q3", 3},
[]pgtype.Oid{pgtype.VarcharOid, pgtype.Int4Oid},
nil,
)
batch.Queue("select id, description, amount from ledger order by id",
nil,
nil,
[]int16{pgx.BinaryFormatCode, pgx.TextFormatCode, pgx.BinaryFormatCode},
)
batch.Queue("select sum(amount) from ledger",
nil,
nil,
[]int16{pgx.BinaryFormatCode},
)
err := batch.Send(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
ct, err := batch.ExecResults()
if err != nil {
t.Error(err)
}
if ct.RowsAffected() != 1 {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
}
ct, err = batch.ExecResults()
if err != nil {
t.Error(err)
}
if ct.RowsAffected() != 1 {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
}
rows, err := batch.QueryResults()
if err != nil {
t.Error(err)
}
var id int32
var description string
var amount int32
if !rows.Next() {
t.Fatal("expected a row to be available")
}
if err := rows.Scan(&id, &description, &amount); err != nil {
t.Fatal(err)
}
if id != 1 {
t.Errorf("id => %v, want %v", id, 1)
}
if description != "q1" {
t.Errorf("description => %v, want %v", description, "q1")
}
if amount != 1 {
t.Errorf("amount => %v, want %v", amount, 1)
}
if !rows.Next() {
t.Fatal("expected a row to be available")
}
if err := rows.Scan(&id, &description, &amount); err != nil {
t.Fatal(err)
}
if id != 2 {
t.Errorf("id => %v, want %v", id, 2)
}
if description != "q2" {
t.Errorf("description => %v, want %v", description, "q2")
}
if amount != 2 {
t.Errorf("amount => %v, want %v", amount, 2)
}
if !rows.Next() {
t.Fatal("expected a row to be available")
}
if err := rows.Scan(&id, &description, &amount); err != nil {
t.Fatal(err)
}
if id != 3 {
t.Errorf("id => %v, want %v", id, 3)
}
if description != "q3" {
t.Errorf("description => %v, want %v", description, "q3")
}
if amount != 3 {
t.Errorf("amount => %v, want %v", amount, 3)
}
if rows.Next() {
t.Fatal("did not expect a row to be available")
}
if rows.Err() != nil {
t.Fatal(rows.Err())
}
err = batch.QueryRowResults().Scan(&amount)
if err != nil {
t.Error(err)
}
if amount != 6 {
t.Errorf("amount => %v, want %v", amount, 6)
}
err = batch.Finish()
if err != nil {
t.Fatal(err)
}
ensureConnValid(t, conn)
}

20
conn.go
View File

@ -107,9 +107,9 @@ type Conn struct {
status byte // One of connStatus* constants
causeOfDeath error
readyForQuery bool // connection has received ReadyForQuery message since last query was sent
cancelQueryInProgress int32
cancelQueryCompleted chan struct{}
pendingReadyForQueryCount int // numer of ReadyForQuery messages expected
cancelQueryInProgress int32
cancelQueryCompleted chan struct{}
// context support
ctxInProgress bool
@ -329,6 +329,8 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
return err
}
c.pendingReadyForQueryCount = 1
for {
msg, err := c.rxMsg()
if err != nil {
@ -782,7 +784,7 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
}
return nil, err
}
c.readyForQuery = false
c.pendingReadyForQueryCount++
ps = &PreparedStatement{Name: name, SQL: sql}
@ -1004,7 +1006,7 @@ func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error {
c.die(err)
return err
}
c.readyForQuery = false
c.pendingReadyForQueryCount++
return nil
}
@ -1045,7 +1047,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
}
return err
}
c.readyForQuery = false
c.pendingReadyForQueryCount++
return nil
}
@ -1167,7 +1169,7 @@ func (c *Conn) rxBackendKeyData(msg *pgproto3.BackendKeyData) {
}
func (c *Conn) rxReadyForQuery(msg *pgproto3.ReadyForQuery) {
c.readyForQuery = true
c.pendingReadyForQueryCount--
c.txStatus = msg.TxStatus
}
@ -1429,7 +1431,7 @@ func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions,
c.die(err)
return "", err
}
c.readyForQuery = false
c.pendingReadyForQueryCount++
} else {
if len(arguments) > 0 {
ps, ok := c.preparedStatements[sql]
@ -1563,7 +1565,7 @@ func (c *Conn) waitForPreviousCancelQuery(ctx context.Context) error {
}
func (c *Conn) ensureConnectionReadyForQuery() error {
for !c.readyForQuery {
for c.pendingReadyForQueryCount > 0 {
msg, err := c.rxMsg()
if err != nil {
return err

View File

@ -1,8 +1,9 @@
package pgx_test
import (
"github.com/jackc/pgx"
"testing"
"github.com/jackc/pgx"
)
func mustConnect(t testing.TB, config pgx.ConnConfig) *pgx.Conn {

View File

@ -409,7 +409,7 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions,
c.die(err)
return nil, err
}
c.readyForQuery = false
c.pendingReadyForQueryCount++
fieldDescriptions, err := c.readUntilRowDescription()
if err != nil {