New beginBatch on transaction object

A batch on a tx object does not open and close a transaction itself and
instead use the tx object to ensure the transactionality of the batch

remove unused boolean 'sent' in batch struct
pull/333/head
Gaspard Douady 2017-09-18 18:38:47 +02:00
parent fd7b776540
commit 38e2b9d449
2 changed files with 244 additions and 11 deletions

View File

@ -21,9 +21,9 @@ type Batch struct {
connPool *ConnPool
items []*batchItem
resultsRead int
sent bool
ctx context.Context
err error
inTx bool
}
// BeginBatch returns a *Batch query for c.
@ -31,6 +31,10 @@ func (c *Conn) BeginBatch() *Batch {
return &Batch{conn: c}
}
func (tx *Tx) BeginBatch() *Batch {
return &Batch{conn: tx.conn, inTx: true}
}
// Conn returns the underlying connection that b will or was performed on.
func (b *Batch) Conn() *Conn {
return b.conn
@ -48,7 +52,8 @@ func (b *Batch) Queue(query string, arguments []interface{}, parameterOIDs []pgt
})
}
// Send sends all queued queries to the server at once. All queries are wrapped
// Send sends all queued queries to the server at once.
// If the batch is created from a conn Object then All queries are wrapped
// in a transaction. The transaction can optionally be configured with
// txOptions. The context is in effect until the Batch is closed.
func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error {
@ -67,13 +72,16 @@ func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error {
return err
}
buf := b.conn.wbuf
if !b.inTx {
buf = appendQuery(buf, txOptions.beginSQL())
}
err = b.conn.initContext(ctx)
if err != nil {
return err
}
buf := appendQuery(b.conn.wbuf, txOptions.beginSQL())
for _, bi := range b.items {
var psName string
var psParameterOIDs []pgtype.OID
@ -97,7 +105,12 @@ func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error {
}
buf = appendSync(buf)
buf = appendQuery(buf, "commit")
b.conn.pendingReadyForQueryCount++
if !b.inTx {
buf = appendQuery(buf, "commit")
b.conn.pendingReadyForQueryCount++
}
n, err := b.conn.conn.Write(buf)
if err != nil {
@ -107,12 +120,7 @@ func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error {
return err
}
// expect ReadyForQuery from sync and from commit
b.conn.pendingReadyForQueryCount = b.conn.pendingReadyForQueryCount + 2
b.sent = true
for {
for !b.inTx {
msg, err := b.conn.rxMsg()
if err != nil {
return err

View File

@ -476,3 +476,228 @@ func TestConnBeginBatchQuerySyntaxError(t *testing.T) {
t.Error("conn should be dead, but was alive")
}
}
func TestConnBeginBatchQueryRowInsert(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
sql := `create temporary table ledger(
id serial primary key,
description varchar not null,
amount int not null
);`
mustExec(t, conn, sql)
batch := conn.BeginBatch()
batch.Queue("select 1",
nil,
nil,
[]int16{pgx.BinaryFormatCode},
)
batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)",
[]interface{}{"q1", 1},
[]pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID},
nil,
)
err := batch.Send(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
var value int
err = batch.QueryRowResults().Scan(&value)
if err != nil {
t.Error(err)
}
ct, err := batch.ExecResults()
if err != nil {
t.Error(err)
}
if ct.RowsAffected() != 2 {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected, 2)
}
batch.Close()
ensureConnValid(t, conn)
}
func TestConnBeginBatchQueryPartialReadInsert(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
sql := `create temporary table ledger(
id serial primary key,
description varchar not null,
amount int not null
);`
mustExec(t, conn, sql)
batch := conn.BeginBatch()
batch.Queue("select 1 union all select 2 union all select 3",
nil,
nil,
[]int16{pgx.BinaryFormatCode},
)
batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)",
[]interface{}{"q1", 1},
[]pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID},
nil,
)
err := batch.Send(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
rows, err := batch.QueryResults()
if err != nil {
t.Error(err)
}
rows.Close()
ct, err := batch.ExecResults()
if err != nil {
t.Error(err)
}
if ct.RowsAffected() != 2 {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected, 2)
}
batch.Close()
ensureConnValid(t, conn)
}
func TestTxBeginBatch(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
sql := `create temporary table ledger1(
id serial primary key,
description varchar not null
);`
mustExec(t, conn, sql)
sql = `create temporary table ledger2(
id int primary key,
amount int not null
);`
mustExec(t, conn, sql)
tx, _ := conn.Begin()
batch := tx.BeginBatch()
batch.Queue("insert into ledger1(description) values($1) returning id",
[]interface{}{"q1"},
[]pgtype.OID{pgtype.VarcharOID},
[]int16{pgx.BinaryFormatCode},
)
err := batch.Send(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
var id int
err = batch.QueryRowResults().Scan(&id)
if err != nil {
t.Error(err)
}
batch.Close()
batch = tx.BeginBatch()
batch.Queue("insert into ledger2(id,amount) values($1, $2)",
[]interface{}{id, 2},
[]pgtype.OID{pgtype.Int4OID, pgtype.Int4OID},
nil,
)
batch.Queue("select amount from ledger2 where id = $1",
[]interface{}{id},
[]pgtype.OID{pgtype.Int4OID},
nil,
)
err = batch.Send(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
ct, err := batch.ExecResults()
if err != nil {
t.Error(err)
}
if ct.RowsAffected() != 1 {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
}
var amout int
err = batch.QueryRowResults().Scan(&amout)
if err != nil {
t.Error(err)
}
batch.Close()
tx.Commit()
var count int
conn.QueryRow("select count(1) from ledger1 where id = $1", id).Scan(&count)
if count != 1 {
t.Errorf("count => %v, want %v", count, 1)
}
err = batch.Close()
if err != nil {
t.Fatal(err)
}
ensureConnValid(t, conn)
}
func TestTxBeginBatchRollback(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
sql := `create temporary table ledger1(
id serial primary key,
description varchar not null
);`
mustExec(t, conn, sql)
tx, _ := conn.Begin()
batch := tx.BeginBatch()
batch.Queue("insert into ledger1(description) values($1) returning id",
[]interface{}{"q1"},
[]pgtype.OID{pgtype.VarcharOID},
[]int16{pgx.BinaryFormatCode},
)
err := batch.Send(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
var id int
err = batch.QueryRowResults().Scan(&id)
if err != nil {
t.Error(err)
}
batch.Close()
tx.Rollback()
row := conn.QueryRow("select count(1) from ledger1 where id = $1", id)
var count int
row.Scan(&count)
if count != 0 {
t.Errorf("count => %v, want %v", count, 0)
}
ensureConnValid(t, conn)
}