diff --git a/batch_test.go b/batch_test.go index da35646b..3e5a2d46 100644 --- a/batch_test.go +++ b/batch_test.go @@ -15,230 +15,227 @@ import ( func TestConnSendBatch(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + skipCockroachDB(t, conn, "Server serial type is incompatible with test") - skipCockroachDB(t, conn, "Server serial type is incompatible with test") - - sql := `create temporary table ledger( + sql := `create temporary table ledger( id serial primary key, description varchar not null, amount int not null );` - mustExec(t, conn, sql) + mustExec(t, conn, sql) - batch := &pgx.Batch{} - batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1) - batch.Queue("insert into ledger(description, amount) values($1, $2)", "q2", 2) - batch.Queue("insert into ledger(description, amount) values($1, $2)", "q3", 3) - batch.Queue("select id, description, amount from ledger order by id") - batch.Queue("select id, description, amount from ledger order by id") - batch.Queue("select * from ledger where false") - batch.Queue("select sum(amount) from ledger") - - br := conn.SendBatch(context.Background(), batch) - - ct, err := br.Exec() - if err != nil { - t.Error(err) - } - if ct.RowsAffected() != 1 { - t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) - } - - ct, err = br.Exec() - if err != nil { - t.Error(err) - } - if ct.RowsAffected() != 1 { - t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) - } - - ct, err = br.Exec() - if err != nil { - t.Error(err) - } - if ct.RowsAffected() != 1 { - t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) - } - - selectFromLedgerExpectedRows := []struct { - id int32 - description string - amount int32 - }{ - {1, "q1", 1}, - {2, "q2", 2}, - {3, "q3", 3}, - } - - rows, err := br.Query() - if err != nil { - t.Error(err) - } - - var id int32 - var description string - var amount int32 - rowCount := 0 - - for rows.Next() { - if rowCount >= len(selectFromLedgerExpectedRows) { - t.Fatalf("got too many rows: %d", rowCount) - } - - if err := rows.Scan(&id, &description, &amount); err != nil { - t.Fatalf("row %d: %v", rowCount, err) - } - - if id != selectFromLedgerExpectedRows[rowCount].id { - t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id) - } - if description != selectFromLedgerExpectedRows[rowCount].description { - t.Errorf("description => %v, want %v", description, selectFromLedgerExpectedRows[rowCount].description) - } - if amount != selectFromLedgerExpectedRows[rowCount].amount { - t.Errorf("amount => %v, want %v", amount, selectFromLedgerExpectedRows[rowCount].amount) - } - - rowCount++ - } - - if rows.Err() != nil { - t.Fatal(rows.Err()) - } - - rowCount = 0 - _, err = br.QueryFunc([]interface{}{&id, &description, &amount}, func(pgx.QueryFuncRow) error { - if id != selectFromLedgerExpectedRows[rowCount].id { - t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id) - } - if description != selectFromLedgerExpectedRows[rowCount].description { - t.Errorf("description => %v, want %v", description, selectFromLedgerExpectedRows[rowCount].description) - } - if amount != selectFromLedgerExpectedRows[rowCount].amount { - t.Errorf("amount => %v, want %v", amount, selectFromLedgerExpectedRows[rowCount].amount) - } - - rowCount++ - - return nil - }) - if err != nil { - t.Error(err) - } - - err = br.QueryRow().Scan(&id, &description, &amount) - if !errors.Is(err, pgx.ErrNoRows) { - t.Errorf("expected pgx.ErrNoRows but got: %v", err) - } - - err = br.QueryRow().Scan(&amount) - if err != nil { - t.Error(err) - } - if amount != 6 { - t.Errorf("amount => %v, want %v", amount, 6) - } - - err = br.Close() - if err != nil { - t.Fatal(err) - } - - ensureConnValid(t, conn) -} - -func TestConnSendBatchMany(t *testing.T) { - t.Parallel() - - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) - - sql := `create temporary table ledger( - id serial primary key, - description varchar not null, - amount int not null - );` - mustExec(t, conn, sql) - - batch := &pgx.Batch{} - - numInserts := 1000 - - for i := 0; i < numInserts; i++ { + batch := &pgx.Batch{} batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1) - } - batch.Queue("select count(*) from ledger") + batch.Queue("insert into ledger(description, amount) values($1, $2)", "q2", 2) + batch.Queue("insert into ledger(description, amount) values($1, $2)", "q3", 3) + batch.Queue("select id, description, amount from ledger order by id") + batch.Queue("select id, description, amount from ledger order by id") + batch.Queue("select * from ledger where false") + batch.Queue("select sum(amount) from ledger") - br := conn.SendBatch(context.Background(), batch) + br := conn.SendBatch(context.Background(), batch) - for i := 0; i < numInserts; i++ { ct, err := br.Exec() - assert.NoError(t, err) - assert.EqualValues(t, 1, ct.RowsAffected()) - } + if err != nil { + t.Error(err) + } + if ct.RowsAffected() != 1 { + t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) + } - var actualInserts int - err := br.QueryRow().Scan(&actualInserts) - assert.NoError(t, err) - assert.EqualValues(t, numInserts, actualInserts) + ct, err = br.Exec() + if err != nil { + t.Error(err) + } + if ct.RowsAffected() != 1 { + t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) + } - err = br.Close() - require.NoError(t, err) + ct, err = br.Exec() + if err != nil { + t.Error(err) + } + if ct.RowsAffected() != 1 { + t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) + } - ensureConnValid(t, conn) -} + selectFromLedgerExpectedRows := []struct { + id int32 + description string + amount int32 + }{ + {1, "q1", 1}, + {2, "q2", 2}, + {3, "q3", 3}, + } -func TestConnSendBatchWithPreparedStatement(t *testing.T) { - t.Parallel() - - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) - - skipCockroachDB(t, conn, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") - - _, err := conn.Prepare(context.Background(), "ps1", "select n from generate_series(0,$1::int) n") - if err != nil { - t.Fatal(err) - } - - batch := &pgx.Batch{} - - queryCount := 3 - for i := 0; i < queryCount; i++ { - batch.Queue("ps1", 5) - } - - br := conn.SendBatch(context.Background(), batch) - - for i := 0; i < queryCount; i++ { rows, err := br.Query() if err != nil { - t.Fatal(err) + t.Error(err) } - for k := 0; rows.Next(); k++ { - var n int - if err := rows.Scan(&n); err != nil { - t.Fatal(err) + var id int32 + var description string + var amount int32 + rowCount := 0 + + for rows.Next() { + if rowCount >= len(selectFromLedgerExpectedRows) { + t.Fatalf("got too many rows: %d", rowCount) } - if n != k { - t.Fatalf("n => %v, want %v", n, k) + + if err := rows.Scan(&id, &description, &amount); err != nil { + t.Fatalf("row %d: %v", rowCount, err) } + + if id != selectFromLedgerExpectedRows[rowCount].id { + t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id) + } + if description != selectFromLedgerExpectedRows[rowCount].description { + t.Errorf("description => %v, want %v", description, selectFromLedgerExpectedRows[rowCount].description) + } + if amount != selectFromLedgerExpectedRows[rowCount].amount { + t.Errorf("amount => %v, want %v", amount, selectFromLedgerExpectedRows[rowCount].amount) + } + + rowCount++ } if rows.Err() != nil { t.Fatal(rows.Err()) } - } - err = br.Close() - if err != nil { - t.Fatal(err) - } + rowCount = 0 + _, err = br.QueryFunc([]interface{}{&id, &description, &amount}, func(pgx.QueryFuncRow) error { + if id != selectFromLedgerExpectedRows[rowCount].id { + t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id) + } + if description != selectFromLedgerExpectedRows[rowCount].description { + t.Errorf("description => %v, want %v", description, selectFromLedgerExpectedRows[rowCount].description) + } + if amount != selectFromLedgerExpectedRows[rowCount].amount { + t.Errorf("amount => %v, want %v", amount, selectFromLedgerExpectedRows[rowCount].amount) + } - ensureConnValid(t, conn) + rowCount++ + + return nil + }) + if err != nil { + t.Error(err) + } + + err = br.QueryRow().Scan(&id, &description, &amount) + if !errors.Is(err, pgx.ErrNoRows) { + t.Errorf("expected pgx.ErrNoRows but got: %v", err) + } + + err = br.QueryRow().Scan(&amount) + if err != nil { + t.Error(err) + } + if amount != 6 { + t.Errorf("amount => %v, want %v", amount, 6) + } + + err = br.Close() + if err != nil { + t.Fatal(err) + } + }) +} + +func TestConnSendBatchMany(t *testing.T) { + t.Parallel() + + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { + sql := `create temporary table ledger( + id serial primary key, + description varchar not null, + amount int not null + );` + mustExec(t, conn, sql) + + batch := &pgx.Batch{} + + numInserts := 1000 + + for i := 0; i < numInserts; i++ { + batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1) + } + batch.Queue("select count(*) from ledger") + + br := conn.SendBatch(context.Background(), batch) + + for i := 0; i < numInserts; i++ { + ct, err := br.Exec() + assert.NoError(t, err) + assert.EqualValues(t, 1, ct.RowsAffected()) + } + + var actualInserts int + err := br.QueryRow().Scan(&actualInserts) + assert.NoError(t, err) + assert.EqualValues(t, numInserts, actualInserts) + + err = br.Close() + require.NoError(t, err) + }) +} + +func TestConnSendBatchWithPreparedStatement(t *testing.T) { + t.Parallel() + + modes := []pgx.QueryExecMode{ + pgx.QueryExecModeCacheStatement, + pgx.QueryExecModeCacheDescribe, + pgx.QueryExecModeDescribeExec, + pgx.QueryExecModeExec, + // Don't test simple mode with prepared statements. + } + testWithQueryExecModes(t, modes, func(t *testing.T, conn *pgx.Conn) { + skipCockroachDB(t, conn, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") + _, err := conn.Prepare(context.Background(), "ps1", "select n from generate_series(0,$1::int) n") + if err != nil { + t.Fatal(err) + } + + batch := &pgx.Batch{} + + queryCount := 3 + for i := 0; i < queryCount; i++ { + batch.Queue("ps1", 5) + } + + br := conn.SendBatch(context.Background(), batch) + + for i := 0; i < queryCount; i++ { + rows, err := br.Query() + if err != nil { + t.Fatal(err) + } + + for k := 0; rows.Next(); k++ { + var n int + if err := rows.Scan(&n); err != nil { + t.Fatal(err) + } + if n != k { + t.Fatalf("n => %v, want %v", n, k) + } + } + + if rows.Err() != nil { + t.Fatal(rows.Err()) + } + } + + err = br.Close() + if err != nil { + t.Fatal(err) + } + }) } // https://github.com/jackc/pgx/issues/856 @@ -303,316 +300,308 @@ func TestConnSendBatchWithPreparedStatementAndStatementCacheDisabled(t *testing. func TestConnSendBatchCloseRowsPartiallyRead(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { - batch := &pgx.Batch{} - batch.Queue("select n from generate_series(0,5) n") - batch.Queue("select n from generate_series(0,5) n") + batch := &pgx.Batch{} + batch.Queue("select n from generate_series(0,5) n") + batch.Queue("select n from generate_series(0,5) n") - br := conn.SendBatch(context.Background(), batch) + br := conn.SendBatch(context.Background(), batch) - rows, err := br.Query() - if err != nil { - t.Error(err) - } - - for i := 0; i < 3; i++ { - if !rows.Next() { - t.Error("expected a row to be available") - } - - var n int - if err := rows.Scan(&n); err != nil { + rows, err := br.Query() + if err != nil { t.Error(err) } - if n != i { - t.Errorf("n => %v, want %v", n, i) + + for i := 0; i < 3; i++ { + if !rows.Next() { + t.Error("expected a row to be available") + } + + var n int + if err := rows.Scan(&n); err != nil { + t.Error(err) + } + if n != i { + t.Errorf("n => %v, want %v", n, i) + } } - } - rows.Close() + rows.Close() - rows, err = br.Query() - if err != nil { - t.Error(err) - } - - for i := 0; rows.Next(); i++ { - var n int - if err := rows.Scan(&n); err != nil { + rows, err = br.Query() + if err != nil { t.Error(err) } - if n != i { - t.Errorf("n => %v, want %v", n, i) + + for i := 0; rows.Next(); i++ { + var n int + if err := rows.Scan(&n); err != nil { + t.Error(err) + } + if n != i { + t.Errorf("n => %v, want %v", n, i) + } } - } - if rows.Err() != nil { - t.Error(rows.Err()) - } + if rows.Err() != nil { + t.Error(rows.Err()) + } - err = br.Close() - if err != nil { - t.Fatal(err) - } + err = br.Close() + if err != nil { + t.Fatal(err) + } - ensureConnValid(t, conn) + }) } func TestConnSendBatchQueryError(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { - batch := &pgx.Batch{} - batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0") - batch.Queue("select n from generate_series(0,5) n") + batch := &pgx.Batch{} + batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0") + batch.Queue("select n from generate_series(0,5) n") - br := conn.SendBatch(context.Background(), batch) + br := conn.SendBatch(context.Background(), batch) - rows, err := br.Query() - if err != nil { - t.Error(err) - } - - for i := 0; rows.Next(); i++ { - var n int - if err := rows.Scan(&n); err != nil { + rows, err := br.Query() + if err != nil { t.Error(err) } - if n != i { - t.Errorf("n => %v, want %v", n, i) + + for i := 0; rows.Next(); i++ { + var n int + if err := rows.Scan(&n); err != nil { + t.Error(err) + } + if n != i { + t.Errorf("n => %v, want %v", n, i) + } } - } - if pgErr, ok := rows.Err().(*pgconn.PgError); !(ok && pgErr.Code == "22012") { - t.Errorf("rows.Err() => %v, want error code %v", rows.Err(), 22012) - } + if pgErr, ok := rows.Err().(*pgconn.PgError); !(ok && pgErr.Code == "22012") { + t.Errorf("rows.Err() => %v, want error code %v", rows.Err(), 22012) + } - err = br.Close() - if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "22012") { - t.Errorf("rows.Err() => %v, want error code %v", err, 22012) - } + err = br.Close() + if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "22012") { + t.Errorf("rows.Err() => %v, want error code %v", err, 22012) + } - ensureConnValid(t, conn) + }) } func TestConnSendBatchQuerySyntaxError(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { - batch := &pgx.Batch{} - batch.Queue("select 1 1") + batch := &pgx.Batch{} + batch.Queue("select 1 1") - br := conn.SendBatch(context.Background(), batch) + br := conn.SendBatch(context.Background(), batch) - var n int32 - err := br.QueryRow().Scan(&n) - if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "42601") { - t.Errorf("rows.Err() => %v, want error code %v", err, 42601) - } + var n int32 + err := br.QueryRow().Scan(&n) + if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "42601") { + t.Errorf("rows.Err() => %v, want error code %v", err, 42601) + } - err = br.Close() - if err == nil { - t.Error("Expected error") - } + err = br.Close() + if err == nil { + t.Error("Expected error") + } - ensureConnValid(t, conn) + }) } func TestConnSendBatchQueryRowInsert(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { - sql := `create temporary table ledger( + sql := `create temporary table ledger( id serial primary key, description varchar not null, amount int not null );` - mustExec(t, conn, sql) + mustExec(t, conn, sql) - batch := &pgx.Batch{} - batch.Queue("select 1") - batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1) + batch := &pgx.Batch{} + batch.Queue("select 1") + batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1) - br := conn.SendBatch(context.Background(), batch) + br := conn.SendBatch(context.Background(), batch) - var value int - err := br.QueryRow().Scan(&value) - if err != nil { - t.Error(err) - } + var value int + err := br.QueryRow().Scan(&value) + if err != nil { + t.Error(err) + } - ct, err := br.Exec() - if err != nil { - t.Error(err) - } - if ct.RowsAffected() != 2 { - t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2) - } + ct, err := br.Exec() + if err != nil { + t.Error(err) + } + if ct.RowsAffected() != 2 { + t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2) + } - br.Close() + br.Close() - ensureConnValid(t, conn) + }) } func TestConnSendBatchQueryPartialReadInsert(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { - sql := `create temporary table ledger( + sql := `create temporary table ledger( id serial primary key, description varchar not null, amount int not null );` - mustExec(t, conn, sql) + mustExec(t, conn, sql) - batch := &pgx.Batch{} - batch.Queue("select 1 union all select 2 union all select 3") - batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1) + batch := &pgx.Batch{} + batch.Queue("select 1 union all select 2 union all select 3") + batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1) - br := conn.SendBatch(context.Background(), batch) + br := conn.SendBatch(context.Background(), batch) - rows, err := br.Query() - if err != nil { - t.Error(err) - } - rows.Close() + rows, err := br.Query() + if err != nil { + t.Error(err) + } + rows.Close() - ct, err := br.Exec() - if err != nil { - t.Error(err) - } - if ct.RowsAffected() != 2 { - t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2) - } + ct, err := br.Exec() + if err != nil { + t.Error(err) + } + if ct.RowsAffected() != 2 { + t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2) + } - br.Close() + br.Close() - ensureConnValid(t, conn) + }) } func TestTxSendBatch(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { - sql := `create temporary table ledger1( + sql := `create temporary table ledger1( id serial primary key, description varchar not null );` - mustExec(t, conn, sql) + mustExec(t, conn, sql) - sql = `create temporary table ledger2( + sql = `create temporary table ledger2( id int primary key, amount int not null );` - mustExec(t, conn, sql) + mustExec(t, conn, sql) - tx, _ := conn.Begin(context.Background()) - batch := &pgx.Batch{} - batch.Queue("insert into ledger1(description) values($1) returning id", "q1") + tx, _ := conn.Begin(context.Background()) + batch := &pgx.Batch{} + batch.Queue("insert into ledger1(description) values($1) returning id", "q1") - br := tx.SendBatch(context.Background(), batch) + br := tx.SendBatch(context.Background(), batch) - var id int - err := br.QueryRow().Scan(&id) - if err != nil { - t.Error(err) - } - br.Close() + var id int + err := br.QueryRow().Scan(&id) + if err != nil { + t.Error(err) + } + br.Close() - batch = &pgx.Batch{} - batch.Queue("insert into ledger2(id,amount) values($1, $2)", id, 2) - batch.Queue("select amount from ledger2 where id = $1", id) + batch = &pgx.Batch{} + batch.Queue("insert into ledger2(id,amount) values($1, $2)", id, 2) + batch.Queue("select amount from ledger2 where id = $1", id) - br = tx.SendBatch(context.Background(), batch) + br = tx.SendBatch(context.Background(), batch) - ct, err := br.Exec() - if err != nil { - t.Error(err) - } - if ct.RowsAffected() != 1 { - t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) - } + ct, err := br.Exec() + if err != nil { + t.Error(err) + } + if ct.RowsAffected() != 1 { + t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) + } - var amount int - err = br.QueryRow().Scan(&amount) - if err != nil { - t.Error(err) - } + var amount int + err = br.QueryRow().Scan(&amount) + if err != nil { + t.Error(err) + } - br.Close() - tx.Commit(context.Background()) + br.Close() + tx.Commit(context.Background()) - var count int - conn.QueryRow(context.Background(), "select count(1) from ledger1 where id = $1", id).Scan(&count) - if count != 1 { - t.Errorf("count => %v, want %v", count, 1) - } + var count int + conn.QueryRow(context.Background(), "select count(1) from ledger1 where id = $1", id).Scan(&count) + if count != 1 { + t.Errorf("count => %v, want %v", count, 1) + } - err = br.Close() - if err != nil { - t.Fatal(err) - } + err = br.Close() + if err != nil { + t.Fatal(err) + } - ensureConnValid(t, conn) + }) } func TestTxSendBatchRollback(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { - sql := `create temporary table ledger1( + sql := `create temporary table ledger1( id serial primary key, description varchar not null );` - mustExec(t, conn, sql) + mustExec(t, conn, sql) - tx, _ := conn.Begin(context.Background()) - batch := &pgx.Batch{} - batch.Queue("insert into ledger1(description) values($1) returning id", "q1") + tx, _ := conn.Begin(context.Background()) + batch := &pgx.Batch{} + batch.Queue("insert into ledger1(description) values($1) returning id", "q1") - br := tx.SendBatch(context.Background(), batch) + br := tx.SendBatch(context.Background(), batch) - var id int - err := br.QueryRow().Scan(&id) - if err != nil { - t.Error(err) - } - br.Close() - tx.Rollback(context.Background()) + var id int + err := br.QueryRow().Scan(&id) + if err != nil { + t.Error(err) + } + br.Close() + tx.Rollback(context.Background()) - row := conn.QueryRow(context.Background(), "select count(1) from ledger1 where id = $1", id) - var count int - row.Scan(&count) - if count != 0 { - t.Errorf("count => %v, want %v", count, 0) - } + row := conn.QueryRow(context.Background(), "select count(1) from ledger1 where id = $1", id) + var count int + row.Scan(&count) + if count != 0 { + t.Errorf("count => %v, want %v", count, 0) + } - ensureConnValid(t, conn) + }) } func TestConnBeginBatchDeferredError(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) { - skipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") + skipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") - mustExec(t, conn, `create temporary table t ( + mustExec(t, conn, `create temporary table t ( id text primary key, n int not null, unique (n) deferrable initially deferred @@ -620,36 +609,36 @@ func TestConnBeginBatchDeferredError(t *testing.T) { insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`) - batch := &pgx.Batch{} + batch := &pgx.Batch{} - batch.Queue(`update t set n=n+1 where id='b' returning *`) + batch.Queue(`update t set n=n+1 where id='b' returning *`) - br := conn.SendBatch(context.Background(), batch) + br := conn.SendBatch(context.Background(), batch) - rows, err := br.Query() - if err != nil { - t.Error(err) - } - - for rows.Next() { - var id string - var n int32 - err = rows.Scan(&id, &n) + rows, err := br.Query() if err != nil { - t.Fatal(err) + t.Error(err) } - } - err = br.Close() - if err == nil { - t.Fatal("expected error 23505 but got none") - } + for rows.Next() { + var id string + var n int32 + err = rows.Scan(&id, &n) + if err != nil { + t.Fatal(err) + } + } - if err, ok := err.(*pgconn.PgError); !ok || err.Code != "23505" { - t.Fatalf("expected error 23505, got %v", err) - } + err = br.Close() + if err == nil { + t.Fatal("expected error 23505 but got none") + } - ensureConnValid(t, conn) + if err, ok := err.(*pgconn.PgError); !ok || err.Code != "23505" { + t.Fatalf("expected error 23505, got %v", err) + } + + }) } func TestConnSendBatchNoStatementCache(t *testing.T) { diff --git a/conn.go b/conn.go index 410acad7..025d8022 100644 --- a/conn.go +++ b/conn.go @@ -861,9 +861,10 @@ func (c *Conn) QueryFunc(ctx context.Context, sql string, args []interface{}, sc // explicit transaction control statements are executed. The returned BatchResults must be closed before the connection // is used again. func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { - simpleProtocol := c.config.DefaultQueryExecMode == QueryExecModeSimpleProtocol - var sb strings.Builder - if simpleProtocol { + mode := c.config.DefaultQueryExecMode + + if mode == QueryExecModeSimpleProtocol { + var sb strings.Builder for i, bi := range b.items { if i > 0 { sb.WriteByte(';') @@ -884,66 +885,102 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { } } - distinctUnpreparedQueries := map[string]struct{}{} - - for _, bi := range b.items { - if _, ok := c.preparedStatements[bi.query]; ok { - continue - } - distinctUnpreparedQueries[bi.query] = struct{}{} - } - - var stmtCache stmtcache.Cache - if len(distinctUnpreparedQueries) > 0 { - if c.statementCache != nil && c.statementCache.Cap() >= len(distinctUnpreparedQueries) { - stmtCache = c.statementCache - } else { - stmtCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, len(distinctUnpreparedQueries)) - } - - for sql, _ := range distinctUnpreparedQueries { - _, err := stmtCache.Get(ctx, sql) - if err != nil { - return &batchResults{ctx: ctx, conn: c, err: err} - } - } - } - batch := &pgconn.Batch{} - for _, bi := range b.items { - c.eqb.Reset() + if mode == QueryExecModeExec { + for _, bi := range b.items { + c.eqb.Reset() + anynil.NormalizeSlice(bi.arguments) - sd := c.preparedStatements[bi.query] - if sd == nil { - var err error - sd, err = stmtCache.Get(ctx, bi.query) - if err != nil { - return &batchResults{ctx: ctx, conn: c, err: err} + sd := c.preparedStatements[bi.query] + if sd != nil { + if len(sd.ParamOIDs) != len(bi.arguments) { + return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("mismatched param and argument count")} + } + + for i := range bi.arguments { + err := c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], bi.arguments[i]) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} + } + } + + for i := range sd.Fields { + c.eqb.AppendResultFormat(c.TypeMap().FormatCodeForOID(sd.Fields[i].DataTypeOID)) + } + + batch.ExecPrepared(sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats) + } else { + err := c.appendParamsForQueryExecModeExec(bi.arguments) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} + } + batch.ExecParams(bi.query, c.eqb.paramValues, nil, c.eqb.paramFormats, c.eqb.resultFormats) + } + } + } else { + + distinctUnpreparedQueries := map[string]struct{}{} + + for _, bi := range b.items { + if _, ok := c.preparedStatements[bi.query]; ok { + continue + } + distinctUnpreparedQueries[bi.query] = struct{}{} + } + + var stmtCache stmtcache.Cache + if len(distinctUnpreparedQueries) > 0 { + if mode == QueryExecModeCacheStatement && c.statementCache != nil && c.statementCache.Cap() >= len(distinctUnpreparedQueries) { + stmtCache = c.statementCache + } else if mode == QueryExecModeCacheStatement && c.descriptionCache != nil && c.descriptionCache.Cap() >= len(distinctUnpreparedQueries) { + stmtCache = c.descriptionCache + } else { + stmtCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, len(distinctUnpreparedQueries)) + } + + for sql, _ := range distinctUnpreparedQueries { + _, err := stmtCache.Get(ctx, sql) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} + } } } - if len(sd.ParamOIDs) != len(bi.arguments) { - return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("mismatched param and argument count")} - } + for _, bi := range b.items { + c.eqb.Reset() - anynil.NormalizeSlice(bi.arguments) - - for i := range bi.arguments { - err := c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], bi.arguments[i]) - if err != nil { - return &batchResults{ctx: ctx, conn: c, err: err} + sd := c.preparedStatements[bi.query] + if sd == nil { + var err error + sd, err = stmtCache.Get(ctx, bi.query) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} + } } - } - for i := range sd.Fields { - c.eqb.AppendResultFormat(c.TypeMap().FormatCodeForOID(sd.Fields[i].DataTypeOID)) - } + if len(sd.ParamOIDs) != len(bi.arguments) { + return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("mismatched param and argument count")} + } - if sd.Name == "" { - batch.ExecParams(bi.query, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, c.eqb.resultFormats) - } else { - batch.ExecPrepared(sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats) + anynil.NormalizeSlice(bi.arguments) + + for i := range bi.arguments { + err := c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], bi.arguments[i]) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} + } + } + + for i := range sd.Fields { + c.eqb.AppendResultFormat(c.TypeMap().FormatCodeForOID(sd.Fields[i].DataTypeOID)) + } + + if sd.Name == "" { + batch.ExecParams(bi.query, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, c.eqb.resultFormats) + } else { + batch.ExecPrepared(sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats) + } } } diff --git a/helper_test.go b/helper_test.go index 0ef21c5a..26509946 100644 --- a/helper_test.go +++ b/helper_test.go @@ -13,13 +13,18 @@ import ( ) func testWithAllQueryExecModes(t *testing.T, f func(t *testing.T, conn *pgx.Conn)) { - for _, mode := range []pgx.QueryExecMode{ + modes := []pgx.QueryExecMode{ pgx.QueryExecModeCacheStatement, pgx.QueryExecModeCacheDescribe, pgx.QueryExecModeDescribeExec, pgx.QueryExecModeExec, pgx.QueryExecModeSimpleProtocol, - } { + } + testWithQueryExecModes(t, modes, f) +} + +func testWithQueryExecModes(t *testing.T, modes []pgx.QueryExecMode, f func(t *testing.T, conn *pgx.Conn)) { + for _, mode := range modes { t.Run(mode.String(), func(t *testing.T) { config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))