mirror of https://github.com/jackc/pgx.git
SendBatch supports default QueryExecMode
parent
1390a11fe2
commit
cb721dfb5b
|
@ -15,9 +15,7 @@ import (
|
|||
func TestConnSendBatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) {
|
||||
skipCockroachDB(t, conn, "Server serial type is incompatible with test")
|
||||
|
||||
sql := `create temporary table ledger(
|
||||
|
@ -145,16 +143,13 @@ func TestConnSendBatch(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnSendBatchMany(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) {
|
||||
sql := `create temporary table ledger(
|
||||
id serial primary key,
|
||||
description varchar not null,
|
||||
|
@ -186,18 +181,21 @@ func TestConnSendBatchMany(t *testing.T) {
|
|||
|
||||
err = br.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnSendBatchWithPreparedStatement(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
modes := []pgx.QueryExecMode{
|
||||
pgx.QueryExecModeCacheStatement,
|
||||
pgx.QueryExecModeCacheDescribe,
|
||||
pgx.QueryExecModeDescribeExec,
|
||||
pgx.QueryExecModeExec,
|
||||
// Don't test simple mode with prepared statements.
|
||||
}
|
||||
testWithQueryExecModes(t, modes, func(t *testing.T, conn *pgx.Conn) {
|
||||
skipCockroachDB(t, conn, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
|
||||
|
||||
_, err := conn.Prepare(context.Background(), "ps1", "select n from generate_series(0,$1::int) n")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -237,8 +235,7 @@ func TestConnSendBatchWithPreparedStatement(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
})
|
||||
}
|
||||
|
||||
// https://github.com/jackc/pgx/issues/856
|
||||
|
@ -303,8 +300,7 @@ func TestConnSendBatchWithPreparedStatementAndStatementCacheDisabled(t *testing.
|
|||
func TestConnSendBatchCloseRowsPartiallyRead(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) {
|
||||
|
||||
batch := &pgx.Batch{}
|
||||
batch.Queue("select n from generate_series(0,5) n")
|
||||
|
@ -357,14 +353,13 @@ func TestConnSendBatchCloseRowsPartiallyRead(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnSendBatchQueryError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) {
|
||||
|
||||
batch := &pgx.Batch{}
|
||||
batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0")
|
||||
|
@ -396,14 +391,13 @@ func TestConnSendBatchQueryError(t *testing.T) {
|
|||
t.Errorf("rows.Err() => %v, want error code %v", err, 22012)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnSendBatchQuerySyntaxError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) {
|
||||
|
||||
batch := &pgx.Batch{}
|
||||
batch.Queue("select 1 1")
|
||||
|
@ -421,14 +415,13 @@ func TestConnSendBatchQuerySyntaxError(t *testing.T) {
|
|||
t.Error("Expected error")
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnSendBatchQueryRowInsert(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) {
|
||||
|
||||
sql := `create temporary table ledger(
|
||||
id serial primary key,
|
||||
|
@ -459,14 +452,13 @@ func TestConnSendBatchQueryRowInsert(t *testing.T) {
|
|||
|
||||
br.Close()
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnSendBatchQueryPartialReadInsert(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) {
|
||||
|
||||
sql := `create temporary table ledger(
|
||||
id serial primary key,
|
||||
|
@ -497,14 +489,13 @@ func TestConnSendBatchQueryPartialReadInsert(t *testing.T) {
|
|||
|
||||
br.Close()
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTxSendBatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) {
|
||||
|
||||
sql := `create temporary table ledger1(
|
||||
id serial primary key,
|
||||
|
@ -565,14 +556,13 @@ func TestTxSendBatch(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTxSendBatchRollback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) {
|
||||
|
||||
sql := `create temporary table ledger1(
|
||||
id serial primary key,
|
||||
|
@ -601,14 +591,13 @@ func TestTxSendBatchRollback(t *testing.T) {
|
|||
t.Errorf("count => %v, want %v", count, 0)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnBeginBatchDeferredError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
testWithAllQueryExecModes(t, func(t *testing.T, conn *pgx.Conn) {
|
||||
|
||||
skipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
|
||||
|
||||
|
@ -649,7 +638,7 @@ func TestConnBeginBatchDeferredError(t *testing.T) {
|
|||
t.Fatalf("expected error 23505, got %v", err)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnSendBatchNoStatementCache(t *testing.T) {
|
||||
|
|
47
conn.go
47
conn.go
|
@ -861,9 +861,10 @@ func (c *Conn) QueryFunc(ctx context.Context, sql string, args []interface{}, sc
|
|||
// explicit transaction control statements are executed. The returned BatchResults must be closed before the connection
|
||||
// is used again.
|
||||
func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
|
||||
simpleProtocol := c.config.DefaultQueryExecMode == QueryExecModeSimpleProtocol
|
||||
mode := c.config.DefaultQueryExecMode
|
||||
|
||||
if mode == QueryExecModeSimpleProtocol {
|
||||
var sb strings.Builder
|
||||
if simpleProtocol {
|
||||
for i, bi := range b.items {
|
||||
if i > 0 {
|
||||
sb.WriteByte(';')
|
||||
|
@ -884,6 +885,41 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
|
|||
}
|
||||
}
|
||||
|
||||
batch := &pgconn.Batch{}
|
||||
|
||||
if mode == QueryExecModeExec {
|
||||
for _, bi := range b.items {
|
||||
c.eqb.Reset()
|
||||
anynil.NormalizeSlice(bi.arguments)
|
||||
|
||||
sd := c.preparedStatements[bi.query]
|
||||
if sd != nil {
|
||||
if len(sd.ParamOIDs) != len(bi.arguments) {
|
||||
return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("mismatched param and argument count")}
|
||||
}
|
||||
|
||||
for i := range bi.arguments {
|
||||
err := c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], bi.arguments[i])
|
||||
if err != nil {
|
||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
}
|
||||
|
||||
for i := range sd.Fields {
|
||||
c.eqb.AppendResultFormat(c.TypeMap().FormatCodeForOID(sd.Fields[i].DataTypeOID))
|
||||
}
|
||||
|
||||
batch.ExecPrepared(sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats)
|
||||
} else {
|
||||
err := c.appendParamsForQueryExecModeExec(bi.arguments)
|
||||
if err != nil {
|
||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
batch.ExecParams(bi.query, c.eqb.paramValues, nil, c.eqb.paramFormats, c.eqb.resultFormats)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
||||
distinctUnpreparedQueries := map[string]struct{}{}
|
||||
|
||||
for _, bi := range b.items {
|
||||
|
@ -895,8 +931,10 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
|
|||
|
||||
var stmtCache stmtcache.Cache
|
||||
if len(distinctUnpreparedQueries) > 0 {
|
||||
if c.statementCache != nil && c.statementCache.Cap() >= len(distinctUnpreparedQueries) {
|
||||
if mode == QueryExecModeCacheStatement && c.statementCache != nil && c.statementCache.Cap() >= len(distinctUnpreparedQueries) {
|
||||
stmtCache = c.statementCache
|
||||
} else if mode == QueryExecModeCacheStatement && c.descriptionCache != nil && c.descriptionCache.Cap() >= len(distinctUnpreparedQueries) {
|
||||
stmtCache = c.descriptionCache
|
||||
} else {
|
||||
stmtCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, len(distinctUnpreparedQueries))
|
||||
}
|
||||
|
@ -909,8 +947,6 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
|
|||
}
|
||||
}
|
||||
|
||||
batch := &pgconn.Batch{}
|
||||
|
||||
for _, bi := range b.items {
|
||||
c.eqb.Reset()
|
||||
|
||||
|
@ -946,6 +982,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
|
|||
batch.ExecPrepared(sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.eqb.Reset() // Allow c.eqb internal memory to be GC'ed as soon as possible.
|
||||
|
||||
|
|
|
@ -13,13 +13,18 @@ import (
|
|||
)
|
||||
|
||||
func testWithAllQueryExecModes(t *testing.T, f func(t *testing.T, conn *pgx.Conn)) {
|
||||
for _, mode := range []pgx.QueryExecMode{
|
||||
modes := []pgx.QueryExecMode{
|
||||
pgx.QueryExecModeCacheStatement,
|
||||
pgx.QueryExecModeCacheDescribe,
|
||||
pgx.QueryExecModeDescribeExec,
|
||||
pgx.QueryExecModeExec,
|
||||
pgx.QueryExecModeSimpleProtocol,
|
||||
} {
|
||||
}
|
||||
testWithQueryExecModes(t, modes, f)
|
||||
}
|
||||
|
||||
func testWithQueryExecModes(t *testing.T, modes []pgx.QueryExecMode, f func(t *testing.T, conn *pgx.Conn)) {
|
||||
for _, mode := range modes {
|
||||
t.Run(mode.String(),
|
||||
func(t *testing.T) {
|
||||
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
||||
|
|
Loading…
Reference in New Issue