diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index ceb1b137..cf4464a4 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -872,6 +872,50 @@ readloop: return psd, nil } +// Deallocate deallocates a prepared statement. +// +// Deallocate does not send a DEALLOCATE statement to the server. It uses the PostgreSQL Close protocol message +// directly. This has the implication that Deallocate can succeed in an aborted transaction. +func (pgConn *PgConn) Deallocate(ctx context.Context, name string) error { + if err := pgConn.lock(); err != nil { + return err + } + defer pgConn.unlock() + + if ctx != context.Background() { + select { + case <-ctx.Done(): + return newContextAlreadyDoneError(ctx) + default: + } + pgConn.contextWatcher.Watch(ctx) + defer pgConn.contextWatcher.Unwatch() + } + + pgConn.frontend.SendClose(&pgproto3.Close{ObjectType: 'S', Name: name}) + pgConn.frontend.SendSync(&pgproto3.Sync{}) + err := pgConn.flushWithPotentialWriteReadDeadlock() + if err != nil { + pgConn.asyncClose() + return err + } + + for { + msg, err := pgConn.receiveMessage() + if err != nil { + pgConn.asyncClose() + return normalizeTimeoutError(ctx, err) + } + + switch msg := msg.(type) { + case *pgproto3.ErrorResponse: + return ErrorResponseToPgError(msg) + case *pgproto3.ReadyForQuery: + return nil + } + } +} + // ErrorResponseToPgError converts a wire protocol error message to a *PgError. func ErrorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { return &PgError{ diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 30cf62ff..79b8c82b 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -661,6 +661,73 @@ func TestConnPrepareContextPrecanceled(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnDeallocate(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Prepare(ctx, "ps1", "select 1", nil) + require.NoError(t, err) + + _, err = pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Close() + require.NoError(t, err) + + err = pgConn.Deallocate(ctx, "ps1") + require.NoError(t, err) + + _, err = pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Close() + require.Error(t, err) + var pgErr *pgconn.PgError + require.ErrorAs(t, err, &pgErr) + require.Equal(t, "26000", pgErr.Code) + + ensureConnValid(t, pgConn) +} + +func TestConnDeallocateSucceedsInAbortedTransaction(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + err = pgConn.Exec(ctx, "begin").Close() + require.NoError(t, err) + + _, err = pgConn.Prepare(ctx, "ps1", "select 1", nil) + require.NoError(t, err) + + _, err = pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Close() + require.NoError(t, err) + + err = pgConn.Exec(ctx, "select 1/0").Close() // break transaction with divide by 0 error + require.Error(t, err) + var pgErr *pgconn.PgError + require.ErrorAs(t, err, &pgErr) + require.Equal(t, "22012", pgErr.Code) + + err = pgConn.Deallocate(ctx, "ps1") + require.NoError(t, err) + + err = pgConn.Exec(ctx, "rollback").Close() + require.NoError(t, err) + + _, err = pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Close() + require.Error(t, err) + require.ErrorAs(t, err, &pgErr) + require.Equal(t, "26000", pgErr.Code) + + ensureConnValid(t, pgConn) +} + func TestConnExec(t *testing.T) { t.Parallel()