mirror of https://github.com/jackc/pgx.git
Happy-path batch query mode
parent
dfe250c13b
commit
fe0af9b357
|
@ -0,0 +1,211 @@
|
||||||
|
package pgx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/pgproto3"
|
||||||
|
"github.com/jackc/pgx/pgtype"
|
||||||
|
)
|
||||||
|
|
||||||
|
type batchItem struct {
|
||||||
|
query string
|
||||||
|
arguments []interface{}
|
||||||
|
parameterOids []pgtype.Oid
|
||||||
|
resultFormatCodes []int16
|
||||||
|
}
|
||||||
|
|
||||||
|
type Batch struct {
|
||||||
|
conn *Conn
|
||||||
|
items []*batchItem
|
||||||
|
resultsRead int
|
||||||
|
sent bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Begin starts a transaction with the default transaction mode for the
|
||||||
|
// current connection. To use a specific transaction mode see BeginEx.
|
||||||
|
func (c *Conn) BeginBatch() *Batch {
|
||||||
|
// TODO - the type stuff below
|
||||||
|
|
||||||
|
// err = c.waitForPreviousCancelQuery(ctx)
|
||||||
|
// if err != nil {
|
||||||
|
// return nil, err
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if err := c.ensureConnectionReadyForQuery(); err != nil {
|
||||||
|
// return nil, err
|
||||||
|
// }
|
||||||
|
|
||||||
|
// c.lastActivityTime = time.Now()
|
||||||
|
|
||||||
|
// rows = c.getRows(sql, args)
|
||||||
|
|
||||||
|
// if err := c.lock(); err != nil {
|
||||||
|
// rows.fatal(err)
|
||||||
|
// return rows, err
|
||||||
|
// }
|
||||||
|
// rows.unlockConn = true
|
||||||
|
|
||||||
|
// err = c.initContext(ctx)
|
||||||
|
// if err != nil {
|
||||||
|
// rows.fatal(err)
|
||||||
|
// return rows, rows.err
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if options != nil && options.SimpleProtocol {
|
||||||
|
// err = c.sanitizeAndSendSimpleQuery(sql, args...)
|
||||||
|
// if err != nil {
|
||||||
|
// rows.fatal(err)
|
||||||
|
// return rows, err
|
||||||
|
// }
|
||||||
|
|
||||||
|
// return rows, nil
|
||||||
|
// }
|
||||||
|
|
||||||
|
return &Batch{conn: c}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Batch) Conn() *Conn {
|
||||||
|
return b.conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Batch) Queue(query string, arguments []interface{}, parameterOids []pgtype.Oid, resultFormatCodes []int16) {
|
||||||
|
b.items = append(b.items, &batchItem{
|
||||||
|
query: query,
|
||||||
|
arguments: arguments,
|
||||||
|
parameterOids: parameterOids,
|
||||||
|
resultFormatCodes: resultFormatCodes,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error {
|
||||||
|
buf := appendQuery(b.conn.wbuf, txOptions.beginSQL())
|
||||||
|
|
||||||
|
for _, bi := range b.items {
|
||||||
|
// TODO - don't parse if named prepared statement
|
||||||
|
buf = appendParse(buf, "", bi.query, bi.parameterOids)
|
||||||
|
|
||||||
|
var err error
|
||||||
|
buf, err = appendBind(buf, "", "", b.conn.ConnInfo, bi.parameterOids, bi.arguments, bi.resultFormatCodes)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
buf = appendDescribe(buf, 'P', "")
|
||||||
|
buf = appendExecute(buf, "", 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf = appendSync(buf)
|
||||||
|
buf = appendQuery(buf, "commit")
|
||||||
|
|
||||||
|
n, err := b.conn.conn.Write(buf)
|
||||||
|
if err != nil {
|
||||||
|
if fatalWriteErr(n, err) {
|
||||||
|
b.conn.die(err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// expect ReadyForQuery from sync and from commit
|
||||||
|
b.conn.pendingReadyForQueryCount = b.conn.pendingReadyForQueryCount + 2
|
||||||
|
|
||||||
|
b.sent = true
|
||||||
|
|
||||||
|
for {
|
||||||
|
msg, err := b.conn.rxMsg()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch msg := msg.(type) {
|
||||||
|
case *pgproto3.ReadyForQuery:
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
if err := b.conn.processContextFreeMsg(msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Batch) ExecResults() (CommandTag, error) {
|
||||||
|
b.resultsRead++
|
||||||
|
|
||||||
|
for {
|
||||||
|
msg, err := b.conn.rxMsg()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch msg := msg.(type) {
|
||||||
|
case *pgproto3.CommandComplete:
|
||||||
|
return CommandTag(msg.CommandTag), nil
|
||||||
|
default:
|
||||||
|
if err := b.conn.processContextFreeMsg(msg); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Batch) QueryResults() (*Rows, error) {
|
||||||
|
b.resultsRead++
|
||||||
|
|
||||||
|
rows := b.conn.getRows("batch query", nil)
|
||||||
|
|
||||||
|
fieldDescriptions, err := b.conn.readUntilRowDescription()
|
||||||
|
if err != nil {
|
||||||
|
rows.fatal(err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rows.fields = fieldDescriptions
|
||||||
|
return rows, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Batch) QueryRowResults() *Row {
|
||||||
|
rows, _ := b.QueryResults()
|
||||||
|
return (*Row)(rows)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Batch) Finish() error {
|
||||||
|
for i := b.resultsRead; i < len(b.items); i++ {
|
||||||
|
_, err := b.ExecResults()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// readyForQueryCount := 0
|
||||||
|
|
||||||
|
// for {
|
||||||
|
// msg, err := b.conn.rxMsg()
|
||||||
|
// if err != nil {
|
||||||
|
// return "", err
|
||||||
|
// }
|
||||||
|
|
||||||
|
// switch msg := msg.(type) {
|
||||||
|
// case *pgproto3.ReadyForQuery:
|
||||||
|
// c.rxReadyForQuery(msg)
|
||||||
|
// default:
|
||||||
|
// if err := b.conn.processContextFreeMsg(msg); err != nil {
|
||||||
|
// return "", err
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// switch msg := msg.(type) {
|
||||||
|
// case *pgproto3.ErrorResponse:
|
||||||
|
// return c.rxErrorResponse(msg)
|
||||||
|
// case *pgproto3.NotificationResponse:
|
||||||
|
// c.rxNotificationResponse(msg)
|
||||||
|
// case *pgproto3.ReadyForQuery:
|
||||||
|
// c.rxReadyForQuery(msg)
|
||||||
|
// case *pgproto3.ParameterStatus:
|
||||||
|
// c.rxParameterStatus(msg)
|
||||||
|
// }
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,150 @@
|
||||||
|
package pgx_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx"
|
||||||
|
"github.com/jackc/pgx/pgtype"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConnBeginBatch(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
conn := mustConnect(t, *defaultConnConfig)
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
sql := `create temporary table ledger(
|
||||||
|
id serial primary key,
|
||||||
|
description varchar not null,
|
||||||
|
amount int not null
|
||||||
|
);`
|
||||||
|
mustExec(t, conn, sql)
|
||||||
|
|
||||||
|
batch := conn.BeginBatch()
|
||||||
|
batch.Queue("insert into ledger(description, amount) values($1, $2)",
|
||||||
|
[]interface{}{"q1", 1},
|
||||||
|
[]pgtype.Oid{pgtype.VarcharOid, pgtype.Int4Oid},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
batch.Queue("insert into ledger(description, amount) values($1, $2)",
|
||||||
|
[]interface{}{"q2", 2},
|
||||||
|
[]pgtype.Oid{pgtype.VarcharOid, pgtype.Int4Oid},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
batch.Queue("insert into ledger(description, amount) values($1, $2)",
|
||||||
|
[]interface{}{"q3", 3},
|
||||||
|
[]pgtype.Oid{pgtype.VarcharOid, pgtype.Int4Oid},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
batch.Queue("select id, description, amount from ledger order by id",
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
[]int16{pgx.BinaryFormatCode, pgx.TextFormatCode, pgx.BinaryFormatCode},
|
||||||
|
)
|
||||||
|
batch.Queue("select sum(amount) from ledger",
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
[]int16{pgx.BinaryFormatCode},
|
||||||
|
)
|
||||||
|
|
||||||
|
err := batch.Send(context.Background(), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ct, err := batch.ExecResults()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if ct.RowsAffected() != 1 {
|
||||||
|
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
ct, err = batch.ExecResults()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if ct.RowsAffected() != 1 {
|
||||||
|
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := batch.QueryResults()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var id int32
|
||||||
|
var description string
|
||||||
|
var amount int32
|
||||||
|
if !rows.Next() {
|
||||||
|
t.Fatal("expected a row to be available")
|
||||||
|
}
|
||||||
|
if err := rows.Scan(&id, &description, &amount); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if id != 1 {
|
||||||
|
t.Errorf("id => %v, want %v", id, 1)
|
||||||
|
}
|
||||||
|
if description != "q1" {
|
||||||
|
t.Errorf("description => %v, want %v", description, "q1")
|
||||||
|
}
|
||||||
|
if amount != 1 {
|
||||||
|
t.Errorf("amount => %v, want %v", amount, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !rows.Next() {
|
||||||
|
t.Fatal("expected a row to be available")
|
||||||
|
}
|
||||||
|
if err := rows.Scan(&id, &description, &amount); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if id != 2 {
|
||||||
|
t.Errorf("id => %v, want %v", id, 2)
|
||||||
|
}
|
||||||
|
if description != "q2" {
|
||||||
|
t.Errorf("description => %v, want %v", description, "q2")
|
||||||
|
}
|
||||||
|
if amount != 2 {
|
||||||
|
t.Errorf("amount => %v, want %v", amount, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !rows.Next() {
|
||||||
|
t.Fatal("expected a row to be available")
|
||||||
|
}
|
||||||
|
if err := rows.Scan(&id, &description, &amount); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if id != 3 {
|
||||||
|
t.Errorf("id => %v, want %v", id, 3)
|
||||||
|
}
|
||||||
|
if description != "q3" {
|
||||||
|
t.Errorf("description => %v, want %v", description, "q3")
|
||||||
|
}
|
||||||
|
if amount != 3 {
|
||||||
|
t.Errorf("amount => %v, want %v", amount, 3)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows.Next() {
|
||||||
|
t.Fatal("did not expect a row to be available")
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows.Err() != nil {
|
||||||
|
t.Fatal(rows.Err())
|
||||||
|
}
|
||||||
|
|
||||||
|
err = batch.QueryRowResults().Scan(&amount)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if amount != 6 {
|
||||||
|
t.Errorf("amount => %v, want %v", amount, 6)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = batch.Finish()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
16
conn.go
16
conn.go
|
@ -107,7 +107,7 @@ type Conn struct {
|
||||||
status byte // One of connStatus* constants
|
status byte // One of connStatus* constants
|
||||||
causeOfDeath error
|
causeOfDeath error
|
||||||
|
|
||||||
readyForQuery bool // connection has received ReadyForQuery message since last query was sent
|
pendingReadyForQueryCount int // numer of ReadyForQuery messages expected
|
||||||
cancelQueryInProgress int32
|
cancelQueryInProgress int32
|
||||||
cancelQueryCompleted chan struct{}
|
cancelQueryCompleted chan struct{}
|
||||||
|
|
||||||
|
@ -329,6 +329,8 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.pendingReadyForQueryCount = 1
|
||||||
|
|
||||||
for {
|
for {
|
||||||
msg, err := c.rxMsg()
|
msg, err := c.rxMsg()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -782,7 +784,7 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
c.readyForQuery = false
|
c.pendingReadyForQueryCount++
|
||||||
|
|
||||||
ps = &PreparedStatement{Name: name, SQL: sql}
|
ps = &PreparedStatement{Name: name, SQL: sql}
|
||||||
|
|
||||||
|
@ -1004,7 +1006,7 @@ func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error {
|
||||||
c.die(err)
|
c.die(err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
c.readyForQuery = false
|
c.pendingReadyForQueryCount++
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -1045,7 +1047,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
c.readyForQuery = false
|
c.pendingReadyForQueryCount++
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -1167,7 +1169,7 @@ func (c *Conn) rxBackendKeyData(msg *pgproto3.BackendKeyData) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) rxReadyForQuery(msg *pgproto3.ReadyForQuery) {
|
func (c *Conn) rxReadyForQuery(msg *pgproto3.ReadyForQuery) {
|
||||||
c.readyForQuery = true
|
c.pendingReadyForQueryCount--
|
||||||
c.txStatus = msg.TxStatus
|
c.txStatus = msg.TxStatus
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1429,7 +1431,7 @@ func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions,
|
||||||
c.die(err)
|
c.die(err)
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
c.readyForQuery = false
|
c.pendingReadyForQueryCount++
|
||||||
} else {
|
} else {
|
||||||
if len(arguments) > 0 {
|
if len(arguments) > 0 {
|
||||||
ps, ok := c.preparedStatements[sql]
|
ps, ok := c.preparedStatements[sql]
|
||||||
|
@ -1563,7 +1565,7 @@ func (c *Conn) waitForPreviousCancelQuery(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) ensureConnectionReadyForQuery() error {
|
func (c *Conn) ensureConnectionReadyForQuery() error {
|
||||||
for !c.readyForQuery {
|
for c.pendingReadyForQueryCount > 0 {
|
||||||
msg, err := c.rxMsg()
|
msg, err := c.rxMsg()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
package pgx_test
|
package pgx_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/jackc/pgx"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx"
|
||||||
)
|
)
|
||||||
|
|
||||||
func mustConnect(t testing.TB, config pgx.ConnConfig) *pgx.Conn {
|
func mustConnect(t testing.TB, config pgx.ConnConfig) *pgx.Conn {
|
||||||
|
|
2
query.go
2
query.go
|
@ -409,7 +409,7 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions,
|
||||||
c.die(err)
|
c.die(err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
c.readyForQuery = false
|
c.pendingReadyForQueryCount++
|
||||||
|
|
||||||
fieldDescriptions, err := c.readUntilRowDescription()
|
fieldDescriptions, err := c.readUntilRowDescription()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
Loading…
Reference in New Issue