From 26c79eb215cd1a8022fee209e1df9601626d2f89 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 3 Jun 2023 18:01:01 -0500 Subject: [PATCH] Handle writes that could deadlock with reads from the server This commit adds a background reader that can optionally buffer reads. It is used whenever a potentially blocking write is made to the server. The background reader is started on a slight delay so there should be no meaningful performance impact as it doesn't run for quick queries and its overhead is minimal relative to slower queries. --- pgconn/auth_scram.go | 4 +- pgconn/internal/bgreader/bgreader.go | 132 ++++++++++++++++++++ pgconn/internal/bgreader/bgreader_test.go | 140 ++++++++++++++++++++++ pgconn/krb5.go | 2 +- pgconn/pgconn.go | 53 ++++++-- 5 files changed, 316 insertions(+), 15 deletions(-) create mode 100644 pgconn/internal/bgreader/bgreader.go create mode 100644 pgconn/internal/bgreader/bgreader_test.go diff --git a/pgconn/auth_scram.go b/pgconn/auth_scram.go index 6ca9e337..8c4b2de3 100644 --- a/pgconn/auth_scram.go +++ b/pgconn/auth_scram.go @@ -42,7 +42,7 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { Data: sc.clientFirstMessage(), } c.frontend.Send(saslInitialResponse) - err = c.frontend.Flush() + err = c.flushWithPotentialWriteReadDeadlock() if err != nil { return err } @@ -62,7 +62,7 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { Data: []byte(sc.clientFinalMessage()), } c.frontend.Send(saslResponse) - err = c.frontend.Flush() + err = c.flushWithPotentialWriteReadDeadlock() if err != nil { return err } diff --git a/pgconn/internal/bgreader/bgreader.go b/pgconn/internal/bgreader/bgreader.go new file mode 100644 index 00000000..aa1a3d39 --- /dev/null +++ b/pgconn/internal/bgreader/bgreader.go @@ -0,0 +1,132 @@ +// Package bgreader provides a io.Reader that can optionally buffer reads in the background. +package bgreader + +import ( + "io" + "sync" + + "github.com/jackc/pgx/v5/internal/iobufpool" +) + +const ( + bgReaderStatusStopped = iota + bgReaderStatusRunning + bgReaderStatusStopping +) + +// BGReader is an io.Reader that can optionally buffer reads in the background. It is safe for concurrent use. +type BGReader struct { + r io.Reader + + cond *sync.Cond + bgReaderStatus int32 + readResults []readResult +} + +type readResult struct { + buf *[]byte + err error +} + +// Start starts the backgrounder reader. If the background reader is already running this is a no-op. The background +// reader will stop automatically when the underlying reader returns an error. +func (r *BGReader) Start() { + r.cond.L.Lock() + defer r.cond.L.Unlock() + + switch r.bgReaderStatus { + case bgReaderStatusStopped: + r.bgReaderStatus = bgReaderStatusRunning + go r.bgRead() + case bgReaderStatusRunning: + // no-op + case bgReaderStatusStopping: + r.bgReaderStatus = bgReaderStatusRunning + } +} + +// Stop tells the background reader to stop after the in progress Read returns. It is safe to call Stop when the +// background reader is not running. +func (r *BGReader) Stop() { + r.cond.L.Lock() + defer r.cond.L.Unlock() + + switch r.bgReaderStatus { + case bgReaderStatusStopped: + // no-op + case bgReaderStatusRunning: + r.bgReaderStatus = bgReaderStatusStopping + case bgReaderStatusStopping: + // no-op + } +} + +func (r *BGReader) bgRead() { + keepReading := true + for keepReading { + buf := iobufpool.Get(8192) + n, err := r.r.Read(*buf) + *buf = (*buf)[:n] + + r.cond.L.Lock() + r.readResults = append(r.readResults, readResult{buf: buf, err: err}) + if r.bgReaderStatus == bgReaderStatusStopping || err != nil { + r.bgReaderStatus = bgReaderStatusStopped + keepReading = false + } + r.cond.L.Unlock() + r.cond.Broadcast() + } +} + +// Read implements the io.Reader interface. +func (r *BGReader) Read(p []byte) (int, error) { + r.cond.L.Lock() + defer r.cond.L.Unlock() + + if len(r.readResults) > 0 { + return r.readFromReadResults(p) + } + + // There are no unread background read results and the background reader is stopped. + if r.bgReaderStatus == bgReaderStatusStopped { + return r.r.Read(p) + } + + // Wait for results from the background reader + for len(r.readResults) == 0 { + r.cond.Wait() + } + return r.readFromReadResults(p) +} + +// readBackgroundResults reads a result previously read by the background reader. r.cond.L must be held. +func (r *BGReader) readFromReadResults(p []byte) (int, error) { + buf := r.readResults[0].buf + var err error + + n := copy(p, *buf) + if n == len(*buf) { + err = r.readResults[0].err + iobufpool.Put(buf) + if len(r.readResults) == 1 { + r.readResults = nil + } else { + r.readResults = r.readResults[1:] + } + } else { + *buf = (*buf)[n:] + r.readResults[0].buf = buf + } + + return n, err +} + +func New(r io.Reader) *BGReader { + return &BGReader{ + r: r, + cond: &sync.Cond{ + L: &sync.Mutex{}, + }, + } +} diff --git a/pgconn/internal/bgreader/bgreader_test.go b/pgconn/internal/bgreader/bgreader_test.go new file mode 100644 index 00000000..f787e2f1 --- /dev/null +++ b/pgconn/internal/bgreader/bgreader_test.go @@ -0,0 +1,140 @@ +package bgreader_test + +import ( + "bytes" + "errors" + "io" + "math/rand" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgconn/internal/bgreader" + "github.com/stretchr/testify/require" +) + +func TestBGReaderReadWhenStopped(t *testing.T) { + r := bytes.NewReader([]byte("foo bar baz")) + bgr := bgreader.New(r) + buf, err := io.ReadAll(bgr) + require.NoError(t, err) + require.Equal(t, []byte("foo bar baz"), buf) +} + +func TestBGReaderReadWhenStarted(t *testing.T) { + r := bytes.NewReader([]byte("foo bar baz")) + bgr := bgreader.New(r) + bgr.Start() + buf, err := io.ReadAll(bgr) + require.NoError(t, err) + require.Equal(t, []byte("foo bar baz"), buf) +} + +type mockReadFunc func(p []byte) (int, error) + +type mockReader struct { + readFuncs []mockReadFunc +} + +func (r *mockReader) Read(p []byte) (int, error) { + if len(r.readFuncs) == 0 { + return 0, io.EOF + } + + fn := r.readFuncs[0] + r.readFuncs = r.readFuncs[1:] + + return fn(p) +} + +func TestBGReaderReadWaitsForBackgroundRead(t *testing.T) { + rr := &mockReader{ + readFuncs: []mockReadFunc{ + func(p []byte) (int, error) { time.Sleep(1 * time.Second); return copy(p, []byte("foo")), nil }, + func(p []byte) (int, error) { return copy(p, []byte("bar")), nil }, + func(p []byte) (int, error) { return copy(p, []byte("baz")), nil }, + }, + } + bgr := bgreader.New(rr) + bgr.Start() + buf := make([]byte, 3) + n, err := bgr.Read(buf) + require.NoError(t, err) + require.EqualValues(t, 3, n) + require.Equal(t, []byte("foo"), buf) +} + +func TestBGReaderErrorWhenStarted(t *testing.T) { + rr := &mockReader{ + readFuncs: []mockReadFunc{ + func(p []byte) (int, error) { return copy(p, []byte("foo")), nil }, + func(p []byte) (int, error) { return copy(p, []byte("bar")), nil }, + func(p []byte) (int, error) { return copy(p, []byte("baz")), errors.New("oops") }, + }, + } + + bgr := bgreader.New(rr) + bgr.Start() + buf, err := io.ReadAll(bgr) + require.Equal(t, []byte("foobarbaz"), buf) + require.EqualError(t, err, "oops") +} + +func TestBGReaderErrorWhenStopped(t *testing.T) { + rr := &mockReader{ + readFuncs: []mockReadFunc{ + func(p []byte) (int, error) { return copy(p, []byte("foo")), nil }, + func(p []byte) (int, error) { return copy(p, []byte("bar")), nil }, + func(p []byte) (int, error) { return copy(p, []byte("baz")), errors.New("oops") }, + }, + } + + bgr := bgreader.New(rr) + buf, err := io.ReadAll(bgr) + require.Equal(t, []byte("foobarbaz"), buf) + require.EqualError(t, err, "oops") +} + +type numberReader struct { + v uint8 + rng *rand.Rand +} + +func (nr *numberReader) Read(p []byte) (int, error) { + n := nr.rng.Intn(len(p)) + for i := 0; i < n; i++ { + p[i] = nr.v + nr.v++ + } + + return n, nil +} + +// TestBGReaderStress stress tests BGReader by reading a lot of bytes in random sizes while randomly starting and +// stopping the background worker from other goroutines. +func TestBGReaderStress(t *testing.T) { + nr := &numberReader{rng: rand.New(rand.NewSource(0))} + bgr := bgreader.New(nr) + + bytesRead := 0 + var expected uint8 + buf := make([]byte, 10_000) + rng := rand.New(rand.NewSource(0)) + + for bytesRead < 1_000_000 { + randomNumber := rng.Intn(100) + switch { + case randomNumber < 10: + go bgr.Start() + case randomNumber < 20: + go bgr.Stop() + default: + n, err := bgr.Read(buf) + require.NoError(t, err) + for i := 0; i < n; i++ { + require.Equal(t, expected, buf[i]) + expected++ + } + bytesRead += n + } + } +} diff --git a/pgconn/krb5.go b/pgconn/krb5.go index 969675fd..3c1af347 100644 --- a/pgconn/krb5.go +++ b/pgconn/krb5.go @@ -63,7 +63,7 @@ func (c *PgConn) gssAuth() error { Data: nextData, } c.frontend.Send(gssResponse) - err = c.frontend.Flush() + err = c.flushWithPotentialWriteReadDeadlock() if err != nil { return err } diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 768d3e71..02226e0a 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -18,6 +18,7 @@ import ( "github.com/jackc/pgx/v5/internal/iobufpool" "github.com/jackc/pgx/v5/internal/pgio" + "github.com/jackc/pgx/v5/pgconn/internal/bgreader" "github.com/jackc/pgx/v5/pgconn/internal/ctxwatch" "github.com/jackc/pgx/v5/pgproto3" ) @@ -71,6 +72,8 @@ type PgConn struct { parameterStatuses map[string]string // parameters that have been reported by the server txStatus byte frontend *pgproto3.Frontend + bgReader *bgreader.BGReader + slowWriteTimer *time.Timer config *Config @@ -293,7 +296,9 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn.parameterStatuses = make(map[string]string) pgConn.status = connStatusConnecting - pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn) + pgConn.bgReader = bgreader.New(pgConn.conn) + pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), pgConn.bgReader.Start) + pgConn.frontend = config.BuildFrontend(pgConn.bgReader, pgConn.conn) startupMsg := pgproto3.StartupMessage{ ProtocolVersion: pgproto3.ProtocolVersionNumber, @@ -311,7 +316,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } pgConn.frontend.Send(&startupMsg) - if err := pgConn.frontend.Flush(); err != nil { + if err := pgConn.flushWithPotentialWriteReadDeadlock(); err != nil { pgConn.conn.Close() return nil, &connectError{config: config, msg: "failed to write startup message", err: err} } @@ -416,7 +421,7 @@ func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { func (pgConn *PgConn) txPasswordMessage(password string) (err error) { pgConn.frontend.Send(&pgproto3.PasswordMessage{Password: password}) - return pgConn.frontend.Flush() + return pgConn.flushWithPotentialWriteReadDeadlock() } func hexMD5(s string) string { @@ -611,7 +616,7 @@ func (pgConn *PgConn) Close(ctx context.Context) error { // // See https://github.com/jackc/pgx/issues/637 pgConn.frontend.Send(&pgproto3.Terminate{}) - pgConn.frontend.Flush() + pgConn.flushWithPotentialWriteReadDeadlock() return pgConn.conn.Close() } @@ -638,7 +643,7 @@ func (pgConn *PgConn) asyncClose() { pgConn.conn.SetDeadline(deadline) pgConn.frontend.Send(&pgproto3.Terminate{}) - pgConn.frontend.Flush() + pgConn.flushWithPotentialWriteReadDeadlock() }() } @@ -813,7 +818,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ pgConn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}) pgConn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name}) pgConn.frontend.SendSync(&pgproto3.Sync{}) - err := pgConn.frontend.Flush() + err := pgConn.flushWithPotentialWriteReadDeadlock() if err != nil { pgConn.asyncClose() return nil, err @@ -995,7 +1000,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { } pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) - err := pgConn.frontend.Flush() + err := pgConn.flushWithPotentialWriteReadDeadlock() if err != nil { pgConn.asyncClose() pgConn.contextWatcher.Unwatch() @@ -1106,7 +1111,7 @@ func (pgConn *PgConn) execExtendedSuffix(result *ResultReader) { pgConn.frontend.SendExecute(&pgproto3.Execute{}) pgConn.frontend.SendSync(&pgproto3.Sync{}) - err := pgConn.frontend.Flush() + err := pgConn.flushWithPotentialWriteReadDeadlock() if err != nil { pgConn.asyncClose() result.concludeCommand(CommandTag{}, err) @@ -1139,7 +1144,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm // Send copy to command pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) - err := pgConn.frontend.Flush() + err := pgConn.flushWithPotentialWriteReadDeadlock() if err != nil { pgConn.asyncClose() pgConn.unlock() @@ -1197,7 +1202,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co // Send copy from query pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) - err := pgConn.frontend.Flush() + err := pgConn.flushWithPotentialWriteReadDeadlock() if err != nil { pgConn.asyncClose() return CommandTag{}, err @@ -1273,7 +1278,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } else { pgConn.frontend.Send(&pgproto3.CopyFail{Message: copyErr.Error()}) } - err = pgConn.frontend.Flush() + err = pgConn.flushWithPotentialWriteReadDeadlock() if err != nil { pgConn.asyncClose() return CommandTag{}, err @@ -1634,7 +1639,9 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) + pgConn.enterPotentialWriteReadDeadlock() _, err := pgConn.conn.Write(batch.buf) + pgConn.exitPotentialWriteReadDeadlock() if err != nil { multiResult.closed = true multiResult.err = err @@ -1688,6 +1695,26 @@ func (pgConn *PgConn) makeCommandTag(buf []byte) CommandTag { return CommandTag{s: string(buf)} } +// enterPotentialWriteReadDeadlock must be called before a write that could deadlock if the server is simultaneously +// blocked writing to us. +func (pgConn *PgConn) enterPotentialWriteReadDeadlock() { + pgConn.slowWriteTimer.Reset(10 * time.Millisecond) +} + +// exitPotentialWriteReadDeadlock must be called after a call to enterPotentialWriteReadDeadlock. +func (pgConn *PgConn) exitPotentialWriteReadDeadlock() { + if !pgConn.slowWriteTimer.Reset(time.Duration(math.MaxInt64)) { + pgConn.slowWriteTimer.Stop() + } +} + +func (pgConn *PgConn) flushWithPotentialWriteReadDeadlock() error { + pgConn.enterPotentialWriteReadDeadlock() + err := pgConn.frontend.Flush() + pgConn.exitPotentialWriteReadDeadlock() + return err +} + // HijackedConn is the result of hijacking a connection. // // Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning @@ -1746,6 +1773,8 @@ func Construct(hc *HijackedConn) (*PgConn, error) { } pgConn.contextWatcher = newContextWatcher(pgConn.conn) + pgConn.bgReader = bgreader.New(pgConn.conn) + pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), pgConn.bgReader.Start) return pgConn, nil } @@ -1868,7 +1897,7 @@ func (p *Pipeline) Flush() error { return errors.New("pipeline closed") } - err := p.conn.frontend.Flush() + err := p.conn.flushWithPotentialWriteReadDeadlock() if err != nil { err = normalizeTimeoutError(p.ctx, err)