Fix SendBatch of all prepared statements with statement cache disabled

fixes #856
pull/860/head
Jack Christensen 2020-10-29 20:28:57 -05:00
parent 5f8d853b34
commit 7c47415150
2 changed files with 66 additions and 9 deletions

View File

@ -228,6 +228,61 @@ func TestConnSendBatchWithPreparedStatement(t *testing.T) {
ensureConnValid(t, conn)
}
// https://github.com/jackc/pgx/issues/856
func TestConnSendBatchWithPreparedStatementAndStatementCacheDisabled(t *testing.T) {
t.Parallel()
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
config.BuildStatementCache = nil
conn := mustConnect(t, config)
defer closeConn(t, conn)
_, err = conn.Prepare(context.Background(), "ps1", "select n from generate_series(0,$1::int) n")
if err != nil {
t.Fatal(err)
}
batch := &pgx.Batch{}
queryCount := 3
for i := 0; i < queryCount; i++ {
batch.Queue("ps1", 5)
}
br := conn.SendBatch(context.Background(), batch)
for i := 0; i < queryCount; i++ {
rows, err := br.Query()
if err != nil {
t.Fatal(err)
}
for k := 0; rows.Next(); k++ {
var n int
if err := rows.Scan(&n); err != nil {
t.Fatal(err)
}
if n != k {
t.Fatalf("n => %v, want %v", n, k)
}
}
if rows.Err() != nil {
t.Fatal(rows.Err())
}
}
err = br.Close()
if err != nil {
t.Fatal(err)
}
ensureConnValid(t, conn)
}
func TestConnSendBatchCloseRowsPartiallyRead(t *testing.T) {
t.Parallel()

20
conn.go
View File

@ -702,16 +702,18 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
}
var stmtCache stmtcache.Cache
if c.stmtcache != nil && c.stmtcache.Cap() >= len(distinctUnpreparedQueries) {
stmtCache = c.stmtcache
} else {
stmtCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, len(distinctUnpreparedQueries))
}
if len(distinctUnpreparedQueries) > 0 {
if c.stmtcache != nil && c.stmtcache.Cap() >= len(distinctUnpreparedQueries) {
stmtCache = c.stmtcache
} else {
stmtCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, len(distinctUnpreparedQueries))
}
for sql, _ := range distinctUnpreparedQueries {
_, err := stmtCache.Get(ctx, sql)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
for sql, _ := range distinctUnpreparedQueries {
_, err := stmtCache.Get(ctx, sql)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
}
}