From e5015e2fac30c57fe03774de91b5b5a45bafd037 Mon Sep 17 00:00:00 2001
From: Jack Christensen <jack@jackchristensen.com>
Date: Sat, 11 Nov 2023 09:30:50 -0600
Subject: [PATCH] pgx.Conn.Deallocate uses PgConn.Deallocate

This uses the PostgreSQL protocol to deallocate a prepared statement
instead of a SQL statement. This allows it to work even in an aborted
transaction.
---
 conn.go      |  2 +-
 conn_test.go | 36 ++++++++++++++++++++++++++++++++++++
 2 files changed, 37 insertions(+), 1 deletion(-)

diff --git a/conn.go b/conn.go
index 0426873c..b760258c 100644
--- a/conn.go
+++ b/conn.go
@@ -347,7 +347,7 @@ func (c *Conn) Deallocate(ctx context.Context, name string) error {
 	} else {
 		psName = name
 	}
-	_, err := c.pgConn.Exec(ctx, "deallocate "+quoteIdentifier(psName)).ReadAll()
+	err := c.pgConn.Deallocate(ctx, psName)
 	return err
 }
 
diff --git a/conn_test.go b/conn_test.go
index 739b6619..a37e9091 100644
--- a/conn_test.go
+++ b/conn_test.go
@@ -548,6 +548,42 @@ func TestPrepareWithDigestedName(t *testing.T) {
 	})
 }
 
+// https://github.com/jackc/pgx/pull/1795
+func TestDeallocateInAbortedTransaction(t *testing.T) {
+	t.Parallel()
+
+	ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
+	defer cancel()
+
+	pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
+		tx, err := conn.Begin(ctx)
+		require.NoError(t, err)
+
+		sql := "select $1::text"
+		sd, err := tx.Prepare(ctx, sql, sql)
+		require.NoError(t, err)
+		require.Equal(t, "stmt_2510cc7db17de3f42758a2a29c8b9ef8305d007b997ebdd6", sd.Name)
+
+		var s string
+		err = tx.QueryRow(ctx, sql, "hello").Scan(&s)
+		require.NoError(t, err)
+		require.Equal(t, "hello", s)
+
+		_, err = tx.Exec(ctx, "select 1/0") // abort transaction with divide by zero error
+		require.Error(t, err)
+
+		err = conn.Deallocate(ctx, sql)
+		require.NoError(t, err)
+
+		err = tx.Rollback(ctx)
+		require.NoError(t, err)
+
+		sd, err = conn.Prepare(ctx, sql, sql)
+		require.NoError(t, err)
+		require.Equal(t, "stmt_2510cc7db17de3f42758a2a29c8b9ef8305d007b997ebdd6", sd.Name)
+	})
+}
+
 func TestListenNotify(t *testing.T) {
 	t.Parallel()