mirror of https://github.com/jackc/pgx.git
Add PgConn.Deallocate method
This method uses the PostgreSQL protocol Close method to deallocate a prepared statement. This means that it can succeed in an aborted transaction.pull/1804/head
parent
0570b0e196
commit
4dbd57a7ed
|
@ -872,6 +872,50 @@ readloop:
|
|||
return psd, nil
|
||||
}
|
||||
|
||||
// Deallocate deallocates a prepared statement.
|
||||
//
|
||||
// Deallocate does not send a DEALLOCATE statement to the server. It uses the PostgreSQL Close protocol message
|
||||
// directly. This has the implication that Deallocate can succeed in an aborted transaction.
|
||||
func (pgConn *PgConn) Deallocate(ctx context.Context, name string) error {
|
||||
if err := pgConn.lock(); err != nil {
|
||||
return err
|
||||
}
|
||||
defer pgConn.unlock()
|
||||
|
||||
if ctx != context.Background() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return newContextAlreadyDoneError(ctx)
|
||||
default:
|
||||
}
|
||||
pgConn.contextWatcher.Watch(ctx)
|
||||
defer pgConn.contextWatcher.Unwatch()
|
||||
}
|
||||
|
||||
pgConn.frontend.SendClose(&pgproto3.Close{ObjectType: 'S', Name: name})
|
||||
pgConn.frontend.SendSync(&pgproto3.Sync{})
|
||||
err := pgConn.flushWithPotentialWriteReadDeadlock()
|
||||
if err != nil {
|
||||
pgConn.asyncClose()
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
msg, err := pgConn.receiveMessage()
|
||||
if err != nil {
|
||||
pgConn.asyncClose()
|
||||
return normalizeTimeoutError(ctx, err)
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.ErrorResponse:
|
||||
return ErrorResponseToPgError(msg)
|
||||
case *pgproto3.ReadyForQuery:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorResponseToPgError converts a wire protocol error message to a *PgError.
|
||||
func ErrorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError {
|
||||
return &PgError{
|
||||
|
|
|
@ -661,6 +661,73 @@ func TestConnPrepareContextPrecanceled(t *testing.T) {
|
|||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnDeallocate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
_, err = pgConn.Prepare(ctx, "ps1", "select 1", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = pgConn.Deallocate(ctx, "ps1")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Close()
|
||||
require.Error(t, err)
|
||||
var pgErr *pgconn.PgError
|
||||
require.ErrorAs(t, err, &pgErr)
|
||||
require.Equal(t, "26000", pgErr.Code)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnDeallocateSucceedsInAbortedTransaction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
err = pgConn.Exec(ctx, "begin").Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = pgConn.Prepare(ctx, "ps1", "select 1", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = pgConn.Exec(ctx, "select 1/0").Close() // break transaction with divide by 0 error
|
||||
require.Error(t, err)
|
||||
var pgErr *pgconn.PgError
|
||||
require.ErrorAs(t, err, &pgErr)
|
||||
require.Equal(t, "22012", pgErr.Code)
|
||||
|
||||
err = pgConn.Deallocate(ctx, "ps1")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = pgConn.Exec(ctx, "rollback").Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Close()
|
||||
require.Error(t, err)
|
||||
require.ErrorAs(t, err, &pgErr)
|
||||
require.Equal(t, "26000", pgErr.Code)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnExec(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
Loading…
Reference in New Issue