From 7ceeea6fe604db989489de762c9d6e0a7cb2ace7 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 28 Apr 2022 07:58:24 -0500 Subject: [PATCH] Fix explicitly prepared statements with describe statement cache mode fixes https://github.com/jackc/pgx/issues/1196 --- conn.go | 2 +- conn_test.go | 44 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/conn.go b/conn.go index 0030ea21..9b620d27 100644 --- a/conn.go +++ b/conn.go @@ -645,7 +645,7 @@ optionLoop: resultFormats = c.eqb.resultFormats } - if c.stmtcache != nil && c.stmtcache.Mode() == stmtcache.ModeDescribe { + if c.stmtcache != nil && c.stmtcache.Mode() == stmtcache.ModeDescribe && !ok { rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, resultFormats) } else { rows.resultReader = c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.paramValues, c.eqb.paramFormats, resultFormats) diff --git a/conn_test.go b/conn_test.go index beddcdcd..e34662ae 100644 --- a/conn_test.go +++ b/conn_test.go @@ -496,6 +496,50 @@ func TestPrepareIdempotency(t *testing.T) { } } +func TestPrepareStatementCacheModes(t *testing.T) { + t.Parallel() + + config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) + + tests := []struct { + name string + buildStatementCache pgx.BuildStatementCacheFunc + }{ + { + name: "disabled", + buildStatementCache: nil, + }, + { + name: "prepare", + buildStatementCache: func(conn *pgconn.PgConn) stmtcache.Cache { + return stmtcache.New(conn, stmtcache.ModePrepare, 32) + }, + }, + { + name: "describe", + buildStatementCache: func(conn *pgconn.PgConn) stmtcache.Cache { + return stmtcache.New(conn, stmtcache.ModeDescribe, 32) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config.BuildStatementCache = tt.buildStatementCache + conn := mustConnect(t, config) + defer closeConn(t, conn) + + _, err := conn.Prepare(context.Background(), "test", "select $1::text") + require.NoError(t, err) + + var s string + err = conn.QueryRow(context.Background(), "test", "hello").Scan(&s) + require.NoError(t, err) + require.Equal(t, "hello", s) + }) + } +} + func TestListenNotify(t *testing.T) { t.Parallel()