Refactor context handling into ctxwatch package

query-exec-mode
Jack Christensen 2019-05-07 18:05:06 -05:00
parent 1e3961bd0e
commit 1baf0ef57e
8 changed files with 261 additions and 80 deletions

View File

@ -206,3 +206,19 @@ func BenchmarkExecPreparedPossibleToCancel(b *testing.B) {
}
}
}
// func BenchmarkChanToSetDeadlinePossibleToCancel(b *testing.B) {
// conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
// require.Nil(b, err)
// defer closeConn(b, conn)
// ctx, cancel := context.WithCancel(context.Background())
// defer cancel()
// b.ResetTimer()
// for i := 0; i < b.N; i++ {
// conn.ChanToSetDeadline().Watch(ctx)
// conn.ChanToSetDeadline().Ignore()
// }
// }

View File

@ -1,51 +0,0 @@
package pgconn
import (
"time"
)
var deadlineTime = time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)
type setDeadliner interface {
SetDeadline(time.Time) error
}
type chanToSetDeadline struct {
cleanupChan chan struct{}
conn setDeadliner
deadlineWasSet bool
cleanupComplete bool
}
func (this *chanToSetDeadline) start(doneChan <-chan struct{}, conn setDeadliner) {
if this.cleanupChan == nil {
this.cleanupChan = make(chan struct{})
}
this.conn = conn
this.deadlineWasSet = false
this.cleanupComplete = false
if doneChan != nil {
go func() {
select {
case <-doneChan:
conn.SetDeadline(deadlineTime)
this.deadlineWasSet = true
<-this.cleanupChan
case <-this.cleanupChan:
}
}()
} else {
this.cleanupComplete = true
}
}
func (this *chanToSetDeadline) cleanup() {
if !this.cleanupComplete {
this.cleanupChan <- struct{}{}
if this.deadlineWasSet {
this.conn.SetDeadline(time.Time{})
}
this.cleanupComplete = true
}
}

1
go.mod
View File

@ -6,7 +6,6 @@ require (
github.com/jackc/pgio v1.0.0
github.com/jackc/pgpassfile v1.0.0
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db
github.com/pkg/errors v0.8.1
github.com/stretchr/testify v1.3.0
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a
golang.org/x/text v0.3.0

1
go.sum
View File

@ -17,6 +17,7 @@ github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a h1:Igim7XhdOpBnWPuYJ70XcNpq8q3BCACtVgNfoJxOV7g=
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE=
golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e h1:nFYrTHrdrAOpShe27kaFHjsqYSEQ0KWqdWLu3xuZJts=
golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=

View File

@ -12,9 +12,9 @@ import (
)
func closeConn(t testing.TB, conn *pgconn.PgConn) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
require.Nil(t, conn.Close(ctx))
require.NoError(t, conn.Close(ctx))
}
// Do a simple query to ensure the connection is still usable

View File

@ -0,0 +1,64 @@
package ctxwatch
import (
"context"
)
// ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a
// time.
type ContextWatcher struct {
onCancel func()
onUnwatchAfterCancel func()
unwatchChan chan struct{}
watchInProgress bool
onCancelWasCalled bool
}
// NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled.
// OnUnwatchAfterCancel will be called when Unwatch is called and the watched context had already been canceled and
// onCancel called.
func NewContextWatcher(onCancel func(), onUnwatchAfterCancel func()) *ContextWatcher {
cw := &ContextWatcher{
onCancel: onCancel,
onUnwatchAfterCancel: onUnwatchAfterCancel,
unwatchChan: make(chan struct{}),
}
return cw
}
// Watch starts watching ctx. If ctx is canceled then the onCancel function passed to NewContextWatcher will be called.
func (cw *ContextWatcher) Watch(ctx context.Context) {
if cw.watchInProgress {
panic("Watch already in progress")
}
cw.onCancelWasCalled = false
if ctx.Done() != nil {
cw.watchInProgress = true
go func() {
select {
case <-ctx.Done():
cw.onCancel()
cw.onCancelWasCalled = true
<-cw.unwatchChan
case <-cw.unwatchChan:
}
}()
} else {
cw.watchInProgress = false
}
}
// Unwatch stops watching the previously watched context. If the onCancel function passed to NewContextWatcher was
// called then onUnwatchAfterCancel will also be called.
func (cw *ContextWatcher) Unwatch() {
if cw.watchInProgress {
cw.unwatchChan <- struct{}{}
if cw.onCancelWasCalled {
cw.onUnwatchAfterCancel()
}
cw.watchInProgress = false
}
}

View File

@ -0,0 +1,139 @@
package ctxwatch_test
import (
"context"
"sync/atomic"
"testing"
"time"
"github.com/jackc/pgconn/internal/ctxwatch"
"github.com/stretchr/testify/require"
)
func TestContextWatcherContextCancelled(t *testing.T) {
canceledChan := make(chan struct{})
cleanupCalled := false
cw := ctxwatch.NewContextWatcher(func() {
canceledChan <- struct{}{}
}, func() {
cleanupCalled = true
})
ctx, cancel := context.WithCancel(context.Background())
cw.Watch(ctx)
cancel()
select {
case <-canceledChan:
case <-time.NewTimer(time.Second).C:
t.Fatal("Timed out waiting for cancel func to be called")
}
cw.Unwatch()
require.True(t, cleanupCalled, "Cleanup func was not called")
}
func TestContextWatcherUnwatchdBeforeContextCancelled(t *testing.T) {
cw := ctxwatch.NewContextWatcher(func() {
t.Error("cancel func should not have been called")
}, func() {
t.Error("cleanup func should not have been called")
})
ctx, cancel := context.WithCancel(context.Background())
cw.Watch(ctx)
cw.Unwatch()
cancel()
}
func TestContextWatcherMultipleWatchPanics(t *testing.T) {
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
cw.Watch(ctx)
ctx2, cancel2 := context.WithCancel(context.Background())
defer cancel2()
require.Panics(t, func() { cw.Watch(ctx2) }, "Expected panic when Watch called multiple times")
}
func TestContextWatcherStress(t *testing.T) {
var cancelFuncCalls int64
var cleanupFuncCalls int64
cw := ctxwatch.NewContextWatcher(func() {
atomic.AddInt64(&cancelFuncCalls, 1)
}, func() {
atomic.AddInt64(&cleanupFuncCalls, 1)
})
cycleCount := 100000
for i := 0; i < cycleCount; i++ {
ctx, cancel := context.WithCancel(context.Background())
cw.Watch(ctx)
if i%2 == 0 {
cancel()
}
// Without time.Sleep, cw.Unwatch will almost always run before the cancel func which means cancel will never happen. This gives us a better mix.
if i%3 == 0 {
time.Sleep(time.Nanosecond)
}
cw.Unwatch()
if i%2 == 1 {
cancel()
}
}
actualCancelFuncCalls := atomic.LoadInt64(&cancelFuncCalls)
actualCleanupFuncCalls := atomic.LoadInt64(&cleanupFuncCalls)
if actualCancelFuncCalls == 0 {
t.Fatal("actualCancelFuncCalls == 0")
}
maxCancelFuncCalls := int64(cycleCount) / 2
if actualCancelFuncCalls > maxCancelFuncCalls {
t.Errorf("cancel func calls should be no more than %d but was %d", actualCancelFuncCalls, maxCancelFuncCalls)
}
if actualCancelFuncCalls != actualCleanupFuncCalls {
t.Errorf("cancel func calls (%d) should be equal to cleanup func calls (%d) but was not", actualCancelFuncCalls, actualCleanupFuncCalls)
}
}
func BenchmarkContextWatcherUncancellable(b *testing.B) {
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
for i := 0; i < b.N; i++ {
cw.Watch(context.Background())
cw.Unwatch()
}
}
func BenchmarkContextWatcherCancelled(b *testing.B) {
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
for i := 0; i < b.N; i++ {
ctx, cancel := context.WithCancel(context.Background())
cw.Watch(ctx)
cancel()
cw.Unwatch()
}
}
func BenchmarkContextWatcherCancellable(b *testing.B) {
cw := ctxwatch.NewContextWatcher(func() {}, func() {})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
for i := 0; i < b.N; i++ {
cw.Watch(ctx)
cw.Unwatch()
}
}

View File

@ -13,7 +13,9 @@ import (
"strconv"
"strings"
"sync"
"time"
"github.com/jackc/pgconn/internal/ctxwatch"
"github.com/jackc/pgio"
"github.com/jackc/pgproto3/v2"
errors "golang.org/x/xerrors"
@ -21,6 +23,7 @@ import (
const (
connStatusUninitialized = iota
connStatusConnecting
connStatusClosed
connStatusIdle
connStatusBusy
@ -71,10 +74,10 @@ type PgConn struct {
bufferingReceiveErr error
// Reusable / preallocated resources
wbuf []byte // write buffer
resultReader ResultReader
multiResultReader MultiResultReader
doneChanToDeadline chanToSetDeadline
wbuf []byte // write buffer
resultReader ResultReader
multiResultReader MultiResultReader
contextWatcher *ctxwatch.ContextWatcher
}
// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format)
@ -149,6 +152,12 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
}
}
pgConn.status = connStatusConnecting
pgConn.contextWatcher = ctxwatch.NewContextWatcher(
func() { pgConn.conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) },
func() { pgConn.conn.SetDeadline(time.Time{}) },
)
pgConn.Frontend, err = pgproto3.NewFrontend(pgproto3.NewChunkReader(pgConn.conn), pgConn.conn)
if err != nil {
return nil, err
@ -355,8 +364,8 @@ func (pgConn *PgConn) Close(ctx context.Context) error {
defer pgConn.conn.Close()
pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn)
defer pgConn.doneChanToDeadline.cleanup()
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
_, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4})
if err != nil {
@ -377,6 +386,7 @@ func (pgConn *PgConn) hardClose() error {
return nil
}
pgConn.status = connStatusClosed
return pgConn.conn.Close()
}
@ -453,8 +463,8 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
return nil, linkErrors(ctx.Err(), ErrNoBytesSent)
default:
}
pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn)
defer pgConn.doneChanToDeadline.cleanup()
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
buf := pgConn.wbuf
buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf)
@ -543,9 +553,12 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
}
defer cancelConn.Close()
var doneChanToDeadline chanToSetDeadline
doneChanToDeadline.start(ctx.Done(), cancelConn)
defer doneChanToDeadline.cleanup()
contextWatcher := ctxwatch.NewContextWatcher(
func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) },
func() { cancelConn.SetDeadline(time.Time{}) },
)
contextWatcher.Watch(ctx)
defer contextWatcher.Unwatch()
buf := make([]byte, 16)
binary.BigEndian.PutUint32(buf[0:4], 16)
@ -579,8 +592,8 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
default:
}
pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn)
defer pgConn.doneChanToDeadline.cleanup()
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
for {
msg, err := pgConn.ReceiveMessage()
@ -622,7 +635,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
return multiResult
default:
}
pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn)
pgConn.contextWatcher.Watch(ctx)
buf := pgConn.wbuf
buf = (&pgproto3.Query{String: sql}).Encode(buf)
@ -630,7 +643,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
n, err := pgConn.conn.Write(buf)
if err != nil {
pgConn.hardClose()
pgConn.doneChanToDeadline.cleanup()
pgConn.contextWatcher.Unwatch()
multiResult.closed = true
if n == 0 {
err = linkErrors(err, ErrNoBytesSent)
@ -732,7 +745,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
return result
default:
}
pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn)
pgConn.contextWatcher.Watch(ctx)
return result
}
@ -749,7 +762,7 @@ func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result
err = linkErrors(err, ErrNoBytesSent)
}
result.concludeCommand(nil, linkErrors(ctx.Err(), err))
pgConn.doneChanToDeadline.cleanup()
pgConn.contextWatcher.Unwatch()
result.closed = true
pgConn.unlock()
}
@ -767,8 +780,8 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
return nil, linkErrors(ctx.Err(), ErrNoBytesSent)
default:
}
pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn)
defer pgConn.doneChanToDeadline.cleanup()
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
// Send copy to command
buf := pgConn.wbuf
@ -828,8 +841,8 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
return nil, linkErrors(ctx.Err(), ErrNoBytesSent)
default:
}
pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn)
defer pgConn.doneChanToDeadline.cleanup()
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
// Send copy to command
buf := pgConn.wbuf
@ -962,7 +975,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error)
msg, err := mrr.pgConn.ReceiveMessage()
if err != nil {
mrr.pgConn.doneChanToDeadline.cleanup()
mrr.pgConn.contextWatcher.Unwatch()
mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err)
mrr.closed = true
mrr.pgConn.hardClose()
@ -971,7 +984,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error)
switch msg := msg.(type) {
case *pgproto3.ReadyForQuery:
mrr.pgConn.doneChanToDeadline.cleanup()
mrr.pgConn.contextWatcher.Unwatch()
mrr.closed = true
mrr.pgConn.unlock()
case *pgproto3.ErrorResponse:
@ -1129,7 +1142,7 @@ func (rr *ResultReader) Close() (CommandTag, error) {
switch msg.(type) {
case *pgproto3.ReadyForQuery:
rr.pgConn.doneChanToDeadline.cleanup()
rr.pgConn.contextWatcher.Unwatch()
rr.pgConn.unlock()
return rr.commandTag, rr.err
}
@ -1148,7 +1161,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error
if err != nil {
rr.concludeCommand(nil, err)
rr.pgConn.doneChanToDeadline.cleanup()
rr.pgConn.contextWatcher.Unwatch()
rr.closed = true
if rr.multiResultReader == nil {
rr.pgConn.hardClose()
@ -1223,7 +1236,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
return multiResult
default:
}
pgConn.doneChanToDeadline.start(ctx.Done(), pgConn.conn)
pgConn.contextWatcher.Watch(ctx)
batch.buf = (&pgproto3.Sync{}).Encode(batch.buf)