mirror of https://github.com/jackc/pgx.git
Add custom context cancellation hook
parent
1257b89df7
commit
7f373ee92b
|
@ -41,7 +41,19 @@ type Config struct {
|
|||
// allows implementing high availability behavior such as libpq does with target_session_attrs.
|
||||
AfterConnectFunc AfterConnectFunc
|
||||
|
||||
OnNotice NoticeHandler // Callback function called when a notice response is received.
|
||||
// OnContextCancel is a callback function used to override cancellation behavior. It is called when a context.Context
|
||||
// is canceled. Default cancellation behavior is to establish another connection to the PostgreSQL server and send a
|
||||
// query cancel request. Some non-PostgreSQL servers (e.g. CockroachDB) that speak a subset of the PostgreSQL wire
|
||||
// protocol do not support this cancellation method.
|
||||
//
|
||||
// It is called from a background goroutine. When the cancellation process has finished ContextCancel.Finish must be
|
||||
// called whether it was successful or not. If an error occurs the connection should be closed. The connection must be
|
||||
// in a ready for query state or be closed when ContextCancel.Finish is called. Use PgConn.ReceiveMessage() to read
|
||||
// the connection until a ready for query message is received.
|
||||
OnContextCancel func(*ContextCancel)
|
||||
|
||||
// OnNotice is a callback function called when a notice response is received.
|
||||
OnNotice NoticeHandler
|
||||
}
|
||||
|
||||
// FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a
|
||||
|
|
|
@ -527,6 +527,22 @@ func (pgConn *PgConn) cancelRequest(ctx context.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// WaitUntilReady waits until a previous context cancellation has been competed processed and the connection is ready
|
||||
// for use. This is done automatically by all methods that need the connection to be ready for use. The only expected
|
||||
// use for this method is for a connection pool to wait for a returned connection to be usable again before making it
|
||||
// available.
|
||||
func (pgConn *PgConn) WaitUntilReady(ctx context.Context) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case pgConn.controller <- pgConn:
|
||||
// The connection must be ready since it was locked. Immediately unlock it.
|
||||
<-pgConn.controller
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exec executes SQL via the PostgreSQL simple query protocol. SQL may contain multiple queries. Execution is
|
||||
// implicitly wrapped in a transaction unless a transaction is already in progress or SQL contains transaction control
|
||||
// statements.
|
||||
|
@ -942,7 +958,7 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) {
|
|||
rr.commandConcluded = true
|
||||
}
|
||||
|
||||
func (pgConn *PgConn) recoverFromTimeout() {
|
||||
func (pgConn *PgConn) defaultCancel() {
|
||||
// Regardless of recovery outcome the lock on the pgConn must be released.
|
||||
defer func() { <-pgConn.controller }()
|
||||
|
||||
|
@ -991,6 +1007,26 @@ func (pgConn *PgConn) recoverFromTimeout() {
|
|||
}
|
||||
}
|
||||
|
||||
type ContextCancel struct {
|
||||
PgConn *PgConn
|
||||
}
|
||||
|
||||
// Finish must be called when the cancellation request has finished processing. The connection must be in a ready for
|
||||
// query state or the connection must be closed. This must be called regardless of the success of the cancellation and
|
||||
// whether the connection is still valid or not. It releases an internal busy lock on the connection.
|
||||
func (cc *ContextCancel) Finish() {
|
||||
<-cc.PgConn.controller
|
||||
}
|
||||
|
||||
func (pgConn *PgConn) recoverFromTimeout() {
|
||||
if pgConn.Config.OnContextCancel == nil {
|
||||
pgConn.defaultCancel()
|
||||
} else {
|
||||
cc := &ContextCancel{PgConn: pgConn}
|
||||
pgConn.Config.OnContextCancel(cc)
|
||||
}
|
||||
}
|
||||
|
||||
// Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip.
|
||||
type Batch struct {
|
||||
buf []byte
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/jackc/pgx/pgconn"
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -490,6 +491,72 @@ func TestCommandTag(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestConnContextCancelWithOnContextCancel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.Nil(t, err)
|
||||
|
||||
calledChan := make(chan struct{})
|
||||
|
||||
config.OnContextCancel = func(cc *pgconn.ContextCancel) {
|
||||
defer cc.Finish()
|
||||
close(calledChan)
|
||||
|
||||
for {
|
||||
msg, err := cc.PgConn.ReceiveMessage()
|
||||
if err != nil {
|
||||
cc.PgConn.Close(context.Background())
|
||||
return
|
||||
}
|
||||
|
||||
switch msg.(type) {
|
||||
case *pgproto3.ReadyForQuery:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pgConn, err := pgconn.ConnectConfig(context.Background(), config)
|
||||
require.Nil(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
result := pgConn.ExecParams(ctx, "select 'Hello, world', pg_sleep(0.25)", nil, nil, nil, nil)
|
||||
_, err = result.Close()
|
||||
assert.Equal(t, context.DeadlineExceeded, err)
|
||||
|
||||
called := false
|
||||
select {
|
||||
case <-calledChan:
|
||||
called = true
|
||||
case <-time.NewTimer(time.Second).C:
|
||||
}
|
||||
|
||||
assert.True(t, called)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnWaitUntilReady(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.Nil(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
result := pgConn.ExecParams(ctx, "select current_database(), pg_sleep(1)", nil, nil, nil, nil).Read()
|
||||
assert.Equal(t, context.DeadlineExceeded, result.Err)
|
||||
|
||||
err = pgConn.WaitUntilReady(context.Background())
|
||||
require.Nil(t, err)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnOnNotice(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
Loading…
Reference in New Issue