diff --git a/pgconn/config.go b/pgconn/config.go index 13167729..40cbd0bb 100644 --- a/pgconn/config.go +++ b/pgconn/config.go @@ -41,7 +41,19 @@ type Config struct { // allows implementing high availability behavior such as libpq does with target_session_attrs. AfterConnectFunc AfterConnectFunc - OnNotice NoticeHandler // Callback function called when a notice response is received. + // OnContextCancel is a callback function used to override cancellation behavior. It is called when a context.Context + // is canceled. Default cancellation behavior is to establish another connection to the PostgreSQL server and send a + // query cancel request. Some non-PostgreSQL servers (e.g. CockroachDB) that speak a subset of the PostgreSQL wire + // protocol do not support this cancellation method. + // + // It is called from a background goroutine. When the cancellation process has finished ContextCancel.Finish must be + // called whether it was successful or not. If an error occurs the connection should be closed. The connection must be + // in a ready for query state or be closed when ContextCancel.Finish is called. Use PgConn.ReceiveMessage() to read + // the connection until a ready for query message is received. + OnContextCancel func(*ContextCancel) + + // OnNotice is a callback function called when a notice response is received. + OnNotice NoticeHandler } // FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index bab4370a..08fce16e 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -527,6 +527,22 @@ func (pgConn *PgConn) cancelRequest(ctx context.Context) error { return nil } +// WaitUntilReady waits until a previous context cancellation has been competed processed and the connection is ready +// for use. This is done automatically by all methods that need the connection to be ready for use. The only expected +// use for this method is for a connection pool to wait for a returned connection to be usable again before making it +// available. +func (pgConn *PgConn) WaitUntilReady(ctx context.Context) error { + select { + case <-ctx.Done(): + return ctx.Err() + case pgConn.controller <- pgConn: + // The connection must be ready since it was locked. Immediately unlock it. + <-pgConn.controller + } + + return nil +} + // Exec executes SQL via the PostgreSQL simple query protocol. SQL may contain multiple queries. Execution is // implicitly wrapped in a transaction unless a transaction is already in progress or SQL contains transaction control // statements. @@ -942,7 +958,7 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) { rr.commandConcluded = true } -func (pgConn *PgConn) recoverFromTimeout() { +func (pgConn *PgConn) defaultCancel() { // Regardless of recovery outcome the lock on the pgConn must be released. defer func() { <-pgConn.controller }() @@ -991,6 +1007,26 @@ func (pgConn *PgConn) recoverFromTimeout() { } } +type ContextCancel struct { + PgConn *PgConn +} + +// Finish must be called when the cancellation request has finished processing. The connection must be in a ready for +// query state or the connection must be closed. This must be called regardless of the success of the cancellation and +// whether the connection is still valid or not. It releases an internal busy lock on the connection. +func (cc *ContextCancel) Finish() { + <-cc.PgConn.controller +} + +func (pgConn *PgConn) recoverFromTimeout() { + if pgConn.Config.OnContextCancel == nil { + pgConn.defaultCancel() + } else { + cc := &ContextCancel{PgConn: pgConn} + pgConn.Config.OnContextCancel(cc) + } +} + // Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip. type Batch struct { buf []byte diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 2d8cc784..9452ffc0 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -11,6 +11,7 @@ import ( "time" "github.com/jackc/pgx/pgconn" + "github.com/jackc/pgx/pgproto3" "github.com/pkg/errors" "github.com/stretchr/testify/assert" @@ -490,6 +491,72 @@ func TestCommandTag(t *testing.T) { } } +func TestConnContextCancelWithOnContextCancel(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + + calledChan := make(chan struct{}) + + config.OnContextCancel = func(cc *pgconn.ContextCancel) { + defer cc.Finish() + close(calledChan) + + for { + msg, err := cc.PgConn.ReceiveMessage() + if err != nil { + cc.PgConn.Close(context.Background()) + return + } + + switch msg.(type) { + case *pgproto3.ReadyForQuery: + return + } + } + } + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.Nil(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + result := pgConn.ExecParams(ctx, "select 'Hello, world', pg_sleep(0.25)", nil, nil, nil, nil) + _, err = result.Close() + assert.Equal(t, context.DeadlineExceeded, err) + + called := false + select { + case <-calledChan: + called = true + case <-time.NewTimer(time.Second).C: + } + + assert.True(t, called) + + ensureConnValid(t, pgConn) +} + +func TestConnWaitUntilReady(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + result := pgConn.ExecParams(ctx, "select current_database(), pg_sleep(1)", nil, nil, nil, nil).Read() + assert.Equal(t, context.DeadlineExceeded, result.Err) + + err = pgConn.WaitUntilReady(context.Background()) + require.Nil(t, err) + + ensureConnValid(t, pgConn) +} + func TestConnOnNotice(t *testing.T) { t.Parallel()