mirror of https://github.com/jackc/pgx.git
SendBatch supports default QueryExecMode
parent
1390a11fe2
commit
cb721dfb5b
817
batch_test.go
817
batch_test.go
|
@ -15,230 +15,227 @@ import (
|
||||||
func TestConnSendBatch(t *testing.T) {
|
func TestConnSendBatch(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) {
|
||||||
defer closeConn(t, conn)
|
skipCockroachDB(t, conn, "Server serial type is incompatible with test")
|
||||||
|
|
||||||
skipCockroachDB(t, conn, "Server serial type is incompatible with test")
|
sql := `create temporary table ledger(
|
||||||
|
|
||||||
sql := `create temporary table ledger(
|
|
||||||
id serial primary key,
|
id serial primary key,
|
||||||
description varchar not null,
|
description varchar not null,
|
||||||
amount int not null
|
amount int not null
|
||||||
);`
|
);`
|
||||||
mustExec(t, conn, sql)
|
mustExec(t, conn, sql)
|
||||||
|
|
||||||
batch := &pgx.Batch{}
|
batch := &pgx.Batch{}
|
||||||
batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1)
|
|
||||||
batch.Queue("insert into ledger(description, amount) values($1, $2)", "q2", 2)
|
|
||||||
batch.Queue("insert into ledger(description, amount) values($1, $2)", "q3", 3)
|
|
||||||
batch.Queue("select id, description, amount from ledger order by id")
|
|
||||||
batch.Queue("select id, description, amount from ledger order by id")
|
|
||||||
batch.Queue("select * from ledger where false")
|
|
||||||
batch.Queue("select sum(amount) from ledger")
|
|
||||||
|
|
||||||
br := conn.SendBatch(context.Background(), batch)
|
|
||||||
|
|
||||||
ct, err := br.Exec()
|
|
||||||
if err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
if ct.RowsAffected() != 1 {
|
|
||||||
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
ct, err = br.Exec()
|
|
||||||
if err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
if ct.RowsAffected() != 1 {
|
|
||||||
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
ct, err = br.Exec()
|
|
||||||
if err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
if ct.RowsAffected() != 1 {
|
|
||||||
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
selectFromLedgerExpectedRows := []struct {
|
|
||||||
id int32
|
|
||||||
description string
|
|
||||||
amount int32
|
|
||||||
}{
|
|
||||||
{1, "q1", 1},
|
|
||||||
{2, "q2", 2},
|
|
||||||
{3, "q3", 3},
|
|
||||||
}
|
|
||||||
|
|
||||||
rows, err := br.Query()
|
|
||||||
if err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var id int32
|
|
||||||
var description string
|
|
||||||
var amount int32
|
|
||||||
rowCount := 0
|
|
||||||
|
|
||||||
for rows.Next() {
|
|
||||||
if rowCount >= len(selectFromLedgerExpectedRows) {
|
|
||||||
t.Fatalf("got too many rows: %d", rowCount)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := rows.Scan(&id, &description, &amount); err != nil {
|
|
||||||
t.Fatalf("row %d: %v", rowCount, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if id != selectFromLedgerExpectedRows[rowCount].id {
|
|
||||||
t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id)
|
|
||||||
}
|
|
||||||
if description != selectFromLedgerExpectedRows[rowCount].description {
|
|
||||||
t.Errorf("description => %v, want %v", description, selectFromLedgerExpectedRows[rowCount].description)
|
|
||||||
}
|
|
||||||
if amount != selectFromLedgerExpectedRows[rowCount].amount {
|
|
||||||
t.Errorf("amount => %v, want %v", amount, selectFromLedgerExpectedRows[rowCount].amount)
|
|
||||||
}
|
|
||||||
|
|
||||||
rowCount++
|
|
||||||
}
|
|
||||||
|
|
||||||
if rows.Err() != nil {
|
|
||||||
t.Fatal(rows.Err())
|
|
||||||
}
|
|
||||||
|
|
||||||
rowCount = 0
|
|
||||||
_, err = br.QueryFunc([]interface{}{&id, &description, &amount}, func(pgx.QueryFuncRow) error {
|
|
||||||
if id != selectFromLedgerExpectedRows[rowCount].id {
|
|
||||||
t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id)
|
|
||||||
}
|
|
||||||
if description != selectFromLedgerExpectedRows[rowCount].description {
|
|
||||||
t.Errorf("description => %v, want %v", description, selectFromLedgerExpectedRows[rowCount].description)
|
|
||||||
}
|
|
||||||
if amount != selectFromLedgerExpectedRows[rowCount].amount {
|
|
||||||
t.Errorf("amount => %v, want %v", amount, selectFromLedgerExpectedRows[rowCount].amount)
|
|
||||||
}
|
|
||||||
|
|
||||||
rowCount++
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = br.QueryRow().Scan(&id, &description, &amount)
|
|
||||||
if !errors.Is(err, pgx.ErrNoRows) {
|
|
||||||
t.Errorf("expected pgx.ErrNoRows but got: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = br.QueryRow().Scan(&amount)
|
|
||||||
if err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
if amount != 6 {
|
|
||||||
t.Errorf("amount => %v, want %v", amount, 6)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = br.Close()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ensureConnValid(t, conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConnSendBatchMany(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
|
||||||
defer closeConn(t, conn)
|
|
||||||
|
|
||||||
sql := `create temporary table ledger(
|
|
||||||
id serial primary key,
|
|
||||||
description varchar not null,
|
|
||||||
amount int not null
|
|
||||||
);`
|
|
||||||
mustExec(t, conn, sql)
|
|
||||||
|
|
||||||
batch := &pgx.Batch{}
|
|
||||||
|
|
||||||
numInserts := 1000
|
|
||||||
|
|
||||||
for i := 0; i < numInserts; i++ {
|
|
||||||
batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1)
|
batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1)
|
||||||
}
|
batch.Queue("insert into ledger(description, amount) values($1, $2)", "q2", 2)
|
||||||
batch.Queue("select count(*) from ledger")
|
batch.Queue("insert into ledger(description, amount) values($1, $2)", "q3", 3)
|
||||||
|
batch.Queue("select id, description, amount from ledger order by id")
|
||||||
|
batch.Queue("select id, description, amount from ledger order by id")
|
||||||
|
batch.Queue("select * from ledger where false")
|
||||||
|
batch.Queue("select sum(amount) from ledger")
|
||||||
|
|
||||||
br := conn.SendBatch(context.Background(), batch)
|
br := conn.SendBatch(context.Background(), batch)
|
||||||
|
|
||||||
for i := 0; i < numInserts; i++ {
|
|
||||||
ct, err := br.Exec()
|
ct, err := br.Exec()
|
||||||
assert.NoError(t, err)
|
if err != nil {
|
||||||
assert.EqualValues(t, 1, ct.RowsAffected())
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
if ct.RowsAffected() != 1 {
|
||||||
|
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
|
||||||
|
}
|
||||||
|
|
||||||
var actualInserts int
|
ct, err = br.Exec()
|
||||||
err := br.QueryRow().Scan(&actualInserts)
|
if err != nil {
|
||||||
assert.NoError(t, err)
|
t.Error(err)
|
||||||
assert.EqualValues(t, numInserts, actualInserts)
|
}
|
||||||
|
if ct.RowsAffected() != 1 {
|
||||||
|
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
|
||||||
|
}
|
||||||
|
|
||||||
err = br.Close()
|
ct, err = br.Exec()
|
||||||
require.NoError(t, err)
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if ct.RowsAffected() != 1 {
|
||||||
|
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
|
||||||
|
}
|
||||||
|
|
||||||
ensureConnValid(t, conn)
|
selectFromLedgerExpectedRows := []struct {
|
||||||
}
|
id int32
|
||||||
|
description string
|
||||||
|
amount int32
|
||||||
|
}{
|
||||||
|
{1, "q1", 1},
|
||||||
|
{2, "q2", 2},
|
||||||
|
{3, "q3", 3},
|
||||||
|
}
|
||||||
|
|
||||||
func TestConnSendBatchWithPreparedStatement(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
|
||||||
defer closeConn(t, conn)
|
|
||||||
|
|
||||||
skipCockroachDB(t, conn, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
|
|
||||||
|
|
||||||
_, err := conn.Prepare(context.Background(), "ps1", "select n from generate_series(0,$1::int) n")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
batch := &pgx.Batch{}
|
|
||||||
|
|
||||||
queryCount := 3
|
|
||||||
for i := 0; i < queryCount; i++ {
|
|
||||||
batch.Queue("ps1", 5)
|
|
||||||
}
|
|
||||||
|
|
||||||
br := conn.SendBatch(context.Background(), batch)
|
|
||||||
|
|
||||||
for i := 0; i < queryCount; i++ {
|
|
||||||
rows, err := br.Query()
|
rows, err := br.Query()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for k := 0; rows.Next(); k++ {
|
var id int32
|
||||||
var n int
|
var description string
|
||||||
if err := rows.Scan(&n); err != nil {
|
var amount int32
|
||||||
t.Fatal(err)
|
rowCount := 0
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
if rowCount >= len(selectFromLedgerExpectedRows) {
|
||||||
|
t.Fatalf("got too many rows: %d", rowCount)
|
||||||
}
|
}
|
||||||
if n != k {
|
|
||||||
t.Fatalf("n => %v, want %v", n, k)
|
if err := rows.Scan(&id, &description, &amount); err != nil {
|
||||||
|
t.Fatalf("row %d: %v", rowCount, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if id != selectFromLedgerExpectedRows[rowCount].id {
|
||||||
|
t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id)
|
||||||
|
}
|
||||||
|
if description != selectFromLedgerExpectedRows[rowCount].description {
|
||||||
|
t.Errorf("description => %v, want %v", description, selectFromLedgerExpectedRows[rowCount].description)
|
||||||
|
}
|
||||||
|
if amount != selectFromLedgerExpectedRows[rowCount].amount {
|
||||||
|
t.Errorf("amount => %v, want %v", amount, selectFromLedgerExpectedRows[rowCount].amount)
|
||||||
|
}
|
||||||
|
|
||||||
|
rowCount++
|
||||||
}
|
}
|
||||||
|
|
||||||
if rows.Err() != nil {
|
if rows.Err() != nil {
|
||||||
t.Fatal(rows.Err())
|
t.Fatal(rows.Err())
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
err = br.Close()
|
rowCount = 0
|
||||||
if err != nil {
|
_, err = br.QueryFunc([]interface{}{&id, &description, &amount}, func(pgx.QueryFuncRow) error {
|
||||||
t.Fatal(err)
|
if id != selectFromLedgerExpectedRows[rowCount].id {
|
||||||
}
|
t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id)
|
||||||
|
}
|
||||||
|
if description != selectFromLedgerExpectedRows[rowCount].description {
|
||||||
|
t.Errorf("description => %v, want %v", description, selectFromLedgerExpectedRows[rowCount].description)
|
||||||
|
}
|
||||||
|
if amount != selectFromLedgerExpectedRows[rowCount].amount {
|
||||||
|
t.Errorf("amount => %v, want %v", amount, selectFromLedgerExpectedRows[rowCount].amount)
|
||||||
|
}
|
||||||
|
|
||||||
ensureConnValid(t, conn)
|
rowCount++
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = br.QueryRow().Scan(&id, &description, &amount)
|
||||||
|
if !errors.Is(err, pgx.ErrNoRows) {
|
||||||
|
t.Errorf("expected pgx.ErrNoRows but got: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = br.QueryRow().Scan(&amount)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if amount != 6 {
|
||||||
|
t.Errorf("amount => %v, want %v", amount, 6)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = br.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnSendBatchMany(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) {
|
||||||
|
sql := `create temporary table ledger(
|
||||||
|
id serial primary key,
|
||||||
|
description varchar not null,
|
||||||
|
amount int not null
|
||||||
|
);`
|
||||||
|
mustExec(t, conn, sql)
|
||||||
|
|
||||||
|
batch := &pgx.Batch{}
|
||||||
|
|
||||||
|
numInserts := 1000
|
||||||
|
|
||||||
|
for i := 0; i < numInserts; i++ {
|
||||||
|
batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1)
|
||||||
|
}
|
||||||
|
batch.Queue("select count(*) from ledger")
|
||||||
|
|
||||||
|
br := conn.SendBatch(context.Background(), batch)
|
||||||
|
|
||||||
|
for i := 0; i < numInserts; i++ {
|
||||||
|
ct, err := br.Exec()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, 1, ct.RowsAffected())
|
||||||
|
}
|
||||||
|
|
||||||
|
var actualInserts int
|
||||||
|
err := br.QueryRow().Scan(&actualInserts)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, numInserts, actualInserts)
|
||||||
|
|
||||||
|
err = br.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnSendBatchWithPreparedStatement(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
modes := []pgx.QueryExecMode{
|
||||||
|
pgx.QueryExecModeCacheStatement,
|
||||||
|
pgx.QueryExecModeCacheDescribe,
|
||||||
|
pgx.QueryExecModeDescribeExec,
|
||||||
|
pgx.QueryExecModeExec,
|
||||||
|
// Don't test simple mode with prepared statements.
|
||||||
|
}
|
||||||
|
testWithQueryExecModes(t, modes, func(t *testing.T, conn *pgx.Conn) {
|
||||||
|
skipCockroachDB(t, conn, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
|
||||||
|
_, err := conn.Prepare(context.Background(), "ps1", "select n from generate_series(0,$1::int) n")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
batch := &pgx.Batch{}
|
||||||
|
|
||||||
|
queryCount := 3
|
||||||
|
for i := 0; i < queryCount; i++ {
|
||||||
|
batch.Queue("ps1", 5)
|
||||||
|
}
|
||||||
|
|
||||||
|
br := conn.SendBatch(context.Background(), batch)
|
||||||
|
|
||||||
|
for i := 0; i < queryCount; i++ {
|
||||||
|
rows, err := br.Query()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for k := 0; rows.Next(); k++ {
|
||||||
|
var n int
|
||||||
|
if err := rows.Scan(&n); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if n != k {
|
||||||
|
t.Fatalf("n => %v, want %v", n, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows.Err() != nil {
|
||||||
|
t.Fatal(rows.Err())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = br.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://github.com/jackc/pgx/issues/856
|
// https://github.com/jackc/pgx/issues/856
|
||||||
|
@ -303,316 +300,308 @@ func TestConnSendBatchWithPreparedStatementAndStatementCacheDisabled(t *testing.
|
||||||
func TestConnSendBatchCloseRowsPartiallyRead(t *testing.T) {
|
func TestConnSendBatchCloseRowsPartiallyRead(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) {
|
||||||
defer closeConn(t, conn)
|
|
||||||
|
|
||||||
batch := &pgx.Batch{}
|
batch := &pgx.Batch{}
|
||||||
batch.Queue("select n from generate_series(0,5) n")
|
batch.Queue("select n from generate_series(0,5) n")
|
||||||
batch.Queue("select n from generate_series(0,5) n")
|
batch.Queue("select n from generate_series(0,5) n")
|
||||||
|
|
||||||
br := conn.SendBatch(context.Background(), batch)
|
br := conn.SendBatch(context.Background(), batch)
|
||||||
|
|
||||||
rows, err := br.Query()
|
rows, err := br.Query()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < 3; i++ {
|
|
||||||
if !rows.Next() {
|
|
||||||
t.Error("expected a row to be available")
|
|
||||||
}
|
|
||||||
|
|
||||||
var n int
|
|
||||||
if err := rows.Scan(&n); err != nil {
|
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if n != i {
|
|
||||||
t.Errorf("n => %v, want %v", n, i)
|
for i := 0; i < 3; i++ {
|
||||||
|
if !rows.Next() {
|
||||||
|
t.Error("expected a row to be available")
|
||||||
|
}
|
||||||
|
|
||||||
|
var n int
|
||||||
|
if err := rows.Scan(&n); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if n != i {
|
||||||
|
t.Errorf("n => %v, want %v", n, i)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
rows.Close()
|
rows.Close()
|
||||||
|
|
||||||
rows, err = br.Query()
|
rows, err = br.Query()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; rows.Next(); i++ {
|
|
||||||
var n int
|
|
||||||
if err := rows.Scan(&n); err != nil {
|
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if n != i {
|
|
||||||
t.Errorf("n => %v, want %v", n, i)
|
for i := 0; rows.Next(); i++ {
|
||||||
|
var n int
|
||||||
|
if err := rows.Scan(&n); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if n != i {
|
||||||
|
t.Errorf("n => %v, want %v", n, i)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if rows.Err() != nil {
|
if rows.Err() != nil {
|
||||||
t.Error(rows.Err())
|
t.Error(rows.Err())
|
||||||
}
|
}
|
||||||
|
|
||||||
err = br.Close()
|
err = br.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ensureConnValid(t, conn)
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnSendBatchQueryError(t *testing.T) {
|
func TestConnSendBatchQueryError(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) {
|
||||||
defer closeConn(t, conn)
|
|
||||||
|
|
||||||
batch := &pgx.Batch{}
|
batch := &pgx.Batch{}
|
||||||
batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0")
|
batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0")
|
||||||
batch.Queue("select n from generate_series(0,5) n")
|
batch.Queue("select n from generate_series(0,5) n")
|
||||||
|
|
||||||
br := conn.SendBatch(context.Background(), batch)
|
br := conn.SendBatch(context.Background(), batch)
|
||||||
|
|
||||||
rows, err := br.Query()
|
rows, err := br.Query()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; rows.Next(); i++ {
|
|
||||||
var n int
|
|
||||||
if err := rows.Scan(&n); err != nil {
|
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if n != i {
|
|
||||||
t.Errorf("n => %v, want %v", n, i)
|
for i := 0; rows.Next(); i++ {
|
||||||
|
var n int
|
||||||
|
if err := rows.Scan(&n); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if n != i {
|
||||||
|
t.Errorf("n => %v, want %v", n, i)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if pgErr, ok := rows.Err().(*pgconn.PgError); !(ok && pgErr.Code == "22012") {
|
if pgErr, ok := rows.Err().(*pgconn.PgError); !(ok && pgErr.Code == "22012") {
|
||||||
t.Errorf("rows.Err() => %v, want error code %v", rows.Err(), 22012)
|
t.Errorf("rows.Err() => %v, want error code %v", rows.Err(), 22012)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = br.Close()
|
err = br.Close()
|
||||||
if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "22012") {
|
if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "22012") {
|
||||||
t.Errorf("rows.Err() => %v, want error code %v", err, 22012)
|
t.Errorf("rows.Err() => %v, want error code %v", err, 22012)
|
||||||
}
|
}
|
||||||
|
|
||||||
ensureConnValid(t, conn)
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnSendBatchQuerySyntaxError(t *testing.T) {
|
func TestConnSendBatchQuerySyntaxError(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) {
|
||||||
defer closeConn(t, conn)
|
|
||||||
|
|
||||||
batch := &pgx.Batch{}
|
batch := &pgx.Batch{}
|
||||||
batch.Queue("select 1 1")
|
batch.Queue("select 1 1")
|
||||||
|
|
||||||
br := conn.SendBatch(context.Background(), batch)
|
br := conn.SendBatch(context.Background(), batch)
|
||||||
|
|
||||||
var n int32
|
var n int32
|
||||||
err := br.QueryRow().Scan(&n)
|
err := br.QueryRow().Scan(&n)
|
||||||
if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "42601") {
|
if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "42601") {
|
||||||
t.Errorf("rows.Err() => %v, want error code %v", err, 42601)
|
t.Errorf("rows.Err() => %v, want error code %v", err, 42601)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = br.Close()
|
err = br.Close()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("Expected error")
|
t.Error("Expected error")
|
||||||
}
|
}
|
||||||
|
|
||||||
ensureConnValid(t, conn)
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnSendBatchQueryRowInsert(t *testing.T) {
|
func TestConnSendBatchQueryRowInsert(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) {
|
||||||
defer closeConn(t, conn)
|
|
||||||
|
|
||||||
sql := `create temporary table ledger(
|
sql := `create temporary table ledger(
|
||||||
id serial primary key,
|
id serial primary key,
|
||||||
description varchar not null,
|
description varchar not null,
|
||||||
amount int not null
|
amount int not null
|
||||||
);`
|
);`
|
||||||
mustExec(t, conn, sql)
|
mustExec(t, conn, sql)
|
||||||
|
|
||||||
batch := &pgx.Batch{}
|
batch := &pgx.Batch{}
|
||||||
batch.Queue("select 1")
|
batch.Queue("select 1")
|
||||||
batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1)
|
batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1)
|
||||||
|
|
||||||
br := conn.SendBatch(context.Background(), batch)
|
br := conn.SendBatch(context.Background(), batch)
|
||||||
|
|
||||||
var value int
|
var value int
|
||||||
err := br.QueryRow().Scan(&value)
|
err := br.QueryRow().Scan(&value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ct, err := br.Exec()
|
ct, err := br.Exec()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if ct.RowsAffected() != 2 {
|
if ct.RowsAffected() != 2 {
|
||||||
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2)
|
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2)
|
||||||
}
|
}
|
||||||
|
|
||||||
br.Close()
|
br.Close()
|
||||||
|
|
||||||
ensureConnValid(t, conn)
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnSendBatchQueryPartialReadInsert(t *testing.T) {
|
func TestConnSendBatchQueryPartialReadInsert(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) {
|
||||||
defer closeConn(t, conn)
|
|
||||||
|
|
||||||
sql := `create temporary table ledger(
|
sql := `create temporary table ledger(
|
||||||
id serial primary key,
|
id serial primary key,
|
||||||
description varchar not null,
|
description varchar not null,
|
||||||
amount int not null
|
amount int not null
|
||||||
);`
|
);`
|
||||||
mustExec(t, conn, sql)
|
mustExec(t, conn, sql)
|
||||||
|
|
||||||
batch := &pgx.Batch{}
|
batch := &pgx.Batch{}
|
||||||
batch.Queue("select 1 union all select 2 union all select 3")
|
batch.Queue("select 1 union all select 2 union all select 3")
|
||||||
batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1)
|
batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1)
|
||||||
|
|
||||||
br := conn.SendBatch(context.Background(), batch)
|
br := conn.SendBatch(context.Background(), batch)
|
||||||
|
|
||||||
rows, err := br.Query()
|
rows, err := br.Query()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
rows.Close()
|
rows.Close()
|
||||||
|
|
||||||
ct, err := br.Exec()
|
ct, err := br.Exec()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if ct.RowsAffected() != 2 {
|
if ct.RowsAffected() != 2 {
|
||||||
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2)
|
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2)
|
||||||
}
|
}
|
||||||
|
|
||||||
br.Close()
|
br.Close()
|
||||||
|
|
||||||
ensureConnValid(t, conn)
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTxSendBatch(t *testing.T) {
|
func TestTxSendBatch(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) {
|
||||||
defer closeConn(t, conn)
|
|
||||||
|
|
||||||
sql := `create temporary table ledger1(
|
sql := `create temporary table ledger1(
|
||||||
id serial primary key,
|
id serial primary key,
|
||||||
description varchar not null
|
description varchar not null
|
||||||
);`
|
);`
|
||||||
mustExec(t, conn, sql)
|
mustExec(t, conn, sql)
|
||||||
|
|
||||||
sql = `create temporary table ledger2(
|
sql = `create temporary table ledger2(
|
||||||
id int primary key,
|
id int primary key,
|
||||||
amount int not null
|
amount int not null
|
||||||
);`
|
);`
|
||||||
mustExec(t, conn, sql)
|
mustExec(t, conn, sql)
|
||||||
|
|
||||||
tx, _ := conn.Begin(context.Background())
|
tx, _ := conn.Begin(context.Background())
|
||||||
batch := &pgx.Batch{}
|
batch := &pgx.Batch{}
|
||||||
batch.Queue("insert into ledger1(description) values($1) returning id", "q1")
|
batch.Queue("insert into ledger1(description) values($1) returning id", "q1")
|
||||||
|
|
||||||
br := tx.SendBatch(context.Background(), batch)
|
br := tx.SendBatch(context.Background(), batch)
|
||||||
|
|
||||||
var id int
|
var id int
|
||||||
err := br.QueryRow().Scan(&id)
|
err := br.QueryRow().Scan(&id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
br.Close()
|
br.Close()
|
||||||
|
|
||||||
batch = &pgx.Batch{}
|
batch = &pgx.Batch{}
|
||||||
batch.Queue("insert into ledger2(id,amount) values($1, $2)", id, 2)
|
batch.Queue("insert into ledger2(id,amount) values($1, $2)", id, 2)
|
||||||
batch.Queue("select amount from ledger2 where id = $1", id)
|
batch.Queue("select amount from ledger2 where id = $1", id)
|
||||||
|
|
||||||
br = tx.SendBatch(context.Background(), batch)
|
br = tx.SendBatch(context.Background(), batch)
|
||||||
|
|
||||||
ct, err := br.Exec()
|
ct, err := br.Exec()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
if ct.RowsAffected() != 1 {
|
if ct.RowsAffected() != 1 {
|
||||||
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
|
t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
var amount int
|
var amount int
|
||||||
err = br.QueryRow().Scan(&amount)
|
err = br.QueryRow().Scan(&amount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
br.Close()
|
br.Close()
|
||||||
tx.Commit(context.Background())
|
tx.Commit(context.Background())
|
||||||
|
|
||||||
var count int
|
var count int
|
||||||
conn.QueryRow(context.Background(), "select count(1) from ledger1 where id = $1", id).Scan(&count)
|
conn.QueryRow(context.Background(), "select count(1) from ledger1 where id = $1", id).Scan(&count)
|
||||||
if count != 1 {
|
if count != 1 {
|
||||||
t.Errorf("count => %v, want %v", count, 1)
|
t.Errorf("count => %v, want %v", count, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = br.Close()
|
err = br.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ensureConnValid(t, conn)
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTxSendBatchRollback(t *testing.T) {
|
func TestTxSendBatchRollback(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) {
|
||||||
defer closeConn(t, conn)
|
|
||||||
|
|
||||||
sql := `create temporary table ledger1(
|
sql := `create temporary table ledger1(
|
||||||
id serial primary key,
|
id serial primary key,
|
||||||
description varchar not null
|
description varchar not null
|
||||||
);`
|
);`
|
||||||
mustExec(t, conn, sql)
|
mustExec(t, conn, sql)
|
||||||
|
|
||||||
tx, _ := conn.Begin(context.Background())
|
tx, _ := conn.Begin(context.Background())
|
||||||
batch := &pgx.Batch{}
|
batch := &pgx.Batch{}
|
||||||
batch.Queue("insert into ledger1(description) values($1) returning id", "q1")
|
batch.Queue("insert into ledger1(description) values($1) returning id", "q1")
|
||||||
|
|
||||||
br := tx.SendBatch(context.Background(), batch)
|
br := tx.SendBatch(context.Background(), batch)
|
||||||
|
|
||||||
var id int
|
var id int
|
||||||
err := br.QueryRow().Scan(&id)
|
err := br.QueryRow().Scan(&id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
br.Close()
|
br.Close()
|
||||||
tx.Rollback(context.Background())
|
tx.Rollback(context.Background())
|
||||||
|
|
||||||
row := conn.QueryRow(context.Background(), "select count(1) from ledger1 where id = $1", id)
|
row := conn.QueryRow(context.Background(), "select count(1) from ledger1 where id = $1", id)
|
||||||
var count int
|
var count int
|
||||||
row.Scan(&count)
|
row.Scan(&count)
|
||||||
if count != 0 {
|
if count != 0 {
|
||||||
t.Errorf("count => %v, want %v", count, 0)
|
t.Errorf("count => %v, want %v", count, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
ensureConnValid(t, conn)
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnBeginBatchDeferredError(t *testing.T) {
|
func TestConnBeginBatchDeferredError(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) {
|
||||||
defer closeConn(t, conn)
|
|
||||||
|
|
||||||
skipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
|
skipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
|
||||||
|
|
||||||
mustExec(t, conn, `create temporary table t (
|
mustExec(t, conn, `create temporary table t (
|
||||||
id text primary key,
|
id text primary key,
|
||||||
n int not null,
|
n int not null,
|
||||||
unique (n) deferrable initially deferred
|
unique (n) deferrable initially deferred
|
||||||
|
@ -620,36 +609,36 @@ func TestConnBeginBatchDeferredError(t *testing.T) {
|
||||||
|
|
||||||
insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`)
|
insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`)
|
||||||
|
|
||||||
batch := &pgx.Batch{}
|
batch := &pgx.Batch{}
|
||||||
|
|
||||||
batch.Queue(`update t set n=n+1 where id='b' returning *`)
|
batch.Queue(`update t set n=n+1 where id='b' returning *`)
|
||||||
|
|
||||||
br := conn.SendBatch(context.Background(), batch)
|
br := conn.SendBatch(context.Background(), batch)
|
||||||
|
|
||||||
rows, err := br.Query()
|
rows, err := br.Query()
|
||||||
if err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for rows.Next() {
|
|
||||||
var id string
|
|
||||||
var n int32
|
|
||||||
err = rows.Scan(&id, &n)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
err = br.Close()
|
for rows.Next() {
|
||||||
if err == nil {
|
var id string
|
||||||
t.Fatal("expected error 23505 but got none")
|
var n int32
|
||||||
}
|
err = rows.Scan(&id, &n)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err, ok := err.(*pgconn.PgError); !ok || err.Code != "23505" {
|
err = br.Close()
|
||||||
t.Fatalf("expected error 23505, got %v", err)
|
if err == nil {
|
||||||
}
|
t.Fatal("expected error 23505 but got none")
|
||||||
|
}
|
||||||
|
|
||||||
ensureConnValid(t, conn)
|
if err, ok := err.(*pgconn.PgError); !ok || err.Code != "23505" {
|
||||||
|
t.Fatalf("expected error 23505, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnSendBatchNoStatementCache(t *testing.T) {
|
func TestConnSendBatchNoStatementCache(t *testing.T) {
|
||||||
|
|
143
conn.go
143
conn.go
|
@ -861,9 +861,10 @@ func (c *Conn) QueryFunc(ctx context.Context, sql string, args []interface{}, sc
|
||||||
// explicit transaction control statements are executed. The returned BatchResults must be closed before the connection
|
// explicit transaction control statements are executed. The returned BatchResults must be closed before the connection
|
||||||
// is used again.
|
// is used again.
|
||||||
func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
|
func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
|
||||||
simpleProtocol := c.config.DefaultQueryExecMode == QueryExecModeSimpleProtocol
|
mode := c.config.DefaultQueryExecMode
|
||||||
var sb strings.Builder
|
|
||||||
if simpleProtocol {
|
if mode == QueryExecModeSimpleProtocol {
|
||||||
|
var sb strings.Builder
|
||||||
for i, bi := range b.items {
|
for i, bi := range b.items {
|
||||||
if i > 0 {
|
if i > 0 {
|
||||||
sb.WriteByte(';')
|
sb.WriteByte(';')
|
||||||
|
@ -884,66 +885,102 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
distinctUnpreparedQueries := map[string]struct{}{}
|
|
||||||
|
|
||||||
for _, bi := range b.items {
|
|
||||||
if _, ok := c.preparedStatements[bi.query]; ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
distinctUnpreparedQueries[bi.query] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
var stmtCache stmtcache.Cache
|
|
||||||
if len(distinctUnpreparedQueries) > 0 {
|
|
||||||
if c.statementCache != nil && c.statementCache.Cap() >= len(distinctUnpreparedQueries) {
|
|
||||||
stmtCache = c.statementCache
|
|
||||||
} else {
|
|
||||||
stmtCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, len(distinctUnpreparedQueries))
|
|
||||||
}
|
|
||||||
|
|
||||||
for sql, _ := range distinctUnpreparedQueries {
|
|
||||||
_, err := stmtCache.Get(ctx, sql)
|
|
||||||
if err != nil {
|
|
||||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
batch := &pgconn.Batch{}
|
batch := &pgconn.Batch{}
|
||||||
|
|
||||||
for _, bi := range b.items {
|
if mode == QueryExecModeExec {
|
||||||
c.eqb.Reset()
|
for _, bi := range b.items {
|
||||||
|
c.eqb.Reset()
|
||||||
|
anynil.NormalizeSlice(bi.arguments)
|
||||||
|
|
||||||
sd := c.preparedStatements[bi.query]
|
sd := c.preparedStatements[bi.query]
|
||||||
if sd == nil {
|
if sd != nil {
|
||||||
var err error
|
if len(sd.ParamOIDs) != len(bi.arguments) {
|
||||||
sd, err = stmtCache.Get(ctx, bi.query)
|
return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("mismatched param and argument count")}
|
||||||
if err != nil {
|
}
|
||||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
|
||||||
|
for i := range bi.arguments {
|
||||||
|
err := c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], bi.arguments[i])
|
||||||
|
if err != nil {
|
||||||
|
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range sd.Fields {
|
||||||
|
c.eqb.AppendResultFormat(c.TypeMap().FormatCodeForOID(sd.Fields[i].DataTypeOID))
|
||||||
|
}
|
||||||
|
|
||||||
|
batch.ExecPrepared(sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats)
|
||||||
|
} else {
|
||||||
|
err := c.appendParamsForQueryExecModeExec(bi.arguments)
|
||||||
|
if err != nil {
|
||||||
|
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||||
|
}
|
||||||
|
batch.ExecParams(bi.query, c.eqb.paramValues, nil, c.eqb.paramFormats, c.eqb.resultFormats)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
|
||||||
|
distinctUnpreparedQueries := map[string]struct{}{}
|
||||||
|
|
||||||
|
for _, bi := range b.items {
|
||||||
|
if _, ok := c.preparedStatements[bi.query]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
distinctUnpreparedQueries[bi.query] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
var stmtCache stmtcache.Cache
|
||||||
|
if len(distinctUnpreparedQueries) > 0 {
|
||||||
|
if mode == QueryExecModeCacheStatement && c.statementCache != nil && c.statementCache.Cap() >= len(distinctUnpreparedQueries) {
|
||||||
|
stmtCache = c.statementCache
|
||||||
|
} else if mode == QueryExecModeCacheStatement && c.descriptionCache != nil && c.descriptionCache.Cap() >= len(distinctUnpreparedQueries) {
|
||||||
|
stmtCache = c.descriptionCache
|
||||||
|
} else {
|
||||||
|
stmtCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, len(distinctUnpreparedQueries))
|
||||||
|
}
|
||||||
|
|
||||||
|
for sql, _ := range distinctUnpreparedQueries {
|
||||||
|
_, err := stmtCache.Get(ctx, sql)
|
||||||
|
if err != nil {
|
||||||
|
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(sd.ParamOIDs) != len(bi.arguments) {
|
for _, bi := range b.items {
|
||||||
return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("mismatched param and argument count")}
|
c.eqb.Reset()
|
||||||
}
|
|
||||||
|
|
||||||
anynil.NormalizeSlice(bi.arguments)
|
sd := c.preparedStatements[bi.query]
|
||||||
|
if sd == nil {
|
||||||
for i := range bi.arguments {
|
var err error
|
||||||
err := c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], bi.arguments[i])
|
sd, err = stmtCache.Get(ctx, bi.query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
for i := range sd.Fields {
|
if len(sd.ParamOIDs) != len(bi.arguments) {
|
||||||
c.eqb.AppendResultFormat(c.TypeMap().FormatCodeForOID(sd.Fields[i].DataTypeOID))
|
return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("mismatched param and argument count")}
|
||||||
}
|
}
|
||||||
|
|
||||||
if sd.Name == "" {
|
anynil.NormalizeSlice(bi.arguments)
|
||||||
batch.ExecParams(bi.query, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, c.eqb.resultFormats)
|
|
||||||
} else {
|
for i := range bi.arguments {
|
||||||
batch.ExecPrepared(sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats)
|
err := c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], bi.arguments[i])
|
||||||
|
if err != nil {
|
||||||
|
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range sd.Fields {
|
||||||
|
c.eqb.AppendResultFormat(c.TypeMap().FormatCodeForOID(sd.Fields[i].DataTypeOID))
|
||||||
|
}
|
||||||
|
|
||||||
|
if sd.Name == "" {
|
||||||
|
batch.ExecParams(bi.query, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, c.eqb.resultFormats)
|
||||||
|
} else {
|
||||||
|
batch.ExecPrepared(sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -13,13 +13,18 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func testWithAllQueryExecModes(t *testing.T, f func(t *testing.T, conn *pgx.Conn)) {
|
func testWithAllQueryExecModes(t *testing.T, f func(t *testing.T, conn *pgx.Conn)) {
|
||||||
for _, mode := range []pgx.QueryExecMode{
|
modes := []pgx.QueryExecMode{
|
||||||
pgx.QueryExecModeCacheStatement,
|
pgx.QueryExecModeCacheStatement,
|
||||||
pgx.QueryExecModeCacheDescribe,
|
pgx.QueryExecModeCacheDescribe,
|
||||||
pgx.QueryExecModeDescribeExec,
|
pgx.QueryExecModeDescribeExec,
|
||||||
pgx.QueryExecModeExec,
|
pgx.QueryExecModeExec,
|
||||||
pgx.QueryExecModeSimpleProtocol,
|
pgx.QueryExecModeSimpleProtocol,
|
||||||
} {
|
}
|
||||||
|
testWithQueryExecModes(t, modes, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testWithQueryExecModes(t *testing.T, modes []pgx.QueryExecMode, f func(t *testing.T, conn *pgx.Conn)) {
|
||||||
|
for _, mode := range modes {
|
||||||
t.Run(mode.String(),
|
t.Run(mode.String(),
|
||||||
func(t *testing.T) {
|
func(t *testing.T) {
|
||||||
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
|
Loading…
Reference in New Issue