SendBatch supports default QueryExecMode

pull/1170/head
Jack Christensen 2022-03-12 15:06:13 -06:00
parent 1390a11fe2
commit cb721dfb5b
3 changed files with 500 additions and 469 deletions

View File

@ -15,230 +15,227 @@ import (
func TestConnSendBatch(t *testing.T) { func TestConnSendBatch(t *testing.T) {
t.Parallel() t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) {
defer closeConn(t, conn) skipCockroachDB(t, conn, "Server serial type is incompatible with test")
skipCockroachDB(t, conn, "Server serial type is incompatible with test") sql := `create temporary table ledger(
sql := `create temporary table ledger(
id serial primary key, id serial primary key,
description varchar not null, description varchar not null,
amount int not null amount int not null
);` );`
mustExec(t, conn, sql) mustExec(t, conn, sql)
batch := &pgx.Batch{} batch := &pgx.Batch{}
batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1)
batch.Queue("insert into ledger(description, amount) values($1, $2)", "q2", 2)
batch.Queue("insert into ledger(description, amount) values($1, $2)", "q3", 3)
batch.Queue("select id, description, amount from ledger order by id")
batch.Queue("select id, description, amount from ledger order by id")
batch.Queue("select * from ledger where false")
batch.Queue("select sum(amount) from ledger")
br := conn.SendBatch(context.Background(), batch)
ct, err := br.Exec()
if err != nil {
t.Error(err)
}
if ct.RowsAffected() != 1 {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
}
ct, err = br.Exec()
if err != nil {
t.Error(err)
}
if ct.RowsAffected() != 1 {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
}
ct, err = br.Exec()
if err != nil {
t.Error(err)
}
if ct.RowsAffected() != 1 {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
}
selectFromLedgerExpectedRows := []struct {
id int32
description string
amount int32
}{
{1, "q1", 1},
{2, "q2", 2},
{3, "q3", 3},
}
rows, err := br.Query()
if err != nil {
t.Error(err)
}
var id int32
var description string
var amount int32
rowCount := 0
for rows.Next() {
if rowCount >= len(selectFromLedgerExpectedRows) {
t.Fatalf("got too many rows: %d", rowCount)
}
if err := rows.Scan(&id, &description, &amount); err != nil {
t.Fatalf("row %d: %v", rowCount, err)
}
if id != selectFromLedgerExpectedRows[rowCount].id {
t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id)
}
if description != selectFromLedgerExpectedRows[rowCount].description {
t.Errorf("description => %v, want %v", description, selectFromLedgerExpectedRows[rowCount].description)
}
if amount != selectFromLedgerExpectedRows[rowCount].amount {
t.Errorf("amount => %v, want %v", amount, selectFromLedgerExpectedRows[rowCount].amount)
}
rowCount++
}
if rows.Err() != nil {
t.Fatal(rows.Err())
}
rowCount = 0
_, err = br.QueryFunc([]interface{}{&id, &description, &amount}, func(pgx.QueryFuncRow) error {
if id != selectFromLedgerExpectedRows[rowCount].id {
t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id)
}
if description != selectFromLedgerExpectedRows[rowCount].description {
t.Errorf("description => %v, want %v", description, selectFromLedgerExpectedRows[rowCount].description)
}
if amount != selectFromLedgerExpectedRows[rowCount].amount {
t.Errorf("amount => %v, want %v", amount, selectFromLedgerExpectedRows[rowCount].amount)
}
rowCount++
return nil
})
if err != nil {
t.Error(err)
}
err = br.QueryRow().Scan(&id, &description, &amount)
if !errors.Is(err, pgx.ErrNoRows) {
t.Errorf("expected pgx.ErrNoRows but got: %v", err)
}
err = br.QueryRow().Scan(&amount)
if err != nil {
t.Error(err)
}
if amount != 6 {
t.Errorf("amount => %v, want %v", amount, 6)
}
err = br.Close()
if err != nil {
t.Fatal(err)
}
ensureConnValid(t, conn)
}
func TestConnSendBatchMany(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
sql := `create temporary table ledger(
id serial primary key,
description varchar not null,
amount int not null
);`
mustExec(t, conn, sql)
batch := &pgx.Batch{}
numInserts := 1000
for i := 0; i < numInserts; i++ {
batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1) batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1)
} batch.Queue("insert into ledger(description, amount) values($1, $2)", "q2", 2)
batch.Queue("select count(*) from ledger") batch.Queue("insert into ledger(description, amount) values($1, $2)", "q3", 3)
batch.Queue("select id, description, amount from ledger order by id")
batch.Queue("select id, description, amount from ledger order by id")
batch.Queue("select * from ledger where false")
batch.Queue("select sum(amount) from ledger")
br := conn.SendBatch(context.Background(), batch) br := conn.SendBatch(context.Background(), batch)
for i := 0; i < numInserts; i++ {
ct, err := br.Exec() ct, err := br.Exec()
assert.NoError(t, err) if err != nil {
assert.EqualValues(t, 1, ct.RowsAffected()) t.Error(err)
} }
if ct.RowsAffected() != 1 {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
}
var actualInserts int ct, err = br.Exec()
err := br.QueryRow().Scan(&actualInserts) if err != nil {
assert.NoError(t, err) t.Error(err)
assert.EqualValues(t, numInserts, actualInserts) }
if ct.RowsAffected() != 1 {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
}
err = br.Close() ct, err = br.Exec()
require.NoError(t, err) if err != nil {
t.Error(err)
}
if ct.RowsAffected() != 1 {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
}
ensureConnValid(t, conn) selectFromLedgerExpectedRows := []struct {
} id int32
description string
amount int32
}{
{1, "q1", 1},
{2, "q2", 2},
{3, "q3", 3},
}
func TestConnSendBatchWithPreparedStatement(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
skipCockroachDB(t, conn, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
_, err := conn.Prepare(context.Background(), "ps1", "select n from generate_series(0,$1::int) n")
if err != nil {
t.Fatal(err)
}
batch := &pgx.Batch{}
queryCount := 3
for i := 0; i < queryCount; i++ {
batch.Queue("ps1", 5)
}
br := conn.SendBatch(context.Background(), batch)
for i := 0; i < queryCount; i++ {
rows, err := br.Query() rows, err := br.Query()
if err != nil { if err != nil {
t.Fatal(err) t.Error(err)
} }
for k := 0; rows.Next(); k++ { var id int32
var n int var description string
if err := rows.Scan(&n); err != nil { var amount int32
t.Fatal(err) rowCount := 0
for rows.Next() {
if rowCount >= len(selectFromLedgerExpectedRows) {
t.Fatalf("got too many rows: %d", rowCount)
} }
if n != k {
t.Fatalf("n => %v, want %v", n, k) if err := rows.Scan(&id, &description, &amount); err != nil {
t.Fatalf("row %d: %v", rowCount, err)
} }
if id != selectFromLedgerExpectedRows[rowCount].id {
t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id)
}
if description != selectFromLedgerExpectedRows[rowCount].description {
t.Errorf("description => %v, want %v", description, selectFromLedgerExpectedRows[rowCount].description)
}
if amount != selectFromLedgerExpectedRows[rowCount].amount {
t.Errorf("amount => %v, want %v", amount, selectFromLedgerExpectedRows[rowCount].amount)
}
rowCount++
} }
if rows.Err() != nil { if rows.Err() != nil {
t.Fatal(rows.Err()) t.Fatal(rows.Err())
} }
}
err = br.Close() rowCount = 0
if err != nil { _, err = br.QueryFunc([]interface{}{&id, &description, &amount}, func(pgx.QueryFuncRow) error {
t.Fatal(err) if id != selectFromLedgerExpectedRows[rowCount].id {
} t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id)
}
if description != selectFromLedgerExpectedRows[rowCount].description {
t.Errorf("description => %v, want %v", description, selectFromLedgerExpectedRows[rowCount].description)
}
if amount != selectFromLedgerExpectedRows[rowCount].amount {
t.Errorf("amount => %v, want %v", amount, selectFromLedgerExpectedRows[rowCount].amount)
}
ensureConnValid(t, conn) rowCount++
return nil
})
if err != nil {
t.Error(err)
}
err = br.QueryRow().Scan(&id, &description, &amount)
if !errors.Is(err, pgx.ErrNoRows) {
t.Errorf("expected pgx.ErrNoRows but got: %v", err)
}
err = br.QueryRow().Scan(&amount)
if err != nil {
t.Error(err)
}
if amount != 6 {
t.Errorf("amount => %v, want %v", amount, 6)
}
err = br.Close()
if err != nil {
t.Fatal(err)
}
})
}
func TestConnSendBatchMany(t *testing.T) {
t.Parallel()
testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) {
sql := `create temporary table ledger(
id serial primary key,
description varchar not null,
amount int not null
);`
mustExec(t, conn, sql)
batch := &pgx.Batch{}
numInserts := 1000
for i := 0; i < numInserts; i++ {
batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1)
}
batch.Queue("select count(*) from ledger")
br := conn.SendBatch(context.Background(), batch)
for i := 0; i < numInserts; i++ {
ct, err := br.Exec()
assert.NoError(t, err)
assert.EqualValues(t, 1, ct.RowsAffected())
}
var actualInserts int
err := br.QueryRow().Scan(&actualInserts)
assert.NoError(t, err)
assert.EqualValues(t, numInserts, actualInserts)
err = br.Close()
require.NoError(t, err)
})
}
func TestConnSendBatchWithPreparedStatement(t *testing.T) {
t.Parallel()
modes := []pgx.QueryExecMode{
pgx.QueryExecModeCacheStatement,
pgx.QueryExecModeCacheDescribe,
pgx.QueryExecModeDescribeExec,
pgx.QueryExecModeExec,
// Don't test simple mode with prepared statements.
}
testWithQueryExecModes(t, modes, func(t *testing.T, conn *pgx.Conn) {
skipCockroachDB(t, conn, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
_, err := conn.Prepare(context.Background(), "ps1", "select n from generate_series(0,$1::int) n")
if err != nil {
t.Fatal(err)
}
batch := &pgx.Batch{}
queryCount := 3
for i := 0; i < queryCount; i++ {
batch.Queue("ps1", 5)
}
br := conn.SendBatch(context.Background(), batch)
for i := 0; i < queryCount; i++ {
rows, err := br.Query()
if err != nil {
t.Fatal(err)
}
for k := 0; rows.Next(); k++ {
var n int
if err := rows.Scan(&n); err != nil {
t.Fatal(err)
}
if n != k {
t.Fatalf("n => %v, want %v", n, k)
}
}
if rows.Err() != nil {
t.Fatal(rows.Err())
}
}
err = br.Close()
if err != nil {
t.Fatal(err)
}
})
} }
// https://github.com/jackc/pgx/issues/856 // https://github.com/jackc/pgx/issues/856
@ -303,316 +300,308 @@ func TestConnSendBatchWithPreparedStatementAndStatementCacheDisabled(t *testing.
func TestConnSendBatchCloseRowsPartiallyRead(t *testing.T) { func TestConnSendBatchCloseRowsPartiallyRead(t *testing.T) {
t.Parallel() t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) {
defer closeConn(t, conn)
batch := &pgx.Batch{} batch := &pgx.Batch{}
batch.Queue("select n from generate_series(0,5) n") batch.Queue("select n from generate_series(0,5) n")
batch.Queue("select n from generate_series(0,5) n") batch.Queue("select n from generate_series(0,5) n")
br := conn.SendBatch(context.Background(), batch) br := conn.SendBatch(context.Background(), batch)
rows, err := br.Query() rows, err := br.Query()
if err != nil { if err != nil {
t.Error(err)
}
for i := 0; i < 3; i++ {
if !rows.Next() {
t.Error("expected a row to be available")
}
var n int
if err := rows.Scan(&n); err != nil {
t.Error(err) t.Error(err)
} }
if n != i {
t.Errorf("n => %v, want %v", n, i) for i := 0; i < 3; i++ {
if !rows.Next() {
t.Error("expected a row to be available")
}
var n int
if err := rows.Scan(&n); err != nil {
t.Error(err)
}
if n != i {
t.Errorf("n => %v, want %v", n, i)
}
} }
}
rows.Close() rows.Close()
rows, err = br.Query() rows, err = br.Query()
if err != nil { if err != nil {
t.Error(err)
}
for i := 0; rows.Next(); i++ {
var n int
if err := rows.Scan(&n); err != nil {
t.Error(err) t.Error(err)
} }
if n != i {
t.Errorf("n => %v, want %v", n, i) for i := 0; rows.Next(); i++ {
var n int
if err := rows.Scan(&n); err != nil {
t.Error(err)
}
if n != i {
t.Errorf("n => %v, want %v", n, i)
}
} }
}
if rows.Err() != nil { if rows.Err() != nil {
t.Error(rows.Err()) t.Error(rows.Err())
} }
err = br.Close() err = br.Close()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
ensureConnValid(t, conn) })
} }
func TestConnSendBatchQueryError(t *testing.T) { func TestConnSendBatchQueryError(t *testing.T) {
t.Parallel() t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) {
defer closeConn(t, conn)
batch := &pgx.Batch{} batch := &pgx.Batch{}
batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0") batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0")
batch.Queue("select n from generate_series(0,5) n") batch.Queue("select n from generate_series(0,5) n")
br := conn.SendBatch(context.Background(), batch) br := conn.SendBatch(context.Background(), batch)
rows, err := br.Query() rows, err := br.Query()
if err != nil { if err != nil {
t.Error(err)
}
for i := 0; rows.Next(); i++ {
var n int
if err := rows.Scan(&n); err != nil {
t.Error(err) t.Error(err)
} }
if n != i {
t.Errorf("n => %v, want %v", n, i) for i := 0; rows.Next(); i++ {
var n int
if err := rows.Scan(&n); err != nil {
t.Error(err)
}
if n != i {
t.Errorf("n => %v, want %v", n, i)
}
} }
}
if pgErr, ok := rows.Err().(*pgconn.PgError); !(ok && pgErr.Code == "22012") { if pgErr, ok := rows.Err().(*pgconn.PgError); !(ok && pgErr.Code == "22012") {
t.Errorf("rows.Err() => %v, want error code %v", rows.Err(), 22012) t.Errorf("rows.Err() => %v, want error code %v", rows.Err(), 22012)
} }
err = br.Close() err = br.Close()
if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "22012") { if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "22012") {
t.Errorf("rows.Err() => %v, want error code %v", err, 22012) t.Errorf("rows.Err() => %v, want error code %v", err, 22012)
} }
ensureConnValid(t, conn) })
} }
func TestConnSendBatchQuerySyntaxError(t *testing.T) { func TestConnSendBatchQuerySyntaxError(t *testing.T) {
t.Parallel() t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) {
defer closeConn(t, conn)
batch := &pgx.Batch{} batch := &pgx.Batch{}
batch.Queue("select 1 1") batch.Queue("select 1 1")
br := conn.SendBatch(context.Background(), batch) br := conn.SendBatch(context.Background(), batch)
var n int32 var n int32
err := br.QueryRow().Scan(&n) err := br.QueryRow().Scan(&n)
if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "42601") { if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "42601") {
t.Errorf("rows.Err() => %v, want error code %v", err, 42601) t.Errorf("rows.Err() => %v, want error code %v", err, 42601)
} }
err = br.Close() err = br.Close()
if err == nil { if err == nil {
t.Error("Expected error") t.Error("Expected error")
} }
ensureConnValid(t, conn) })
} }
func TestConnSendBatchQueryRowInsert(t *testing.T) { func TestConnSendBatchQueryRowInsert(t *testing.T) {
t.Parallel() t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) {
defer closeConn(t, conn)
sql := `create temporary table ledger( sql := `create temporary table ledger(
id serial primary key, id serial primary key,
description varchar not null, description varchar not null,
amount int not null amount int not null
);` );`
mustExec(t, conn, sql) mustExec(t, conn, sql)
batch := &pgx.Batch{} batch := &pgx.Batch{}
batch.Queue("select 1") batch.Queue("select 1")
batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1) batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1)
br := conn.SendBatch(context.Background(), batch) br := conn.SendBatch(context.Background(), batch)
var value int var value int
err := br.QueryRow().Scan(&value) err := br.QueryRow().Scan(&value)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
ct, err := br.Exec() ct, err := br.Exec()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
if ct.RowsAffected() != 2 { if ct.RowsAffected() != 2 {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2) t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2)
} }
br.Close() br.Close()
ensureConnValid(t, conn) })
} }
func TestConnSendBatchQueryPartialReadInsert(t *testing.T) { func TestConnSendBatchQueryPartialReadInsert(t *testing.T) {
t.Parallel() t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) {
defer closeConn(t, conn)
sql := `create temporary table ledger( sql := `create temporary table ledger(
id serial primary key, id serial primary key,
description varchar not null, description varchar not null,
amount int not null amount int not null
);` );`
mustExec(t, conn, sql) mustExec(t, conn, sql)
batch := &pgx.Batch{} batch := &pgx.Batch{}
batch.Queue("select 1 union all select 2 union all select 3") batch.Queue("select 1 union all select 2 union all select 3")
batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1) batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1)
br := conn.SendBatch(context.Background(), batch) br := conn.SendBatch(context.Background(), batch)
rows, err := br.Query() rows, err := br.Query()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
rows.Close() rows.Close()
ct, err := br.Exec() ct, err := br.Exec()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
if ct.RowsAffected() != 2 { if ct.RowsAffected() != 2 {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2) t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2)
} }
br.Close() br.Close()
ensureConnValid(t, conn) })
} }
func TestTxSendBatch(t *testing.T) { func TestTxSendBatch(t *testing.T) {
t.Parallel() t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) {
defer closeConn(t, conn)
sql := `create temporary table ledger1( sql := `create temporary table ledger1(
id serial primary key, id serial primary key,
description varchar not null description varchar not null
);` );`
mustExec(t, conn, sql) mustExec(t, conn, sql)
sql = `create temporary table ledger2( sql = `create temporary table ledger2(
id int primary key, id int primary key,
amount int not null amount int not null
);` );`
mustExec(t, conn, sql) mustExec(t, conn, sql)
tx, _ := conn.Begin(context.Background()) tx, _ := conn.Begin(context.Background())
batch := &pgx.Batch{} batch := &pgx.Batch{}
batch.Queue("insert into ledger1(description) values($1) returning id", "q1") batch.Queue("insert into ledger1(description) values($1) returning id", "q1")
br := tx.SendBatch(context.Background(), batch) br := tx.SendBatch(context.Background(), batch)
var id int var id int
err := br.QueryRow().Scan(&id) err := br.QueryRow().Scan(&id)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
br.Close() br.Close()
batch = &pgx.Batch{} batch = &pgx.Batch{}
batch.Queue("insert into ledger2(id,amount) values($1, $2)", id, 2) batch.Queue("insert into ledger2(id,amount) values($1, $2)", id, 2)
batch.Queue("select amount from ledger2 where id = $1", id) batch.Queue("select amount from ledger2 where id = $1", id)
br = tx.SendBatch(context.Background(), batch) br = tx.SendBatch(context.Background(), batch)
ct, err := br.Exec() ct, err := br.Exec()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
if ct.RowsAffected() != 1 { if ct.RowsAffected() != 1 {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
} }
var amount int var amount int
err = br.QueryRow().Scan(&amount) err = br.QueryRow().Scan(&amount)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
br.Close() br.Close()
tx.Commit(context.Background()) tx.Commit(context.Background())
var count int var count int
conn.QueryRow(context.Background(), "select count(1) from ledger1 where id = $1", id).Scan(&count) conn.QueryRow(context.Background(), "select count(1) from ledger1 where id = $1", id).Scan(&count)
if count != 1 { if count != 1 {
t.Errorf("count => %v, want %v", count, 1) t.Errorf("count => %v, want %v", count, 1)
} }
err = br.Close() err = br.Close()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
ensureConnValid(t, conn) })
} }
func TestTxSendBatchRollback(t *testing.T) { func TestTxSendBatchRollback(t *testing.T) {
t.Parallel() t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) {
defer closeConn(t, conn)
sql := `create temporary table ledger1( sql := `create temporary table ledger1(
id serial primary key, id serial primary key,
description varchar not null description varchar not null
);` );`
mustExec(t, conn, sql) mustExec(t, conn, sql)
tx, _ := conn.Begin(context.Background()) tx, _ := conn.Begin(context.Background())
batch := &pgx.Batch{} batch := &pgx.Batch{}
batch.Queue("insert into ledger1(description) values($1) returning id", "q1") batch.Queue("insert into ledger1(description) values($1) returning id", "q1")
br := tx.SendBatch(context.Background(), batch) br := tx.SendBatch(context.Background(), batch)
var id int var id int
err := br.QueryRow().Scan(&id) err := br.QueryRow().Scan(&id)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
br.Close() br.Close()
tx.Rollback(context.Background()) tx.Rollback(context.Background())
row := conn.QueryRow(context.Background(), "select count(1) from ledger1 where id = $1", id) row := conn.QueryRow(context.Background(), "select count(1) from ledger1 where id = $1", id)
var count int var count int
row.Scan(&count) row.Scan(&count)
if count != 0 { if count != 0 {
t.Errorf("count => %v, want %v", count, 0) t.Errorf("count => %v, want %v", count, 0)
} }
ensureConnValid(t, conn) })
} }
func TestConnBeginBatchDeferredError(t *testing.T) { func TestConnBeginBatchDeferredError(t *testing.T) {
t.Parallel() t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) {
defer closeConn(t, conn)
skipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") skipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
mustExec(t, conn, `create temporary table t ( mustExec(t, conn, `create temporary table t (
id text primary key, id text primary key,
n int not null, n int not null,
unique (n) deferrable initially deferred unique (n) deferrable initially deferred
@ -620,36 +609,36 @@ func TestConnBeginBatchDeferredError(t *testing.T) {
insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`) insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`)
batch := &pgx.Batch{} batch := &pgx.Batch{}
batch.Queue(`update t set n=n+1 where id='b' returning *`) batch.Queue(`update t set n=n+1 where id='b' returning *`)
br := conn.SendBatch(context.Background(), batch) br := conn.SendBatch(context.Background(), batch)
rows, err := br.Query() rows, err := br.Query()
if err != nil {
t.Error(err)
}
for rows.Next() {
var id string
var n int32
err = rows.Scan(&id, &n)
if err != nil { if err != nil {
t.Fatal(err) t.Error(err)
} }
}
err = br.Close() for rows.Next() {
if err == nil { var id string
t.Fatal("expected error 23505 but got none") var n int32
} err = rows.Scan(&id, &n)
if err != nil {
t.Fatal(err)
}
}
if err, ok := err.(*pgconn.PgError); !ok || err.Code != "23505" { err = br.Close()
t.Fatalf("expected error 23505, got %v", err) if err == nil {
} t.Fatal("expected error 23505 but got none")
}
ensureConnValid(t, conn) if err, ok := err.(*pgconn.PgError); !ok || err.Code != "23505" {
t.Fatalf("expected error 23505, got %v", err)
}
})
} }
func TestConnSendBatchNoStatementCache(t *testing.T) { func TestConnSendBatchNoStatementCache(t *testing.T) {

143
conn.go
View File

@ -861,9 +861,10 @@ func (c *Conn) QueryFunc(ctx context.Context, sql string, args []interface{}, sc
// explicit transaction control statements are executed. The returned BatchResults must be closed before the connection // explicit transaction control statements are executed. The returned BatchResults must be closed before the connection
// is used again. // is used again.
func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
simpleProtocol := c.config.DefaultQueryExecMode == QueryExecModeSimpleProtocol mode := c.config.DefaultQueryExecMode
var sb strings.Builder
if simpleProtocol { if mode == QueryExecModeSimpleProtocol {
var sb strings.Builder
for i, bi := range b.items { for i, bi := range b.items {
if i > 0 { if i > 0 {
sb.WriteByte(';') sb.WriteByte(';')
@ -884,66 +885,102 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
} }
} }
distinctUnpreparedQueries := map[string]struct{}{}
for _, bi := range b.items {
if _, ok := c.preparedStatements[bi.query]; ok {
continue
}
distinctUnpreparedQueries[bi.query] = struct{}{}
}
var stmtCache stmtcache.Cache
if len(distinctUnpreparedQueries) > 0 {
if c.statementCache != nil && c.statementCache.Cap() >= len(distinctUnpreparedQueries) {
stmtCache = c.statementCache
} else {
stmtCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, len(distinctUnpreparedQueries))
}
for sql, _ := range distinctUnpreparedQueries {
_, err := stmtCache.Get(ctx, sql)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
}
}
batch := &pgconn.Batch{} batch := &pgconn.Batch{}
for _, bi := range b.items { if mode == QueryExecModeExec {
c.eqb.Reset() for _, bi := range b.items {
c.eqb.Reset()
anynil.NormalizeSlice(bi.arguments)
sd := c.preparedStatements[bi.query] sd := c.preparedStatements[bi.query]
if sd == nil { if sd != nil {
var err error if len(sd.ParamOIDs) != len(bi.arguments) {
sd, err = stmtCache.Get(ctx, bi.query) return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("mismatched param and argument count")}
if err != nil { }
return &batchResults{ctx: ctx, conn: c, err: err}
for i := range bi.arguments {
err := c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], bi.arguments[i])
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
}
for i := range sd.Fields {
c.eqb.AppendResultFormat(c.TypeMap().FormatCodeForOID(sd.Fields[i].DataTypeOID))
}
batch.ExecPrepared(sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats)
} else {
err := c.appendParamsForQueryExecModeExec(bi.arguments)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
batch.ExecParams(bi.query, c.eqb.paramValues, nil, c.eqb.paramFormats, c.eqb.resultFormats)
}
}
} else {
distinctUnpreparedQueries := map[string]struct{}{}
for _, bi := range b.items {
if _, ok := c.preparedStatements[bi.query]; ok {
continue
}
distinctUnpreparedQueries[bi.query] = struct{}{}
}
var stmtCache stmtcache.Cache
if len(distinctUnpreparedQueries) > 0 {
if mode == QueryExecModeCacheStatement && c.statementCache != nil && c.statementCache.Cap() >= len(distinctUnpreparedQueries) {
stmtCache = c.statementCache
} else if mode == QueryExecModeCacheStatement && c.descriptionCache != nil && c.descriptionCache.Cap() >= len(distinctUnpreparedQueries) {
stmtCache = c.descriptionCache
} else {
stmtCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, len(distinctUnpreparedQueries))
}
for sql, _ := range distinctUnpreparedQueries {
_, err := stmtCache.Get(ctx, sql)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
} }
} }
if len(sd.ParamOIDs) != len(bi.arguments) { for _, bi := range b.items {
return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("mismatched param and argument count")} c.eqb.Reset()
}
anynil.NormalizeSlice(bi.arguments) sd := c.preparedStatements[bi.query]
if sd == nil {
for i := range bi.arguments { var err error
err := c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], bi.arguments[i]) sd, err = stmtCache.Get(ctx, bi.query)
if err != nil { if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err} return &batchResults{ctx: ctx, conn: c, err: err}
}
} }
}
for i := range sd.Fields { if len(sd.ParamOIDs) != len(bi.arguments) {
c.eqb.AppendResultFormat(c.TypeMap().FormatCodeForOID(sd.Fields[i].DataTypeOID)) return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("mismatched param and argument count")}
} }
if sd.Name == "" { anynil.NormalizeSlice(bi.arguments)
batch.ExecParams(bi.query, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, c.eqb.resultFormats)
} else { for i := range bi.arguments {
batch.ExecPrepared(sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats) err := c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], bi.arguments[i])
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
}
for i := range sd.Fields {
c.eqb.AppendResultFormat(c.TypeMap().FormatCodeForOID(sd.Fields[i].DataTypeOID))
}
if sd.Name == "" {
batch.ExecParams(bi.query, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, c.eqb.resultFormats)
} else {
batch.ExecPrepared(sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats)
}
} }
} }

View File

@ -13,13 +13,18 @@ import (
) )
func testWithAllQueryExecModes(t *testing.T, f func(t *testing.T, conn *pgx.Conn)) { func testWithAllQueryExecModes(t *testing.T, f func(t *testing.T, conn *pgx.Conn)) {
for _, mode := range []pgx.QueryExecMode{ modes := []pgx.QueryExecMode{
pgx.QueryExecModeCacheStatement, pgx.QueryExecModeCacheStatement,
pgx.QueryExecModeCacheDescribe, pgx.QueryExecModeCacheDescribe,
pgx.QueryExecModeDescribeExec, pgx.QueryExecModeDescribeExec,
pgx.QueryExecModeExec, pgx.QueryExecModeExec,
pgx.QueryExecModeSimpleProtocol, pgx.QueryExecModeSimpleProtocol,
} { }
testWithQueryExecModes(t, modes, f)
}
func testWithQueryExecModes(t *testing.T, modes []pgx.QueryExecMode, f func(t *testing.T, conn *pgx.Conn)) {
for _, mode := range modes {
t.Run(mode.String(), t.Run(mode.String(),
func(t *testing.T) { func(t *testing.T) {
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))