pgx/batch_test.go

1099 lines
27 KiB
Go

package pgx_test
import (
"context"
"errors"
"fmt"
"os"
"testing"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxtest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestConnSendBatch(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.SkipCockroachDB(t, conn, "Server serial type is incompatible with test")
sql := `create temporary table ledger(
id serial primary key,
description varchar not null,
amount int not null
);`
mustExec(t, conn, sql)
batch := &pgx.Batch{}
batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1)
batch.Queue("insert into ledger(description, amount) values($1, $2)", "q2", 2)
batch.Queue("insert into ledger(description, amount) values($1, $2)", "q3", 3)
batch.Queue("select id, description, amount from ledger order by id")
batch.Queue("select id, description, amount from ledger order by id")
batch.Queue("select * from ledger where false")
batch.Queue("select sum(amount) from ledger")
br := conn.SendBatch(ctx, batch)
ct, err := br.Exec()
if err != nil {
t.Error(err)
}
if ct.RowsAffected() != 1 {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
}
ct, err = br.Exec()
if err != nil {
t.Error(err)
}
if ct.RowsAffected() != 1 {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
}
ct, err = br.Exec()
if err != nil {
t.Error(err)
}
if ct.RowsAffected() != 1 {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
}
selectFromLedgerExpectedRows := []struct {
id int32
description string
amount int32
}{
{1, "q1", 1},
{2, "q2", 2},
{3, "q3", 3},
}
rows, err := br.Query()
if err != nil {
t.Error(err)
}
var id int32
var description string
var amount int32
rowCount := 0
for rows.Next() {
if rowCount >= len(selectFromLedgerExpectedRows) {
t.Fatalf("got too many rows: %d", rowCount)
}
if err := rows.Scan(&id, &description, &amount); err != nil {
t.Fatalf("row %d: %v", rowCount, err)
}
if id != selectFromLedgerExpectedRows[rowCount].id {
t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id)
}
if description != selectFromLedgerExpectedRows[rowCount].description {
t.Errorf("description => %v, want %v", description, selectFromLedgerExpectedRows[rowCount].description)
}
if amount != selectFromLedgerExpectedRows[rowCount].amount {
t.Errorf("amount => %v, want %v", amount, selectFromLedgerExpectedRows[rowCount].amount)
}
rowCount++
}
if rows.Err() != nil {
t.Fatal(rows.Err())
}
rowCount = 0
rows, _ = br.Query()
_, err = pgx.ForEachRow(rows, []any{&id, &description, &amount}, func() error {
if id != selectFromLedgerExpectedRows[rowCount].id {
t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id)
}
if description != selectFromLedgerExpectedRows[rowCount].description {
t.Errorf("description => %v, want %v", description, selectFromLedgerExpectedRows[rowCount].description)
}
if amount != selectFromLedgerExpectedRows[rowCount].amount {
t.Errorf("amount => %v, want %v", amount, selectFromLedgerExpectedRows[rowCount].amount)
}
rowCount++
return nil
})
if err != nil {
t.Error(err)
}
err = br.QueryRow().Scan(&id, &description, &amount)
if !errors.Is(err, pgx.ErrNoRows) {
t.Errorf("expected pgx.ErrNoRows but got: %v", err)
}
err = br.QueryRow().Scan(&amount)
if err != nil {
t.Error(err)
}
if amount != 6 {
t.Errorf("amount => %v, want %v", amount, 6)
}
err = br.Close()
if err != nil {
t.Fatal(err)
}
})
}
func TestConnSendBatchQueuedQuery(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.SkipCockroachDB(t, conn, "Server serial type is incompatible with test")
sql := `create temporary table ledger(
id serial primary key,
description varchar not null,
amount int not null
);`
mustExec(t, conn, sql)
batch := &pgx.Batch{}
batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1).Exec(func(ct pgconn.CommandTag) error {
assert.EqualValues(t, 1, ct.RowsAffected())
return nil
})
batch.Queue("insert into ledger(description, amount) values($1, $2)", "q2", 2).Exec(func(ct pgconn.CommandTag) error {
assert.EqualValues(t, 1, ct.RowsAffected())
return nil
})
batch.Queue("insert into ledger(description, amount) values($1, $2)", "q3", 3).Exec(func(ct pgconn.CommandTag) error {
assert.EqualValues(t, 1, ct.RowsAffected())
return nil
})
selectFromLedgerExpectedRows := []struct {
id int32
description string
amount int32
}{
{1, "q1", 1},
{2, "q2", 2},
{3, "q3", 3},
}
batch.Queue("select id, description, amount from ledger order by id").Query(func(rows pgx.Rows) error {
rowCount := 0
var id int32
var description string
var amount int32
_, err := pgx.ForEachRow(rows, []any{&id, &description, &amount}, func() error {
assert.Equal(t, selectFromLedgerExpectedRows[rowCount].id, id)
assert.Equal(t, selectFromLedgerExpectedRows[rowCount].description, description)
assert.Equal(t, selectFromLedgerExpectedRows[rowCount].amount, amount)
rowCount++
return nil
})
assert.NoError(t, err)
return nil
})
batch.Queue("select id, description, amount from ledger order by id").Query(func(rows pgx.Rows) error {
rowCount := 0
var id int32
var description string
var amount int32
_, err := pgx.ForEachRow(rows, []any{&id, &description, &amount}, func() error {
assert.Equal(t, selectFromLedgerExpectedRows[rowCount].id, id)
assert.Equal(t, selectFromLedgerExpectedRows[rowCount].description, description)
assert.Equal(t, selectFromLedgerExpectedRows[rowCount].amount, amount)
rowCount++
return nil
})
assert.NoError(t, err)
return nil
})
batch.Queue("select * from ledger where false").QueryRow(func(row pgx.Row) error {
err := row.Scan(nil, nil, nil)
assert.ErrorIs(t, err, pgx.ErrNoRows)
return nil
})
batch.Queue("select sum(amount) from ledger").QueryRow(func(row pgx.Row) error {
var sumAmount int32
err := row.Scan(&sumAmount)
assert.NoError(t, err)
assert.EqualValues(t, 6, sumAmount)
return nil
})
err := conn.SendBatch(ctx, batch).Close()
assert.NoError(t, err)
})
}
func TestConnSendBatchMany(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
sql := `create temporary table ledger(
id serial primary key,
description varchar not null,
amount int not null
);`
mustExec(t, conn, sql)
batch := &pgx.Batch{}
numInserts := 1000
for i := 0; i < numInserts; i++ {
batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1)
}
batch.Queue("select count(*) from ledger")
br := conn.SendBatch(ctx, batch)
for i := 0; i < numInserts; i++ {
ct, err := br.Exec()
assert.NoError(t, err)
assert.EqualValues(t, 1, ct.RowsAffected())
}
var actualInserts int
err := br.QueryRow().Scan(&actualInserts)
assert.NoError(t, err)
assert.EqualValues(t, numInserts, actualInserts)
err = br.Close()
require.NoError(t, err)
})
}
// https://github.com/jackc/pgx/issues/1801#issuecomment-2203784178
func TestConnSendBatchReadResultsWhenNothingQueued(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
batch := &pgx.Batch{}
br := conn.SendBatch(ctx, batch)
commandTag, err := br.Exec()
require.Equal(t, "", commandTag.String())
require.EqualError(t, err, "no more results in batch")
err = br.Close()
require.NoError(t, err)
})
}
func TestConnSendBatchReadMoreResultsThanQueriesSent(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
batch := &pgx.Batch{}
batch.Queue("select 1")
br := conn.SendBatch(ctx, batch)
commandTag, err := br.Exec()
require.Equal(t, "SELECT 1", commandTag.String())
require.NoError(t, err)
commandTag, err = br.Exec()
require.Equal(t, "", commandTag.String())
require.EqualError(t, err, "no more results in batch")
err = br.Close()
require.NoError(t, err)
})
}
func TestConnSendBatchWithPreparedStatement(t *testing.T) {
t.Parallel()
modes := []pgx.QueryExecMode{
pgx.QueryExecModeCacheStatement,
pgx.QueryExecModeCacheDescribe,
pgx.QueryExecModeDescribeExec,
pgx.QueryExecModeExec,
// Don't test simple mode with prepared statements.
}
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, modes, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.SkipCockroachDB(t, conn, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
_, err := conn.Prepare(ctx, "ps1", "select n from generate_series(0,$1::int) n")
if err != nil {
t.Fatal(err)
}
batch := &pgx.Batch{}
queryCount := 3
for i := 0; i < queryCount; i++ {
batch.Queue("ps1", 5)
}
br := conn.SendBatch(ctx, batch)
for i := 0; i < queryCount; i++ {
rows, err := br.Query()
if err != nil {
t.Fatal(err)
}
for k := 0; rows.Next(); k++ {
var n int
if err := rows.Scan(&n); err != nil {
t.Fatal(err)
}
if n != k {
t.Fatalf("n => %v, want %v", n, k)
}
}
if rows.Err() != nil {
t.Fatal(rows.Err())
}
}
err = br.Close()
if err != nil {
t.Fatal(err)
}
})
}
func TestConnSendBatchWithQueryRewriter(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
batch := &pgx.Batch{}
batch.Queue("something to be replaced", &testQueryRewriter{sql: "select $1::int", args: []any{1}})
batch.Queue("something else to be replaced", &testQueryRewriter{sql: "select $1::text", args: []any{"hello"}})
batch.Queue("more to be replaced", &testQueryRewriter{sql: "select $1::int", args: []any{3}})
br := conn.SendBatch(ctx, batch)
var n int32
err := br.QueryRow().Scan(&n)
require.NoError(t, err)
require.EqualValues(t, 1, n)
var s string
err = br.QueryRow().Scan(&s)
require.NoError(t, err)
require.Equal(t, "hello", s)
err = br.QueryRow().Scan(&n)
require.NoError(t, err)
require.EqualValues(t, 3, n)
err = br.Close()
require.NoError(t, err)
})
}
// https://github.com/jackc/pgx/issues/856
func TestConnSendBatchWithPreparedStatementAndStatementCacheDisabled(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
config.DefaultQueryExecMode = pgx.QueryExecModeDescribeExec
config.StatementCacheCapacity = 0
config.DescriptionCacheCapacity = 0
conn := mustConnect(t, config)
defer closeConn(t, conn)
pgxtest.SkipCockroachDB(t, conn, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
_, err = conn.Prepare(ctx, "ps1", "select n from generate_series(0,$1::int) n")
if err != nil {
t.Fatal(err)
}
batch := &pgx.Batch{}
queryCount := 3
for i := 0; i < queryCount; i++ {
batch.Queue("ps1", 5)
}
br := conn.SendBatch(ctx, batch)
for i := 0; i < queryCount; i++ {
rows, err := br.Query()
if err != nil {
t.Fatal(err)
}
for k := 0; rows.Next(); k++ {
var n int
if err := rows.Scan(&n); err != nil {
t.Fatal(err)
}
if n != k {
t.Fatalf("n => %v, want %v", n, k)
}
}
if rows.Err() != nil {
t.Fatal(rows.Err())
}
}
err = br.Close()
if err != nil {
t.Fatal(err)
}
ensureConnValid(t, conn)
}
func TestConnSendBatchCloseRowsPartiallyRead(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
batch := &pgx.Batch{}
batch.Queue("select n from generate_series(0,5) n")
batch.Queue("select n from generate_series(0,5) n")
br := conn.SendBatch(ctx, batch)
rows, err := br.Query()
if err != nil {
t.Error(err)
}
for i := 0; i < 3; i++ {
if !rows.Next() {
t.Error("expected a row to be available")
}
var n int
if err := rows.Scan(&n); err != nil {
t.Error(err)
}
if n != i {
t.Errorf("n => %v, want %v", n, i)
}
}
rows.Close()
rows, err = br.Query()
if err != nil {
t.Error(err)
}
for i := 0; rows.Next(); i++ {
var n int
if err := rows.Scan(&n); err != nil {
t.Error(err)
}
if n != i {
t.Errorf("n => %v, want %v", n, i)
}
}
if rows.Err() != nil {
t.Error(rows.Err())
}
err = br.Close()
if err != nil {
t.Fatal(err)
}
})
}
func TestConnSendBatchQueryError(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
batch := &pgx.Batch{}
batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0")
batch.Queue("select n from generate_series(0,5) n")
br := conn.SendBatch(ctx, batch)
rows, err := br.Query()
if err != nil {
t.Error(err)
}
for i := 0; rows.Next(); i++ {
var n int
if err := rows.Scan(&n); err != nil {
t.Error(err)
}
if n != i {
t.Errorf("n => %v, want %v", n, i)
}
}
if pgErr, ok := rows.Err().(*pgconn.PgError); !(ok && pgErr.Code == "22012") {
t.Errorf("rows.Err() => %v, want error code %v", rows.Err(), 22012)
}
err = br.Close()
if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "22012") {
t.Errorf("br.Close() => %v, want error code %v", err, 22012)
}
})
}
func TestConnSendBatchQuerySyntaxError(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
batch := &pgx.Batch{}
batch.Queue("select 1 1")
br := conn.SendBatch(ctx, batch)
var n int32
err := br.QueryRow().Scan(&n)
if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "42601") {
t.Errorf("rows.Err() => %v, want error code %v", err, 42601)
}
err = br.Close()
if err == nil {
t.Error("Expected error")
}
})
}
func TestConnSendBatchQueryRowInsert(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
sql := `create temporary table ledger(
id serial primary key,
description varchar not null,
amount int not null
);`
mustExec(t, conn, sql)
batch := &pgx.Batch{}
batch.Queue("select 1")
batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1)
br := conn.SendBatch(ctx, batch)
var value int
err := br.QueryRow().Scan(&value)
if err != nil {
t.Error(err)
}
ct, err := br.Exec()
if err != nil {
t.Error(err)
}
if ct.RowsAffected() != 2 {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2)
}
br.Close()
})
}
func TestConnSendBatchQueryPartialReadInsert(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
sql := `create temporary table ledger(
id serial primary key,
description varchar not null,
amount int not null
);`
mustExec(t, conn, sql)
batch := &pgx.Batch{}
batch.Queue("select 1 union all select 2 union all select 3")
batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1)
br := conn.SendBatch(ctx, batch)
rows, err := br.Query()
if err != nil {
t.Error(err)
}
rows.Close()
ct, err := br.Exec()
if err != nil {
t.Error(err)
}
if ct.RowsAffected() != 2 {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2)
}
br.Close()
})
}
func TestTxSendBatch(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
sql := `create temporary table ledger1(
id serial primary key,
description varchar not null
);`
mustExec(t, conn, sql)
sql = `create temporary table ledger2(
id int primary key,
amount int not null
);`
mustExec(t, conn, sql)
tx, _ := conn.Begin(ctx)
batch := &pgx.Batch{}
batch.Queue("insert into ledger1(description) values($1) returning id", "q1")
br := tx.SendBatch(context.Background(), batch)
var id int
err := br.QueryRow().Scan(&id)
if err != nil {
t.Error(err)
}
br.Close()
batch = &pgx.Batch{}
batch.Queue("insert into ledger2(id,amount) values($1, $2)", id, 2)
batch.Queue("select amount from ledger2 where id = $1", id)
br = tx.SendBatch(ctx, batch)
ct, err := br.Exec()
if err != nil {
t.Error(err)
}
if ct.RowsAffected() != 1 {
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
}
var amount int
err = br.QueryRow().Scan(&amount)
if err != nil {
t.Error(err)
}
br.Close()
tx.Commit(ctx)
var count int
conn.QueryRow(ctx, "select count(1) from ledger1 where id = $1", id).Scan(&count)
if count != 1 {
t.Errorf("count => %v, want %v", count, 1)
}
err = br.Close()
if err != nil {
t.Fatal(err)
}
})
}
func TestTxSendBatchRollback(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
sql := `create temporary table ledger1(
id serial primary key,
description varchar not null
);`
mustExec(t, conn, sql)
tx, _ := conn.Begin(ctx)
batch := &pgx.Batch{}
batch.Queue("insert into ledger1(description) values($1) returning id", "q1")
br := tx.SendBatch(ctx, batch)
var id int
err := br.QueryRow().Scan(&id)
if err != nil {
t.Error(err)
}
br.Close()
tx.Rollback(ctx)
row := conn.QueryRow(ctx, "select count(1) from ledger1 where id = $1", id)
var count int
row.Scan(&count)
if count != 0 {
t.Errorf("count => %v, want %v", count, 0)
}
})
}
// https://github.com/jackc/pgx/issues/1578
func TestSendBatchErrorWhileReadingResultsWithoutCallback(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
batch := &pgx.Batch{}
batch.Queue("select 4 / $1::int", 0)
batchResult := conn.SendBatch(ctx, batch)
_, execErr := batchResult.Exec()
require.Error(t, execErr)
closeErr := batchResult.Close()
require.Equal(t, execErr, closeErr)
// Try to use the connection.
_, err := conn.Exec(ctx, "select 1")
require.NoError(t, err)
})
}
func TestSendBatchErrorWhileReadingResultsWithExecWhereSomeRowsAreReturned(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
batch := &pgx.Batch{}
batch.Queue("select 4 / n from generate_series(-2, 2) n")
batchResult := conn.SendBatch(ctx, batch)
_, execErr := batchResult.Exec()
require.Error(t, execErr)
closeErr := batchResult.Close()
require.Equal(t, execErr, closeErr)
// Try to use the connection.
_, err := conn.Exec(ctx, "select 1")
require.NoError(t, err)
})
}
func TestConnBeginBatchDeferredError(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.SkipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
mustExec(t, conn, `create temporary table t (
id text primary key,
n int not null,
unique (n) deferrable initially deferred
);
insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`)
batch := &pgx.Batch{}
batch.Queue(`update t set n=n+1 where id='b' returning *`)
br := conn.SendBatch(ctx, batch)
rows, err := br.Query()
if err != nil {
t.Error(err)
}
for rows.Next() {
var id string
var n int32
err = rows.Scan(&id, &n)
if err != nil {
t.Fatal(err)
}
}
err = br.Close()
if err == nil {
t.Fatal("expected error 23505 but got none")
}
if err, ok := err.(*pgconn.PgError); !ok || err.Code != "23505" {
t.Fatalf("expected error 23505, got %v", err)
}
})
}
func TestConnSendBatchNoStatementCache(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
config.DefaultQueryExecMode = pgx.QueryExecModeDescribeExec
config.StatementCacheCapacity = 0
config.DescriptionCacheCapacity = 0
conn := mustConnect(t, config)
defer closeConn(t, conn)
testConnSendBatch(t, ctx, conn, 3)
}
func TestConnSendBatchPrepareStatementCache(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
config.DefaultQueryExecMode = pgx.QueryExecModeCacheStatement
config.StatementCacheCapacity = 32
conn := mustConnect(t, config)
defer closeConn(t, conn)
testConnSendBatch(t, ctx, conn, 3)
}
func TestConnSendBatchDescribeStatementCache(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
config.DefaultQueryExecMode = pgx.QueryExecModeCacheDescribe
config.DescriptionCacheCapacity = 32
conn := mustConnect(t, config)
defer closeConn(t, conn)
testConnSendBatch(t, ctx, conn, 3)
}
func testConnSendBatch(t *testing.T, ctx context.Context, conn *pgx.Conn, queryCount int) {
batch := &pgx.Batch{}
for j := 0; j < queryCount; j++ {
batch.Queue("select n from generate_series(0,5) n")
}
br := conn.SendBatch(ctx, batch)
for j := 0; j < queryCount; j++ {
rows, err := br.Query()
require.NoError(t, err)
for k := 0; rows.Next(); k++ {
var n int
err := rows.Scan(&n)
require.NoError(t, err)
require.Equal(t, k, n)
}
require.NoError(t, rows.Err())
}
err := br.Close()
require.NoError(t, err)
}
func TestSendBatchSimpleProtocol(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol
conn := mustConnect(t, config)
defer closeConn(t, conn)
var batch pgx.Batch
batch.Queue("SELECT 1::int")
batch.Queue("SELECT 2::int; SELECT $1::int", 3)
results := conn.SendBatch(ctx, &batch)
rows, err := results.Query()
assert.NoError(t, err)
assert.True(t, rows.Next())
values, err := rows.Values()
assert.NoError(t, err)
assert.EqualValues(t, 1, values[0])
assert.False(t, rows.Next())
rows, err = results.Query()
assert.NoError(t, err)
assert.True(t, rows.Next())
values, err = rows.Values()
assert.NoError(t, err)
assert.EqualValues(t, 2, values[0])
assert.False(t, rows.Next())
rows, err = results.Query()
assert.NoError(t, err)
assert.True(t, rows.Next())
values, err = rows.Values()
assert.NoError(t, err)
assert.EqualValues(t, 3, values[0])
assert.False(t, rows.Next())
}
// https://github.com/jackc/pgx/issues/1847#issuecomment-2347858887
func TestConnSendBatchErrorDoesNotLeaveOrphanedPreparedStatement(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.SkipCockroachDB(t, conn, "Server serial type is incompatible with test")
mustExec(t, conn, `create temporary table foo(col1 text primary key);`)
batch := &pgx.Batch{}
batch.Queue("select col1 from foo")
batch.Queue("select col1 from baz")
err := conn.SendBatch(ctx, batch).Close()
require.EqualError(t, err, `ERROR: relation "baz" does not exist (SQLSTATE 42P01)`)
mustExec(t, conn, `create temporary table baz(col1 text primary key);`)
// Since table baz now exists, the batch should succeed.
batch = &pgx.Batch{}
batch.Queue("select col1 from foo")
batch.Queue("select col1 from baz")
err = conn.SendBatch(ctx, batch).Close()
require.NoError(t, err)
})
}
func ExampleConn_SendBatch() {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
if err != nil {
fmt.Printf("Unable to establish connection: %v", err)
return
}
batch := &pgx.Batch{}
batch.Queue("select 1 + 1").QueryRow(func(row pgx.Row) error {
var n int32
err := row.Scan(&n)
if err != nil {
return err
}
fmt.Println(n)
return err
})
batch.Queue("select 1 + 2").QueryRow(func(row pgx.Row) error {
var n int32
err := row.Scan(&n)
if err != nil {
return err
}
fmt.Println(n)
return err
})
batch.Queue("select 2 + 3").QueryRow(func(row pgx.Row) error {
var n int32
err := row.Scan(&n)
if err != nil {
return err
}
fmt.Println(n)
return err
})
err = conn.SendBatch(ctx, batch).Close()
if err != nil {
fmt.Printf("SendBatch error: %v", err)
return
}
// Output:
// 2
// 3
// 5
}