Add custom context cancellation hook

pull/483/head
Jack Christensen 2019-01-12 11:37:13 -06:00
parent 1257b89df7
commit 7f373ee92b
3 changed files with 117 additions and 2 deletions

View File

@ -41,7 +41,19 @@ type Config struct {
// allows implementing high availability behavior such as libpq does with target_session_attrs.
AfterConnectFunc AfterConnectFunc
OnNotice NoticeHandler // Callback function called when a notice response is received.
// OnContextCancel is a callback function used to override cancellation behavior. It is called when a context.Context
// is canceled. Default cancellation behavior is to establish another connection to the PostgreSQL server and send a
// query cancel request. Some non-PostgreSQL servers (e.g. CockroachDB) that speak a subset of the PostgreSQL wire
// protocol do not support this cancellation method.
//
// It is called from a background goroutine. When the cancellation process has finished ContextCancel.Finish must be
// called whether it was successful or not. If an error occurs the connection should be closed. The connection must be
// in a ready for query state or be closed when ContextCancel.Finish is called. Use PgConn.ReceiveMessage() to read
// the connection until a ready for query message is received.
OnContextCancel func(*ContextCancel)
// OnNotice is a callback function called when a notice response is received.
OnNotice NoticeHandler
}
// FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a

View File

@ -527,6 +527,22 @@ func (pgConn *PgConn) cancelRequest(ctx context.Context) error {
return nil
}
// WaitUntilReady waits until a previous context cancellation has been competed processed and the connection is ready
// for use. This is done automatically by all methods that need the connection to be ready for use. The only expected
// use for this method is for a connection pool to wait for a returned connection to be usable again before making it
// available.
func (pgConn *PgConn) WaitUntilReady(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
case pgConn.controller <- pgConn:
// The connection must be ready since it was locked. Immediately unlock it.
<-pgConn.controller
}
return nil
}
// Exec executes SQL via the PostgreSQL simple query protocol. SQL may contain multiple queries. Execution is
// implicitly wrapped in a transaction unless a transaction is already in progress or SQL contains transaction control
// statements.
@ -942,7 +958,7 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) {
rr.commandConcluded = true
}
func (pgConn *PgConn) recoverFromTimeout() {
func (pgConn *PgConn) defaultCancel() {
// Regardless of recovery outcome the lock on the pgConn must be released.
defer func() { <-pgConn.controller }()
@ -991,6 +1007,26 @@ func (pgConn *PgConn) recoverFromTimeout() {
}
}
type ContextCancel struct {
PgConn *PgConn
}
// Finish must be called when the cancellation request has finished processing. The connection must be in a ready for
// query state or the connection must be closed. This must be called regardless of the success of the cancellation and
// whether the connection is still valid or not. It releases an internal busy lock on the connection.
func (cc *ContextCancel) Finish() {
<-cc.PgConn.controller
}
func (pgConn *PgConn) recoverFromTimeout() {
if pgConn.Config.OnContextCancel == nil {
pgConn.defaultCancel()
} else {
cc := &ContextCancel{PgConn: pgConn}
pgConn.Config.OnContextCancel(cc)
}
}
// Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip.
type Batch struct {
buf []byte

View File

@ -11,6 +11,7 @@ import (
"time"
"github.com/jackc/pgx/pgconn"
"github.com/jackc/pgx/pgproto3"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
@ -490,6 +491,72 @@ func TestCommandTag(t *testing.T) {
}
}
func TestConnContextCancelWithOnContextCancel(t *testing.T) {
t.Parallel()
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.Nil(t, err)
calledChan := make(chan struct{})
config.OnContextCancel = func(cc *pgconn.ContextCancel) {
defer cc.Finish()
close(calledChan)
for {
msg, err := cc.PgConn.ReceiveMessage()
if err != nil {
cc.PgConn.Close(context.Background())
return
}
switch msg.(type) {
case *pgproto3.ReadyForQuery:
return
}
}
}
pgConn, err := pgconn.ConnectConfig(context.Background(), config)
require.Nil(t, err)
defer closeConn(t, pgConn)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
result := pgConn.ExecParams(ctx, "select 'Hello, world', pg_sleep(0.25)", nil, nil, nil, nil)
_, err = result.Close()
assert.Equal(t, context.DeadlineExceeded, err)
called := false
select {
case <-calledChan:
called = true
case <-time.NewTimer(time.Second).C:
}
assert.True(t, called)
ensureConnValid(t, pgConn)
}
func TestConnWaitUntilReady(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.Nil(t, err)
defer closeConn(t, pgConn)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
result := pgConn.ExecParams(ctx, "select current_database(), pg_sleep(1)", nil, nil, nil, nil).Read()
assert.Equal(t, context.DeadlineExceeded, result.Err)
err = pgConn.WaitUntilReady(context.Background())
require.Nil(t, err)
ensureConnValid(t, pgConn)
}
func TestConnOnNotice(t *testing.T) {
t.Parallel()