mirror of https://github.com/jackc/pgx.git
New beginBatch on transaction object
A batch on a tx object does not open and close a transaction itself and instead use the tx object to ensure the transactionality of the batch remove unused boolean 'sent' in batch structpull/333/head
parent
fd7b776540
commit
38e2b9d449
30
batch.go
30
batch.go
|
@ -21,9 +21,9 @@ type Batch struct {
|
|||
connPool *ConnPool
|
||||
items []*batchItem
|
||||
resultsRead int
|
||||
sent bool
|
||||
ctx context.Context
|
||||
err error
|
||||
inTx bool
|
||||
}
|
||||
|
||||
// BeginBatch returns a *Batch query for c.
|
||||
|
@ -31,6 +31,10 @@ func (c *Conn) BeginBatch() *Batch {
|
|||
return &Batch{conn: c}
|
||||
}
|
||||
|
||||
func (tx *Tx) BeginBatch() *Batch {
|
||||
return &Batch{conn: tx.conn, inTx: true}
|
||||
}
|
||||
|
||||
// Conn returns the underlying connection that b will or was performed on.
|
||||
func (b *Batch) Conn() *Conn {
|
||||
return b.conn
|
||||
|
@ -48,7 +52,8 @@ func (b *Batch) Queue(query string, arguments []interface{}, parameterOIDs []pgt
|
|||
})
|
||||
}
|
||||
|
||||
// Send sends all queued queries to the server at once. All queries are wrapped
|
||||
// Send sends all queued queries to the server at once.
|
||||
// If the batch is created from a conn Object then All queries are wrapped
|
||||
// in a transaction. The transaction can optionally be configured with
|
||||
// txOptions. The context is in effect until the Batch is closed.
|
||||
func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error {
|
||||
|
@ -67,13 +72,16 @@ func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error {
|
|||
return err
|
||||
}
|
||||
|
||||
buf := b.conn.wbuf
|
||||
if !b.inTx {
|
||||
buf = appendQuery(buf, txOptions.beginSQL())
|
||||
}
|
||||
|
||||
err = b.conn.initContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf := appendQuery(b.conn.wbuf, txOptions.beginSQL())
|
||||
|
||||
for _, bi := range b.items {
|
||||
var psName string
|
||||
var psParameterOIDs []pgtype.OID
|
||||
|
@ -97,7 +105,12 @@ func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error {
|
|||
}
|
||||
|
||||
buf = appendSync(buf)
|
||||
buf = appendQuery(buf, "commit")
|
||||
b.conn.pendingReadyForQueryCount++
|
||||
|
||||
if !b.inTx {
|
||||
buf = appendQuery(buf, "commit")
|
||||
b.conn.pendingReadyForQueryCount++
|
||||
}
|
||||
|
||||
n, err := b.conn.conn.Write(buf)
|
||||
if err != nil {
|
||||
|
@ -107,12 +120,7 @@ func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error {
|
|||
return err
|
||||
}
|
||||
|
||||
// expect ReadyForQuery from sync and from commit
|
||||
b.conn.pendingReadyForQueryCount = b.conn.pendingReadyForQueryCount + 2
|
||||
|
||||
b.sent = true
|
||||
|
||||
for {
|
||||
for !b.inTx {
|
||||
msg, err := b.conn.rxMsg()
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
225
batch_test.go
225
batch_test.go
|
@ -476,3 +476,228 @@ func TestConnBeginBatchQuerySyntaxError(t *testing.T) {
|
|||
t.Error("conn should be dead, but was alive")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnBeginBatchQueryRowInsert(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
sql := `create temporary table ledger(
|
||||
id serial primary key,
|
||||
description varchar not null,
|
||||
amount int not null
|
||||
);`
|
||||
mustExec(t, conn, sql)
|
||||
|
||||
batch := conn.BeginBatch()
|
||||
batch.Queue("select 1",
|
||||
nil,
|
||||
nil,
|
||||
[]int16{pgx.BinaryFormatCode},
|
||||
)
|
||||
batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)",
|
||||
[]interface{}{"q1", 1},
|
||||
[]pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID},
|
||||
nil,
|
||||
)
|
||||
|
||||
err := batch.Send(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var value int
|
||||
err = batch.QueryRowResults().Scan(&value)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
ct, err := batch.ExecResults()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if ct.RowsAffected() != 2 {
|
||||
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected, 2)
|
||||
}
|
||||
|
||||
batch.Close()
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnBeginBatchQueryPartialReadInsert(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
sql := `create temporary table ledger(
|
||||
id serial primary key,
|
||||
description varchar not null,
|
||||
amount int not null
|
||||
);`
|
||||
mustExec(t, conn, sql)
|
||||
|
||||
batch := conn.BeginBatch()
|
||||
batch.Queue("select 1 union all select 2 union all select 3",
|
||||
nil,
|
||||
nil,
|
||||
[]int16{pgx.BinaryFormatCode},
|
||||
)
|
||||
batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)",
|
||||
[]interface{}{"q1", 1},
|
||||
[]pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID},
|
||||
nil,
|
||||
)
|
||||
|
||||
err := batch.Send(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rows, err := batch.QueryResults()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
ct, err := batch.ExecResults()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if ct.RowsAffected() != 2 {
|
||||
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected, 2)
|
||||
}
|
||||
|
||||
batch.Close()
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestTxBeginBatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
sql := `create temporary table ledger1(
|
||||
id serial primary key,
|
||||
description varchar not null
|
||||
);`
|
||||
mustExec(t, conn, sql)
|
||||
|
||||
sql = `create temporary table ledger2(
|
||||
id int primary key,
|
||||
amount int not null
|
||||
);`
|
||||
mustExec(t, conn, sql)
|
||||
|
||||
tx, _ := conn.Begin()
|
||||
batch := tx.BeginBatch()
|
||||
batch.Queue("insert into ledger1(description) values($1) returning id",
|
||||
[]interface{}{"q1"},
|
||||
[]pgtype.OID{pgtype.VarcharOID},
|
||||
[]int16{pgx.BinaryFormatCode},
|
||||
)
|
||||
|
||||
err := batch.Send(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var id int
|
||||
err = batch.QueryRowResults().Scan(&id)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
batch.Close()
|
||||
|
||||
batch = tx.BeginBatch()
|
||||
batch.Queue("insert into ledger2(id,amount) values($1, $2)",
|
||||
[]interface{}{id, 2},
|
||||
[]pgtype.OID{pgtype.Int4OID, pgtype.Int4OID},
|
||||
nil,
|
||||
)
|
||||
|
||||
batch.Queue("select amount from ledger2 where id = $1",
|
||||
[]interface{}{id},
|
||||
[]pgtype.OID{pgtype.Int4OID},
|
||||
nil,
|
||||
)
|
||||
|
||||
err = batch.Send(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ct, err := batch.ExecResults()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if ct.RowsAffected() != 1 {
|
||||
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
|
||||
}
|
||||
|
||||
var amout int
|
||||
err = batch.QueryRowResults().Scan(&amout)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
batch.Close()
|
||||
tx.Commit()
|
||||
|
||||
var count int
|
||||
conn.QueryRow("select count(1) from ledger1 where id = $1", id).Scan(&count)
|
||||
if count != 1 {
|
||||
t.Errorf("count => %v, want %v", count, 1)
|
||||
}
|
||||
|
||||
err = batch.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestTxBeginBatchRollback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
sql := `create temporary table ledger1(
|
||||
id serial primary key,
|
||||
description varchar not null
|
||||
);`
|
||||
mustExec(t, conn, sql)
|
||||
|
||||
tx, _ := conn.Begin()
|
||||
batch := tx.BeginBatch()
|
||||
batch.Queue("insert into ledger1(description) values($1) returning id",
|
||||
[]interface{}{"q1"},
|
||||
[]pgtype.OID{pgtype.VarcharOID},
|
||||
[]int16{pgx.BinaryFormatCode},
|
||||
)
|
||||
|
||||
err := batch.Send(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var id int
|
||||
err = batch.QueryRowResults().Scan(&id)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
batch.Close()
|
||||
tx.Rollback()
|
||||
|
||||
row := conn.QueryRow("select count(1) from ledger1 where id = $1", id)
|
||||
var count int
|
||||
row.Scan(&count)
|
||||
if count != 0 {
|
||||
t.Errorf("count => %v, want %v", count, 0)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue