mirror of https://github.com/jackc/pgx.git
Allow customizing context canceled behavior for pgconn
This feature made the ctxwatch package public.pull/1894/head
parent
60a01d044a
commit
42c9e9070a
|
@ -19,6 +19,7 @@ import (
|
||||||
|
|
||||||
"github.com/jackc/pgpassfile"
|
"github.com/jackc/pgpassfile"
|
||||||
"github.com/jackc/pgservicefile"
|
"github.com/jackc/pgservicefile"
|
||||||
|
"github.com/jackc/pgx/v5/pgconn/ctxwatch"
|
||||||
"github.com/jackc/pgx/v5/pgproto3"
|
"github.com/jackc/pgx/v5/pgproto3"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -39,7 +40,12 @@ type Config struct {
|
||||||
DialFunc DialFunc // e.g. net.Dialer.DialContext
|
DialFunc DialFunc // e.g. net.Dialer.DialContext
|
||||||
LookupFunc LookupFunc // e.g. net.Resolver.LookupHost
|
LookupFunc LookupFunc // e.g. net.Resolver.LookupHost
|
||||||
BuildFrontend BuildFrontendFunc
|
BuildFrontend BuildFrontendFunc
|
||||||
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
|
|
||||||
|
// BuildContextWatcherHandler is called to create a ContextWatcherHandler for a connection. The handler is called
|
||||||
|
// when a context passed to a PgConn method is canceled.
|
||||||
|
BuildContextWatcherHandler func(*PgConn) ctxwatch.Handler
|
||||||
|
|
||||||
|
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
|
||||||
|
|
||||||
KerberosSrvName string
|
KerberosSrvName string
|
||||||
KerberosSpn string
|
KerberosSpn string
|
||||||
|
@ -266,6 +272,9 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
|
||||||
BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend {
|
BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend {
|
||||||
return pgproto3.NewFrontend(r, w)
|
return pgproto3.NewFrontend(r, w)
|
||||||
},
|
},
|
||||||
|
BuildContextWatcherHandler: func(pgConn *PgConn) ctxwatch.Handler {
|
||||||
|
return &DeadlineContextWatcherHandler{Conn: pgConn.conn}
|
||||||
|
},
|
||||||
OnPgError: func(_ *PgConn, pgErr *PgError) bool {
|
OnPgError: func(_ *PgConn, pgErr *PgError) bool {
|
||||||
// we want to automatically close any fatal errors
|
// we want to automatically close any fatal errors
|
||||||
if strings.EqualFold(pgErr.Severity, "FATAL") {
|
if strings.EqualFold(pgErr.Severity, "FATAL") {
|
||||||
|
|
|
@ -8,9 +8,8 @@ import (
|
||||||
// ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a
|
// ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a
|
||||||
// time.
|
// time.
|
||||||
type ContextWatcher struct {
|
type ContextWatcher struct {
|
||||||
onCancel func()
|
handler Handler
|
||||||
onUnwatchAfterCancel func()
|
unwatchChan chan struct{}
|
||||||
unwatchChan chan struct{}
|
|
||||||
|
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
watchInProgress bool
|
watchInProgress bool
|
||||||
|
@ -20,11 +19,10 @@ type ContextWatcher struct {
|
||||||
// NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled.
|
// NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled.
|
||||||
// OnUnwatchAfterCancel will be called when Unwatch is called and the watched context had already been canceled and
|
// OnUnwatchAfterCancel will be called when Unwatch is called and the watched context had already been canceled and
|
||||||
// onCancel called.
|
// onCancel called.
|
||||||
func NewContextWatcher(onCancel func(), onUnwatchAfterCancel func()) *ContextWatcher {
|
func NewContextWatcher(handler Handler) *ContextWatcher {
|
||||||
cw := &ContextWatcher{
|
cw := &ContextWatcher{
|
||||||
onCancel: onCancel,
|
handler: handler,
|
||||||
onUnwatchAfterCancel: onUnwatchAfterCancel,
|
unwatchChan: make(chan struct{}),
|
||||||
unwatchChan: make(chan struct{}),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return cw
|
return cw
|
||||||
|
@ -46,7 +44,7 @@ func (cw *ContextWatcher) Watch(ctx context.Context) {
|
||||||
go func() {
|
go func() {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
cw.onCancel()
|
cw.handler.HandleCancel(ctx)
|
||||||
cw.onCancelWasCalled = true
|
cw.onCancelWasCalled = true
|
||||||
<-cw.unwatchChan
|
<-cw.unwatchChan
|
||||||
case <-cw.unwatchChan:
|
case <-cw.unwatchChan:
|
||||||
|
@ -66,8 +64,17 @@ func (cw *ContextWatcher) Unwatch() {
|
||||||
if cw.watchInProgress {
|
if cw.watchInProgress {
|
||||||
cw.unwatchChan <- struct{}{}
|
cw.unwatchChan <- struct{}{}
|
||||||
if cw.onCancelWasCalled {
|
if cw.onCancelWasCalled {
|
||||||
cw.onUnwatchAfterCancel()
|
cw.handler.HandleUnwatchAfterCancel()
|
||||||
}
|
}
|
||||||
cw.watchInProgress = false
|
cw.watchInProgress = false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Handler interface {
|
||||||
|
// HandleCancel is called when the context that a ContextWatcher is currently watching is canceled. canceledCtx is the
|
||||||
|
// context that was canceled.
|
||||||
|
HandleCancel(canceledCtx context.Context)
|
||||||
|
|
||||||
|
// HandleUnwatchAfterCancel is called when a ContextWatcher that called HandleCancel on this Handler is unwatched.
|
||||||
|
HandleUnwatchAfterCancel()
|
||||||
|
}
|
|
@ -6,17 +6,32 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/jackc/pgx/v5/pgconn/internal/ctxwatch"
|
"github.com/jackc/pgx/v5/pgconn/ctxwatch"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type testHandler struct {
|
||||||
|
handleCancel func(context.Context)
|
||||||
|
handleUnwatchAfterCancel func()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *testHandler) HandleCancel(ctx context.Context) {
|
||||||
|
h.handleCancel(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *testHandler) HandleUnwatchAfterCancel() {
|
||||||
|
h.handleUnwatchAfterCancel()
|
||||||
|
}
|
||||||
|
|
||||||
func TestContextWatcherContextCancelled(t *testing.T) {
|
func TestContextWatcherContextCancelled(t *testing.T) {
|
||||||
canceledChan := make(chan struct{})
|
canceledChan := make(chan struct{})
|
||||||
cleanupCalled := false
|
cleanupCalled := false
|
||||||
cw := ctxwatch.NewContextWatcher(func() {
|
cw := ctxwatch.NewContextWatcher(&testHandler{
|
||||||
canceledChan <- struct{}{}
|
handleCancel: func(context.Context) {
|
||||||
}, func() {
|
canceledChan <- struct{}{}
|
||||||
cleanupCalled = true
|
}, handleUnwatchAfterCancel: func() {
|
||||||
|
cleanupCalled = true
|
||||||
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
@ -35,10 +50,12 @@ func TestContextWatcherContextCancelled(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestContextWatcherUnwatchdBeforeContextCancelled(t *testing.T) {
|
func TestContextWatcherUnwatchdBeforeContextCancelled(t *testing.T) {
|
||||||
cw := ctxwatch.NewContextWatcher(func() {
|
cw := ctxwatch.NewContextWatcher(&testHandler{
|
||||||
t.Error("cancel func should not have been called")
|
handleCancel: func(context.Context) {
|
||||||
}, func() {
|
t.Error("cancel func should not have been called")
|
||||||
t.Error("cleanup func should not have been called")
|
}, handleUnwatchAfterCancel: func() {
|
||||||
|
t.Error("cleanup func should not have been called")
|
||||||
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
@ -48,7 +65,7 @@ func TestContextWatcherUnwatchdBeforeContextCancelled(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestContextWatcherMultipleWatchPanics(t *testing.T) {
|
func TestContextWatcherMultipleWatchPanics(t *testing.T) {
|
||||||
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
|
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
@ -61,7 +78,7 @@ func TestContextWatcherMultipleWatchPanics(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestContextWatcherUnwatchWhenNotWatchingIsSafe(t *testing.T) {
|
func TestContextWatcherUnwatchWhenNotWatchingIsSafe(t *testing.T) {
|
||||||
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
|
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
|
||||||
cw.Unwatch() // unwatch when not / never watching
|
cw.Unwatch() // unwatch when not / never watching
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
@ -72,7 +89,7 @@ func TestContextWatcherUnwatchWhenNotWatchingIsSafe(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestContextWatcherUnwatchIsConcurrencySafe(t *testing.T) {
|
func TestContextWatcherUnwatchIsConcurrencySafe(t *testing.T) {
|
||||||
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
|
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
@ -88,10 +105,12 @@ func TestContextWatcherStress(t *testing.T) {
|
||||||
var cancelFuncCalls int64
|
var cancelFuncCalls int64
|
||||||
var cleanupFuncCalls int64
|
var cleanupFuncCalls int64
|
||||||
|
|
||||||
cw := ctxwatch.NewContextWatcher(func() {
|
cw := ctxwatch.NewContextWatcher(&testHandler{
|
||||||
atomic.AddInt64(&cancelFuncCalls, 1)
|
handleCancel: func(context.Context) {
|
||||||
}, func() {
|
atomic.AddInt64(&cancelFuncCalls, 1)
|
||||||
atomic.AddInt64(&cleanupFuncCalls, 1)
|
}, handleUnwatchAfterCancel: func() {
|
||||||
|
atomic.AddInt64(&cleanupFuncCalls, 1)
|
||||||
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
cycleCount := 100000
|
cycleCount := 100000
|
||||||
|
@ -134,7 +153,7 @@ func TestContextWatcherStress(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkContextWatcherUncancellable(b *testing.B) {
|
func BenchmarkContextWatcherUncancellable(b *testing.B) {
|
||||||
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
|
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
|
||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
cw.Watch(context.Background())
|
cw.Watch(context.Background())
|
||||||
|
@ -143,7 +162,7 @@ func BenchmarkContextWatcherUncancellable(b *testing.B) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkContextWatcherCancelled(b *testing.B) {
|
func BenchmarkContextWatcherCancelled(b *testing.B) {
|
||||||
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
|
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
|
||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
@ -154,7 +173,7 @@ func BenchmarkContextWatcherCancelled(b *testing.B) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkContextWatcherCancellable(b *testing.B) {
|
func BenchmarkContextWatcherCancellable(b *testing.B) {
|
||||||
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
|
cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
|
@ -18,8 +18,8 @@ import (
|
||||||
|
|
||||||
"github.com/jackc/pgx/v5/internal/iobufpool"
|
"github.com/jackc/pgx/v5/internal/iobufpool"
|
||||||
"github.com/jackc/pgx/v5/internal/pgio"
|
"github.com/jackc/pgx/v5/internal/pgio"
|
||||||
|
"github.com/jackc/pgx/v5/pgconn/ctxwatch"
|
||||||
"github.com/jackc/pgx/v5/pgconn/internal/bgreader"
|
"github.com/jackc/pgx/v5/pgconn/internal/bgreader"
|
||||||
"github.com/jackc/pgx/v5/pgconn/internal/ctxwatch"
|
|
||||||
"github.com/jackc/pgx/v5/pgproto3"
|
"github.com/jackc/pgx/v5/pgproto3"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -281,28 +281,26 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port)
|
network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port)
|
||||||
netConn, err := config.DialFunc(ctx, network, address)
|
pgConn.conn, err = config.DialFunc(ctx, network, address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, &ConnectError{Config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)}
|
return nil, &ConnectError{Config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)}
|
||||||
}
|
}
|
||||||
|
|
||||||
pgConn.conn = netConn
|
|
||||||
pgConn.contextWatcher = newContextWatcher(netConn)
|
|
||||||
pgConn.contextWatcher.Watch(ctx)
|
|
||||||
|
|
||||||
if fallbackConfig.TLSConfig != nil {
|
if fallbackConfig.TLSConfig != nil {
|
||||||
nbTLSConn, err := startTLS(netConn, fallbackConfig.TLSConfig)
|
pgConn.contextWatcher = ctxwatch.NewContextWatcher(&DeadlineContextWatcherHandler{Conn: pgConn.conn})
|
||||||
|
pgConn.contextWatcher.Watch(ctx)
|
||||||
|
tlsConn, err := startTLS(pgConn.conn, fallbackConfig.TLSConfig)
|
||||||
pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS.
|
pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
netConn.Close()
|
pgConn.conn.Close()
|
||||||
return nil, &ConnectError{Config: config, msg: "tls error", err: normalizeTimeoutError(ctx, err)}
|
return nil, &ConnectError{Config: config, msg: "tls error", err: normalizeTimeoutError(ctx, err)}
|
||||||
}
|
}
|
||||||
|
|
||||||
pgConn.conn = nbTLSConn
|
pgConn.conn = tlsConn
|
||||||
pgConn.contextWatcher = newContextWatcher(nbTLSConn)
|
|
||||||
pgConn.contextWatcher.Watch(ctx)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pgConn.contextWatcher = ctxwatch.NewContextWatcher(config.BuildContextWatcherHandler(pgConn))
|
||||||
|
pgConn.contextWatcher.Watch(ctx)
|
||||||
defer pgConn.contextWatcher.Unwatch()
|
defer pgConn.contextWatcher.Unwatch()
|
||||||
|
|
||||||
pgConn.parameterStatuses = make(map[string]string)
|
pgConn.parameterStatuses = make(map[string]string)
|
||||||
|
@ -412,13 +410,6 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newContextWatcher(conn net.Conn) *ctxwatch.ContextWatcher {
|
|
||||||
return ctxwatch.NewContextWatcher(
|
|
||||||
func() { conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) },
|
|
||||||
func() { conn.SetDeadline(time.Time{}) },
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) {
|
func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) {
|
||||||
err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103})
|
err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -988,10 +979,7 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
|
||||||
defer cancelConn.Close()
|
defer cancelConn.Close()
|
||||||
|
|
||||||
if ctx != context.Background() {
|
if ctx != context.Background() {
|
||||||
contextWatcher := ctxwatch.NewContextWatcher(
|
contextWatcher := ctxwatch.NewContextWatcher(&DeadlineContextWatcherHandler{Conn: cancelConn})
|
||||||
func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) },
|
|
||||||
func() { cancelConn.SetDeadline(time.Time{}) },
|
|
||||||
)
|
|
||||||
contextWatcher.Watch(ctx)
|
contextWatcher.Watch(ctx)
|
||||||
defer contextWatcher.Unwatch()
|
defer contextWatcher.Unwatch()
|
||||||
}
|
}
|
||||||
|
@ -1939,7 +1927,7 @@ func Construct(hc *HijackedConn) (*PgConn, error) {
|
||||||
cleanupDone: make(chan struct{}),
|
cleanupDone: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
pgConn.contextWatcher = newContextWatcher(pgConn.conn)
|
pgConn.contextWatcher = ctxwatch.NewContextWatcher(hc.Config.BuildContextWatcherHandler(pgConn))
|
||||||
pgConn.bgReader = bgreader.New(pgConn.conn)
|
pgConn.bgReader = bgreader.New(pgConn.conn)
|
||||||
pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64),
|
pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64),
|
||||||
func() {
|
func() {
|
||||||
|
@ -2246,3 +2234,19 @@ func (p *Pipeline) Close() error {
|
||||||
|
|
||||||
return p.err
|
return p.err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeadlineContextWatcherHandler handles canceled contexts by setting a deadline on a net.Conn.
|
||||||
|
type DeadlineContextWatcherHandler struct {
|
||||||
|
Conn net.Conn
|
||||||
|
|
||||||
|
// DeadlineDelay is the delay to set on the deadline set on net.Conn when the context is canceled.
|
||||||
|
DeadlineDelay time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DeadlineContextWatcherHandler) HandleCancel(ctx context.Context) {
|
||||||
|
h.Conn.SetDeadline(time.Now().Add(h.DeadlineDelay))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DeadlineContextWatcherHandler) HandleUnwatchAfterCancel() {
|
||||||
|
h.Conn.SetDeadline(time.Time{})
|
||||||
|
}
|
||||||
|
|
|
@ -24,6 +24,7 @@ import (
|
||||||
"github.com/jackc/pgx/v5/internal/pgio"
|
"github.com/jackc/pgx/v5/internal/pgio"
|
||||||
"github.com/jackc/pgx/v5/internal/pgmock"
|
"github.com/jackc/pgx/v5/internal/pgmock"
|
||||||
"github.com/jackc/pgx/v5/pgconn"
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
|
"github.com/jackc/pgx/v5/pgconn/ctxwatch"
|
||||||
"github.com/jackc/pgx/v5/pgproto3"
|
"github.com/jackc/pgx/v5/pgproto3"
|
||||||
"github.com/jackc/pgx/v5/pgtype"
|
"github.com/jackc/pgx/v5/pgtype"
|
||||||
)
|
)
|
||||||
|
@ -3480,3 +3481,49 @@ func mustEncode(buf []byte, err error) []byte {
|
||||||
}
|
}
|
||||||
return buf
|
return buf
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDeadlineContextWatcherHandler(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("DeadlineExceeded with zero DeadlineDelay", func(t *testing.T) {
|
||||||
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
config.BuildContextWatcherHandler = func(conn *pgconn.PgConn) ctxwatch.Handler {
|
||||||
|
return &pgconn.DeadlineContextWatcherHandler{Conn: conn.Conn()}
|
||||||
|
}
|
||||||
|
config.ConnectTimeout = 5 * time.Second
|
||||||
|
|
||||||
|
pgConn, err := pgconn.ConnectConfig(context.Background(), config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
_, err = pgConn.Exec(ctx, "select 1, pg_sleep(1)").ReadAll()
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||||
|
require.True(t, pgConn.IsClosed())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("DeadlineExceeded with DeadlineDelay", func(t *testing.T) {
|
||||||
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
config.BuildContextWatcherHandler = func(conn *pgconn.PgConn) ctxwatch.Handler {
|
||||||
|
return &pgconn.DeadlineContextWatcherHandler{Conn: conn.Conn(), DeadlineDelay: 500 * time.Millisecond}
|
||||||
|
}
|
||||||
|
config.ConnectTimeout = 5 * time.Second
|
||||||
|
|
||||||
|
pgConn, err := pgconn.ConnectConfig(context.Background(), config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
_, err = pgConn.Exec(ctx, "select 1, pg_sleep(0.250)").ReadAll()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue