From 38e2b9d4497eae0896bc877ffa3fae34e9b9b9b6 Mon Sep 17 00:00:00 2001 From: Gaspard Douady Date: Mon, 18 Sep 2017 18:38:47 +0200 Subject: [PATCH] New beginBatch on transaction object A batch on a tx object does not open and close a transaction itself and instead use the tx object to ensure the transactionality of the batch remove unused boolean 'sent' in batch struct --- batch.go | 30 ++++--- batch_test.go | 225 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 244 insertions(+), 11 deletions(-) diff --git a/batch.go b/batch.go index fc6f0d03..67dafc29 100644 --- a/batch.go +++ b/batch.go @@ -21,9 +21,9 @@ type Batch struct { connPool *ConnPool items []*batchItem resultsRead int - sent bool ctx context.Context err error + inTx bool } // BeginBatch returns a *Batch query for c. @@ -31,6 +31,10 @@ func (c *Conn) BeginBatch() *Batch { return &Batch{conn: c} } +func (tx *Tx) BeginBatch() *Batch { + return &Batch{conn: tx.conn, inTx: true} +} + // Conn returns the underlying connection that b will or was performed on. func (b *Batch) Conn() *Conn { return b.conn @@ -48,7 +52,8 @@ func (b *Batch) Queue(query string, arguments []interface{}, parameterOIDs []pgt }) } -// Send sends all queued queries to the server at once. All queries are wrapped +// Send sends all queued queries to the server at once. +// If the batch is created from a conn Object then All queries are wrapped // in a transaction. The transaction can optionally be configured with // txOptions. The context is in effect until the Batch is closed. func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error { @@ -67,13 +72,16 @@ func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error { return err } + buf := b.conn.wbuf + if !b.inTx { + buf = appendQuery(buf, txOptions.beginSQL()) + } + err = b.conn.initContext(ctx) if err != nil { return err } - buf := appendQuery(b.conn.wbuf, txOptions.beginSQL()) - for _, bi := range b.items { var psName string var psParameterOIDs []pgtype.OID @@ -97,7 +105,12 @@ func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error { } buf = appendSync(buf) - buf = appendQuery(buf, "commit") + b.conn.pendingReadyForQueryCount++ + + if !b.inTx { + buf = appendQuery(buf, "commit") + b.conn.pendingReadyForQueryCount++ + } n, err := b.conn.conn.Write(buf) if err != nil { @@ -107,12 +120,7 @@ func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error { return err } - // expect ReadyForQuery from sync and from commit - b.conn.pendingReadyForQueryCount = b.conn.pendingReadyForQueryCount + 2 - - b.sent = true - - for { + for !b.inTx { msg, err := b.conn.rxMsg() if err != nil { return err diff --git a/batch_test.go b/batch_test.go index e12e4f32..54785f79 100644 --- a/batch_test.go +++ b/batch_test.go @@ -476,3 +476,228 @@ func TestConnBeginBatchQuerySyntaxError(t *testing.T) { t.Error("conn should be dead, but was alive") } } + +func TestConnBeginBatchQueryRowInsert(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + sql := `create temporary table ledger( + id serial primary key, + description varchar not null, + amount int not null +);` + mustExec(t, conn, sql) + + batch := conn.BeginBatch() + batch.Queue("select 1", + nil, + nil, + []int16{pgx.BinaryFormatCode}, + ) + batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", + []interface{}{"q1", 1}, + []pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID}, + nil, + ) + + err := batch.Send(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + + var value int + err = batch.QueryRowResults().Scan(&value) + if err != nil { + t.Error(err) + } + + ct, err := batch.ExecResults() + if err != nil { + t.Error(err) + } + if ct.RowsAffected() != 2 { + t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected, 2) + } + + batch.Close() + + ensureConnValid(t, conn) +} + +func TestConnBeginBatchQueryPartialReadInsert(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + sql := `create temporary table ledger( + id serial primary key, + description varchar not null, + amount int not null +);` + mustExec(t, conn, sql) + + batch := conn.BeginBatch() + batch.Queue("select 1 union all select 2 union all select 3", + nil, + nil, + []int16{pgx.BinaryFormatCode}, + ) + batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", + []interface{}{"q1", 1}, + []pgtype.OID{pgtype.VarcharOID, pgtype.Int4OID}, + nil, + ) + + err := batch.Send(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + + rows, err := batch.QueryResults() + if err != nil { + t.Error(err) + } + rows.Close() + + ct, err := batch.ExecResults() + if err != nil { + t.Error(err) + } + if ct.RowsAffected() != 2 { + t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected, 2) + } + + batch.Close() + + ensureConnValid(t, conn) +} + +func TestTxBeginBatch(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + sql := `create temporary table ledger1( + id serial primary key, + description varchar not null +);` + mustExec(t, conn, sql) + + sql = `create temporary table ledger2( + id int primary key, + amount int not null +);` + mustExec(t, conn, sql) + + tx, _ := conn.Begin() + batch := tx.BeginBatch() + batch.Queue("insert into ledger1(description) values($1) returning id", + []interface{}{"q1"}, + []pgtype.OID{pgtype.VarcharOID}, + []int16{pgx.BinaryFormatCode}, + ) + + err := batch.Send(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + var id int + err = batch.QueryRowResults().Scan(&id) + if err != nil { + t.Error(err) + } + batch.Close() + + batch = tx.BeginBatch() + batch.Queue("insert into ledger2(id,amount) values($1, $2)", + []interface{}{id, 2}, + []pgtype.OID{pgtype.Int4OID, pgtype.Int4OID}, + nil, + ) + + batch.Queue("select amount from ledger2 where id = $1", + []interface{}{id}, + []pgtype.OID{pgtype.Int4OID}, + nil, + ) + + err = batch.Send(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + ct, err := batch.ExecResults() + if err != nil { + t.Error(err) + } + if ct.RowsAffected() != 1 { + t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) + } + + var amout int + err = batch.QueryRowResults().Scan(&amout) + if err != nil { + t.Error(err) + } + + batch.Close() + tx.Commit() + + var count int + conn.QueryRow("select count(1) from ledger1 where id = $1", id).Scan(&count) + if count != 1 { + t.Errorf("count => %v, want %v", count, 1) + } + + err = batch.Close() + if err != nil { + t.Fatal(err) + } + + ensureConnValid(t, conn) +} + +func TestTxBeginBatchRollback(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + sql := `create temporary table ledger1( + id serial primary key, + description varchar not null +);` + mustExec(t, conn, sql) + + tx, _ := conn.Begin() + batch := tx.BeginBatch() + batch.Queue("insert into ledger1(description) values($1) returning id", + []interface{}{"q1"}, + []pgtype.OID{pgtype.VarcharOID}, + []int16{pgx.BinaryFormatCode}, + ) + + err := batch.Send(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + var id int + err = batch.QueryRowResults().Scan(&id) + if err != nil { + t.Error(err) + } + batch.Close() + tx.Rollback() + + row := conn.QueryRow("select count(1) from ledger1 where id = $1", id) + var count int + row.Scan(&count) + if count != 0 { + t.Errorf("count => %v, want %v", count, 0) + } + + ensureConnValid(t, conn) +}