mirror of https://github.com/jackc/pgx.git
Allow customizing context canceled behavior for pgconn
This feature made the ctxwatch package public.pull/1894/head
parent
60a01d044a
commit
42c9e9070a
|
@ -19,6 +19,7 @@ import (
|
|||
|
||||
"github.com/jackc/pgpassfile"
|
||||
"github.com/jackc/pgservicefile"
|
||||
"github.com/jackc/pgx/v5/pgconn/ctxwatch"
|
||||
"github.com/jackc/pgx/v5/pgproto3"
|
||||
)
|
||||
|
||||
|
@ -39,7 +40,12 @@ type Config struct {
|
|||
DialFunc DialFunc // e.g. net.Dialer.DialContext
|
||||
LookupFunc LookupFunc // e.g. net.Resolver.LookupHost
|
||||
BuildFrontend BuildFrontendFunc
|
||||
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
|
||||
|
||||
// BuildContextWatcherHandler is called to create a ContextWatcherHandler for a connection. The handler is called
|
||||
// when a context passed to a PgConn method is canceled.
|
||||
BuildContextWatcherHandler func(*PgConn) ctxwatch.Handler
|
||||
|
||||
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
|
||||
|
||||
KerberosSrvName string
|
||||
KerberosSpn string
|
||||
|
@ -266,6 +272,9 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
|
|||
BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend {
|
||||
return pgproto3.NewFrontend(r, w)
|
||||
},
|
||||
BuildContextWatcherHandler: func(pgConn *PgConn) ctxwatch.Handler {
|
||||
return &DeadlineContextWatcherHandler{Conn: pgConn.conn}
|
||||
},
|
||||
OnPgError: func(_ *PgConn, pgErr *PgError) bool {
|
||||
// we want to automatically close any fatal errors
|
||||
if strings.EqualFold(pgErr.Severity, "FATAL") {
|
||||
|
|
|
@ -8,9 +8,8 @@ import (
|
|||
// ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a
|
||||
// time.
|
||||
type ContextWatcher struct {
|
||||
onCancel func()
|
||||
onUnwatchAfterCancel func()
|
||||
unwatchChan chan struct{}
|
||||
handler Handler
|
||||
unwatchChan chan struct{}
|
||||
|
||||
lock sync.Mutex
|
||||
watchInProgress bool
|
||||
|
@ -20,11 +19,10 @@ type ContextWatcher struct {
|
|||
// NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled.
|
||||
// OnUnwatchAfterCancel will be called when Unwatch is called and the watched context had already been canceled and
|
||||
// onCancel called.
|
||||
func NewContextWatcher(onCancel func(), onUnwatchAfterCancel func()) *ContextWatcher {
|
||||
func NewContextWatcher(handler Handler) *ContextWatcher {
|
||||
cw := &ContextWatcher{
|
||||
onCancel: onCancel,
|
||||
onUnwatchAfterCancel: onUnwatchAfterCancel,
|
||||
unwatchChan: make(chan struct{}),
|
||||
handler: handler,
|
||||
unwatchChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
return cw
|
||||
|
@ -46,7 +44,7 @@ func (cw *ContextWatcher) Watch(ctx context.Context) {
|
|||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
cw.onCancel()
|
||||
cw.handler.HandleCancel(ctx)
|
||||
cw.onCancelWasCalled = true
|
||||
<-cw.unwatchChan
|
||||
case <-cw.unwatchChan:
|
||||
|
@ -66,8 +64,17 @@ func (cw *ContextWatcher) Unwatch() {
|
|||
if cw.watchInProgress {
|
||||
cw.unwatchChan <- struct{}{}
|
||||
if cw.onCancelWasCalled {
|
||||
cw.onUnwatchAfterCancel()
|
||||
cw.handler.HandleUnwatchAfterCancel()
|
||||
}
|
||||
cw.watchInProgress = false
|
||||
}
|
||||
}
|
||||
|
||||
type Handler interface {
|
||||
// HandleCancel is called when the context that a ContextWatcher is currently watching is canceled. canceledCtx is the
|
||||
// context that was canceled.
|
||||
HandleCancel(canceledCtx context.Context)
|
||||
|
||||
// HandleUnwatchAfterCancel is called when a ContextWatcher that called HandleCancel on this Handler is unwatched.
|
||||
HandleUnwatchAfterCancel()
|
||||
}
|
|
@ -6,17 +6,32 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn/internal/ctxwatch"
|
||||
"github.com/jackc/pgx/v5/pgconn/ctxwatch"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type testHandler struct {
|
||||
handleCancel func(context.Context)
|
||||
handleUnwatchAfterCancel func()
|
||||
}
|
||||
|
||||
func (h *testHandler) HandleCancel(ctx context.Context) {
|
||||
h.handleCancel(ctx)
|
||||
}
|
||||
|
||||
func (h *testHandler) HandleUnwatchAfterCancel() {
|
||||
h.handleUnwatchAfterCancel()
|
||||
}
|
||||
|
||||
func TestContextWatcherContextCancelled(t *testing.T) {
|
||||
canceledChan := make(chan struct{})
|
||||
cleanupCalled := false
|
||||
cw := ctxwatch.NewContextWatcher(func() {
|
||||
canceledChan <- struct{}{}
|
||||
}, func() {
|
||||
cleanupCalled = true
|
||||
cw := ctxwatch.NewContextWatcher(&testHandler{
|
||||
handleCancel: func(context.Context) {
|
||||
canceledChan <- struct{}{}
|
||||
}, handleUnwatchAfterCancel: func() {
|
||||
cleanupCalled = true
|
||||
},
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
@ -35,10 +50,12 @@ func TestContextWatcherContextCancelled(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestContextWatcherUnwatchdBeforeContextCancelled(t *testing.T) {
|
||||
cw := ctxwatch.NewContextWatcher(func() {
|
||||
t.Error("cancel func should not have been called")
|
||||
}, func() {
|
||||
t.Error("cleanup func should not have been called")
|
||||
cw := ctxwatch.NewContextWatcher(&testHandler{
|
||||
handleCancel: func(context.Context) {
|
||||
t.Error("cancel func should not have been called")
|
||||
}, handleUnwatchAfterCancel: func() {
|
||||
t.Error("cleanup func should not have been called")
|
||||
},
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
@ -48,7 +65,7 @@ func TestContextWatcherUnwatchdBeforeContextCancelled(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestContextWatcherMultipleWatchPanics(t *testing.T) {
|
||||
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
|
||||
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
@ -61,7 +78,7 @@ func TestContextWatcherMultipleWatchPanics(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestContextWatcherUnwatchWhenNotWatchingIsSafe(t *testing.T) {
|
||||
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
|
||||
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
|
||||
cw.Unwatch() // unwatch when not / never watching
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
@ -72,7 +89,7 @@ func TestContextWatcherUnwatchWhenNotWatchingIsSafe(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestContextWatcherUnwatchIsConcurrencySafe(t *testing.T) {
|
||||
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
|
||||
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
@ -88,10 +105,12 @@ func TestContextWatcherStress(t *testing.T) {
|
|||
var cancelFuncCalls int64
|
||||
var cleanupFuncCalls int64
|
||||
|
||||
cw := ctxwatch.NewContextWatcher(func() {
|
||||
atomic.AddInt64(&cancelFuncCalls, 1)
|
||||
}, func() {
|
||||
atomic.AddInt64(&cleanupFuncCalls, 1)
|
||||
cw := ctxwatch.NewContextWatcher(&testHandler{
|
||||
handleCancel: func(context.Context) {
|
||||
atomic.AddInt64(&cancelFuncCalls, 1)
|
||||
}, handleUnwatchAfterCancel: func() {
|
||||
atomic.AddInt64(&cleanupFuncCalls, 1)
|
||||
},
|
||||
})
|
||||
|
||||
cycleCount := 100000
|
||||
|
@ -134,7 +153,7 @@ func TestContextWatcherStress(t *testing.T) {
|
|||
}
|
||||
|
||||
func BenchmarkContextWatcherUncancellable(b *testing.B) {
|
||||
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
|
||||
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
cw.Watch(context.Background())
|
||||
|
@ -143,7 +162,7 @@ func BenchmarkContextWatcherUncancellable(b *testing.B) {
|
|||
}
|
||||
|
||||
func BenchmarkContextWatcherCancelled(b *testing.B) {
|
||||
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
|
||||
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
@ -154,7 +173,7 @@ func BenchmarkContextWatcherCancelled(b *testing.B) {
|
|||
}
|
||||
|
||||
func BenchmarkContextWatcherCancellable(b *testing.B) {
|
||||
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
|
||||
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
|
@ -18,8 +18,8 @@ import (
|
|||
|
||||
"github.com/jackc/pgx/v5/internal/iobufpool"
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
"github.com/jackc/pgx/v5/pgconn/ctxwatch"
|
||||
"github.com/jackc/pgx/v5/pgconn/internal/bgreader"
|
||||
"github.com/jackc/pgx/v5/pgconn/internal/ctxwatch"
|
||||
"github.com/jackc/pgx/v5/pgproto3"
|
||||
)
|
||||
|
||||
|
@ -281,28 +281,26 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
|||
|
||||
var err error
|
||||
network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port)
|
||||
netConn, err := config.DialFunc(ctx, network, address)
|
||||
pgConn.conn, err = config.DialFunc(ctx, network, address)
|
||||
if err != nil {
|
||||
return nil, &ConnectError{Config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)}
|
||||
}
|
||||
|
||||
pgConn.conn = netConn
|
||||
pgConn.contextWatcher = newContextWatcher(netConn)
|
||||
pgConn.contextWatcher.Watch(ctx)
|
||||
|
||||
if fallbackConfig.TLSConfig != nil {
|
||||
nbTLSConn, err := startTLS(netConn, fallbackConfig.TLSConfig)
|
||||
pgConn.contextWatcher = ctxwatch.NewContextWatcher(&DeadlineContextWatcherHandler{Conn: pgConn.conn})
|
||||
pgConn.contextWatcher.Watch(ctx)
|
||||
tlsConn, err := startTLS(pgConn.conn, fallbackConfig.TLSConfig)
|
||||
pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS.
|
||||
if err != nil {
|
||||
netConn.Close()
|
||||
pgConn.conn.Close()
|
||||
return nil, &ConnectError{Config: config, msg: "tls error", err: normalizeTimeoutError(ctx, err)}
|
||||
}
|
||||
|
||||
pgConn.conn = nbTLSConn
|
||||
pgConn.contextWatcher = newContextWatcher(nbTLSConn)
|
||||
pgConn.contextWatcher.Watch(ctx)
|
||||
pgConn.conn = tlsConn
|
||||
}
|
||||
|
||||
pgConn.contextWatcher = ctxwatch.NewContextWatcher(config.BuildContextWatcherHandler(pgConn))
|
||||
pgConn.contextWatcher.Watch(ctx)
|
||||
defer pgConn.contextWatcher.Unwatch()
|
||||
|
||||
pgConn.parameterStatuses = make(map[string]string)
|
||||
|
@ -412,13 +410,6 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
|||
}
|
||||
}
|
||||
|
||||
func newContextWatcher(conn net.Conn) *ctxwatch.ContextWatcher {
|
||||
return ctxwatch.NewContextWatcher(
|
||||
func() { conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) },
|
||||
func() { conn.SetDeadline(time.Time{}) },
|
||||
)
|
||||
}
|
||||
|
||||
func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) {
|
||||
err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103})
|
||||
if err != nil {
|
||||
|
@ -988,10 +979,7 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
|
|||
defer cancelConn.Close()
|
||||
|
||||
if ctx != context.Background() {
|
||||
contextWatcher := ctxwatch.NewContextWatcher(
|
||||
func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) },
|
||||
func() { cancelConn.SetDeadline(time.Time{}) },
|
||||
)
|
||||
contextWatcher := ctxwatch.NewContextWatcher(&DeadlineContextWatcherHandler{Conn: cancelConn})
|
||||
contextWatcher.Watch(ctx)
|
||||
defer contextWatcher.Unwatch()
|
||||
}
|
||||
|
@ -1939,7 +1927,7 @@ func Construct(hc *HijackedConn) (*PgConn, error) {
|
|||
cleanupDone: make(chan struct{}),
|
||||
}
|
||||
|
||||
pgConn.contextWatcher = newContextWatcher(pgConn.conn)
|
||||
pgConn.contextWatcher = ctxwatch.NewContextWatcher(hc.Config.BuildContextWatcherHandler(pgConn))
|
||||
pgConn.bgReader = bgreader.New(pgConn.conn)
|
||||
pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64),
|
||||
func() {
|
||||
|
@ -2246,3 +2234,19 @@ func (p *Pipeline) Close() error {
|
|||
|
||||
return p.err
|
||||
}
|
||||
|
||||
// DeadlineContextWatcherHandler handles canceled contexts by setting a deadline on a net.Conn.
|
||||
type DeadlineContextWatcherHandler struct {
|
||||
Conn net.Conn
|
||||
|
||||
// DeadlineDelay is the delay to set on the deadline set on net.Conn when the context is canceled.
|
||||
DeadlineDelay time.Duration
|
||||
}
|
||||
|
||||
func (h *DeadlineContextWatcherHandler) HandleCancel(ctx context.Context) {
|
||||
h.Conn.SetDeadline(time.Now().Add(h.DeadlineDelay))
|
||||
}
|
||||
|
||||
func (h *DeadlineContextWatcherHandler) HandleUnwatchAfterCancel() {
|
||||
h.Conn.SetDeadline(time.Time{})
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@ import (
|
|||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
"github.com/jackc/pgx/v5/internal/pgmock"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgconn/ctxwatch"
|
||||
"github.com/jackc/pgx/v5/pgproto3"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
@ -3480,3 +3481,49 @@ func mustEncode(buf []byte, err error) []byte {
|
|||
}
|
||||
return buf
|
||||
}
|
||||
|
||||
func TestDeadlineContextWatcherHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("DeadlineExceeded with zero DeadlineDelay", func(t *testing.T) {
|
||||
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
config.BuildContextWatcherHandler = func(conn *pgconn.PgConn) ctxwatch.Handler {
|
||||
return &pgconn.DeadlineContextWatcherHandler{Conn: conn.Conn()}
|
||||
}
|
||||
config.ConnectTimeout = 5 * time.Second
|
||||
|
||||
pgConn, err := pgconn.ConnectConfig(context.Background(), config)
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, err = pgConn.Exec(ctx, "select 1, pg_sleep(1)").ReadAll()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
require.True(t, pgConn.IsClosed())
|
||||
})
|
||||
|
||||
t.Run("DeadlineExceeded with DeadlineDelay", func(t *testing.T) {
|
||||
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
config.BuildContextWatcherHandler = func(conn *pgconn.PgConn) ctxwatch.Handler {
|
||||
return &pgconn.DeadlineContextWatcherHandler{Conn: conn.Conn(), DeadlineDelay: 500 * time.Millisecond}
|
||||
}
|
||||
config.ConnectTimeout = 5 * time.Second
|
||||
|
||||
pgConn, err := pgconn.ConnectConfig(context.Background(), config)
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, err = pgConn.Exec(ctx, "select 1, pg_sleep(0.250)").ReadAll()
|
||||
require.NoError(t, err)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue