From 7c47415150692ec8d1bb50ba0702a174db8a4b37 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 29 Oct 2020 20:28:57 -0500 Subject: [PATCH] Fix SendBatch of all prepared statements with statement cache disabled fixes #856 --- batch_test.go | 55 +++++++++++++++++++++++++++++++++++++++++++++++++++ conn.go | 20 ++++++++++--------- 2 files changed, 66 insertions(+), 9 deletions(-) diff --git a/batch_test.go b/batch_test.go index 113ce3cf..7a25ba52 100644 --- a/batch_test.go +++ b/batch_test.go @@ -228,6 +228,61 @@ func TestConnSendBatchWithPreparedStatement(t *testing.T) { ensureConnValid(t, conn) } +// https://github.com/jackc/pgx/issues/856 +func TestConnSendBatchWithPreparedStatementAndStatementCacheDisabled(t *testing.T) { + t.Parallel() + + config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + config.BuildStatementCache = nil + + conn := mustConnect(t, config) + defer closeConn(t, conn) + + _, err = conn.Prepare(context.Background(), "ps1", "select n from generate_series(0,$1::int) n") + if err != nil { + t.Fatal(err) + } + + batch := &pgx.Batch{} + + queryCount := 3 + for i := 0; i < queryCount; i++ { + batch.Queue("ps1", 5) + } + + br := conn.SendBatch(context.Background(), batch) + + for i := 0; i < queryCount; i++ { + rows, err := br.Query() + if err != nil { + t.Fatal(err) + } + + for k := 0; rows.Next(); k++ { + var n int + if err := rows.Scan(&n); err != nil { + t.Fatal(err) + } + if n != k { + t.Fatalf("n => %v, want %v", n, k) + } + } + + if rows.Err() != nil { + t.Fatal(rows.Err()) + } + } + + err = br.Close() + if err != nil { + t.Fatal(err) + } + + ensureConnValid(t, conn) +} + func TestConnSendBatchCloseRowsPartiallyRead(t *testing.T) { t.Parallel() diff --git a/conn.go b/conn.go index f14e1801..28b0f87c 100644 --- a/conn.go +++ b/conn.go @@ -702,16 +702,18 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { } var stmtCache stmtcache.Cache - if c.stmtcache != nil && c.stmtcache.Cap() >= len(distinctUnpreparedQueries) { - stmtCache = c.stmtcache - } else { - stmtCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, len(distinctUnpreparedQueries)) - } + if len(distinctUnpreparedQueries) > 0 { + if c.stmtcache != nil && c.stmtcache.Cap() >= len(distinctUnpreparedQueries) { + stmtCache = c.stmtcache + } else { + stmtCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, len(distinctUnpreparedQueries)) + } - for sql, _ := range distinctUnpreparedQueries { - _, err := stmtCache.Get(ctx, sql) - if err != nil { - return &batchResults{ctx: ctx, conn: c, err: err} + for sql, _ := range distinctUnpreparedQueries { + _, err := stmtCache.Get(ctx, sql) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: err} + } } }