mirror of
https://github.com/jackc/pgx.git
synced 2025-05-30 19:22:19 +00:00
Handle writes that could deadlock with reads from the server
This commit adds a background reader that can optionally buffer reads. It is used whenever a potentially blocking write is made to the server. The background reader is started on a slight delay so there should be no meaningful performance impact as it doesn't run for quick queries and its overhead is minimal relative to slower queries.
This commit is contained in:
parent
85136a8efe
commit
26c79eb215
@ -42,7 +42,7 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
|
||||
Data: sc.clientFirstMessage(),
|
||||
}
|
||||
c.frontend.Send(saslInitialResponse)
|
||||
err = c.frontend.Flush()
|
||||
err = c.flushWithPotentialWriteReadDeadlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -62,7 +62,7 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
|
||||
Data: []byte(sc.clientFinalMessage()),
|
||||
}
|
||||
c.frontend.Send(saslResponse)
|
||||
err = c.frontend.Flush()
|
||||
err = c.flushWithPotentialWriteReadDeadlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
132
pgconn/internal/bgreader/bgreader.go
Normal file
132
pgconn/internal/bgreader/bgreader.go
Normal file
@ -0,0 +1,132 @@
|
||||
// Package bgreader provides a io.Reader that can optionally buffer reads in the background.
|
||||
package bgreader
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/iobufpool"
|
||||
)
|
||||
|
||||
const (
|
||||
bgReaderStatusStopped = iota
|
||||
bgReaderStatusRunning
|
||||
bgReaderStatusStopping
|
||||
)
|
||||
|
||||
// BGReader is an io.Reader that can optionally buffer reads in the background. It is safe for concurrent use.
|
||||
type BGReader struct {
|
||||
r io.Reader
|
||||
|
||||
cond *sync.Cond
|
||||
bgReaderStatus int32
|
||||
readResults []readResult
|
||||
}
|
||||
|
||||
type readResult struct {
|
||||
buf *[]byte
|
||||
err error
|
||||
}
|
||||
|
||||
// Start starts the backgrounder reader. If the background reader is already running this is a no-op. The background
|
||||
// reader will stop automatically when the underlying reader returns an error.
|
||||
func (r *BGReader) Start() {
|
||||
r.cond.L.Lock()
|
||||
defer r.cond.L.Unlock()
|
||||
|
||||
switch r.bgReaderStatus {
|
||||
case bgReaderStatusStopped:
|
||||
r.bgReaderStatus = bgReaderStatusRunning
|
||||
go r.bgRead()
|
||||
case bgReaderStatusRunning:
|
||||
// no-op
|
||||
case bgReaderStatusStopping:
|
||||
r.bgReaderStatus = bgReaderStatusRunning
|
||||
}
|
||||
}
|
||||
|
||||
// Stop tells the background reader to stop after the in progress Read returns. It is safe to call Stop when the
|
||||
// background reader is not running.
|
||||
func (r *BGReader) Stop() {
|
||||
r.cond.L.Lock()
|
||||
defer r.cond.L.Unlock()
|
||||
|
||||
switch r.bgReaderStatus {
|
||||
case bgReaderStatusStopped:
|
||||
// no-op
|
||||
case bgReaderStatusRunning:
|
||||
r.bgReaderStatus = bgReaderStatusStopping
|
||||
case bgReaderStatusStopping:
|
||||
// no-op
|
||||
}
|
||||
}
|
||||
|
||||
func (r *BGReader) bgRead() {
|
||||
keepReading := true
|
||||
for keepReading {
|
||||
buf := iobufpool.Get(8192)
|
||||
n, err := r.r.Read(*buf)
|
||||
*buf = (*buf)[:n]
|
||||
|
||||
r.cond.L.Lock()
|
||||
r.readResults = append(r.readResults, readResult{buf: buf, err: err})
|
||||
if r.bgReaderStatus == bgReaderStatusStopping || err != nil {
|
||||
r.bgReaderStatus = bgReaderStatusStopped
|
||||
keepReading = false
|
||||
}
|
||||
r.cond.L.Unlock()
|
||||
r.cond.Broadcast()
|
||||
}
|
||||
}
|
||||
|
||||
// Read implements the io.Reader interface.
|
||||
func (r *BGReader) Read(p []byte) (int, error) {
|
||||
r.cond.L.Lock()
|
||||
defer r.cond.L.Unlock()
|
||||
|
||||
if len(r.readResults) > 0 {
|
||||
return r.readFromReadResults(p)
|
||||
}
|
||||
|
||||
// There are no unread background read results and the background reader is stopped.
|
||||
if r.bgReaderStatus == bgReaderStatusStopped {
|
||||
return r.r.Read(p)
|
||||
}
|
||||
|
||||
// Wait for results from the background reader
|
||||
for len(r.readResults) == 0 {
|
||||
r.cond.Wait()
|
||||
}
|
||||
return r.readFromReadResults(p)
|
||||
}
|
||||
|
||||
// readBackgroundResults reads a result previously read by the background reader. r.cond.L must be held.
|
||||
func (r *BGReader) readFromReadResults(p []byte) (int, error) {
|
||||
buf := r.readResults[0].buf
|
||||
var err error
|
||||
|
||||
n := copy(p, *buf)
|
||||
if n == len(*buf) {
|
||||
err = r.readResults[0].err
|
||||
iobufpool.Put(buf)
|
||||
if len(r.readResults) == 1 {
|
||||
r.readResults = nil
|
||||
} else {
|
||||
r.readResults = r.readResults[1:]
|
||||
}
|
||||
} else {
|
||||
*buf = (*buf)[n:]
|
||||
r.readResults[0].buf = buf
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
func New(r io.Reader) *BGReader {
|
||||
return &BGReader{
|
||||
r: r,
|
||||
cond: &sync.Cond{
|
||||
L: &sync.Mutex{},
|
||||
},
|
||||
}
|
||||
}
|
140
pgconn/internal/bgreader/bgreader_test.go
Normal file
140
pgconn/internal/bgreader/bgreader_test.go
Normal file
@ -0,0 +1,140 @@
|
||||
package bgreader_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn/internal/bgreader"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBGReaderReadWhenStopped(t *testing.T) {
|
||||
r := bytes.NewReader([]byte("foo bar baz"))
|
||||
bgr := bgreader.New(r)
|
||||
buf, err := io.ReadAll(bgr)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("foo bar baz"), buf)
|
||||
}
|
||||
|
||||
func TestBGReaderReadWhenStarted(t *testing.T) {
|
||||
r := bytes.NewReader([]byte("foo bar baz"))
|
||||
bgr := bgreader.New(r)
|
||||
bgr.Start()
|
||||
buf, err := io.ReadAll(bgr)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("foo bar baz"), buf)
|
||||
}
|
||||
|
||||
type mockReadFunc func(p []byte) (int, error)
|
||||
|
||||
type mockReader struct {
|
||||
readFuncs []mockReadFunc
|
||||
}
|
||||
|
||||
func (r *mockReader) Read(p []byte) (int, error) {
|
||||
if len(r.readFuncs) == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
fn := r.readFuncs[0]
|
||||
r.readFuncs = r.readFuncs[1:]
|
||||
|
||||
return fn(p)
|
||||
}
|
||||
|
||||
func TestBGReaderReadWaitsForBackgroundRead(t *testing.T) {
|
||||
rr := &mockReader{
|
||||
readFuncs: []mockReadFunc{
|
||||
func(p []byte) (int, error) { time.Sleep(1 * time.Second); return copy(p, []byte("foo")), nil },
|
||||
func(p []byte) (int, error) { return copy(p, []byte("bar")), nil },
|
||||
func(p []byte) (int, error) { return copy(p, []byte("baz")), nil },
|
||||
},
|
||||
}
|
||||
bgr := bgreader.New(rr)
|
||||
bgr.Start()
|
||||
buf := make([]byte, 3)
|
||||
n, err := bgr.Read(buf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 3, n)
|
||||
require.Equal(t, []byte("foo"), buf)
|
||||
}
|
||||
|
||||
func TestBGReaderErrorWhenStarted(t *testing.T) {
|
||||
rr := &mockReader{
|
||||
readFuncs: []mockReadFunc{
|
||||
func(p []byte) (int, error) { return copy(p, []byte("foo")), nil },
|
||||
func(p []byte) (int, error) { return copy(p, []byte("bar")), nil },
|
||||
func(p []byte) (int, error) { return copy(p, []byte("baz")), errors.New("oops") },
|
||||
},
|
||||
}
|
||||
|
||||
bgr := bgreader.New(rr)
|
||||
bgr.Start()
|
||||
buf, err := io.ReadAll(bgr)
|
||||
require.Equal(t, []byte("foobarbaz"), buf)
|
||||
require.EqualError(t, err, "oops")
|
||||
}
|
||||
|
||||
func TestBGReaderErrorWhenStopped(t *testing.T) {
|
||||
rr := &mockReader{
|
||||
readFuncs: []mockReadFunc{
|
||||
func(p []byte) (int, error) { return copy(p, []byte("foo")), nil },
|
||||
func(p []byte) (int, error) { return copy(p, []byte("bar")), nil },
|
||||
func(p []byte) (int, error) { return copy(p, []byte("baz")), errors.New("oops") },
|
||||
},
|
||||
}
|
||||
|
||||
bgr := bgreader.New(rr)
|
||||
buf, err := io.ReadAll(bgr)
|
||||
require.Equal(t, []byte("foobarbaz"), buf)
|
||||
require.EqualError(t, err, "oops")
|
||||
}
|
||||
|
||||
type numberReader struct {
|
||||
v uint8
|
||||
rng *rand.Rand
|
||||
}
|
||||
|
||||
func (nr *numberReader) Read(p []byte) (int, error) {
|
||||
n := nr.rng.Intn(len(p))
|
||||
for i := 0; i < n; i++ {
|
||||
p[i] = nr.v
|
||||
nr.v++
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// TestBGReaderStress stress tests BGReader by reading a lot of bytes in random sizes while randomly starting and
|
||||
// stopping the background worker from other goroutines.
|
||||
func TestBGReaderStress(t *testing.T) {
|
||||
nr := &numberReader{rng: rand.New(rand.NewSource(0))}
|
||||
bgr := bgreader.New(nr)
|
||||
|
||||
bytesRead := 0
|
||||
var expected uint8
|
||||
buf := make([]byte, 10_000)
|
||||
rng := rand.New(rand.NewSource(0))
|
||||
|
||||
for bytesRead < 1_000_000 {
|
||||
randomNumber := rng.Intn(100)
|
||||
switch {
|
||||
case randomNumber < 10:
|
||||
go bgr.Start()
|
||||
case randomNumber < 20:
|
||||
go bgr.Stop()
|
||||
default:
|
||||
n, err := bgr.Read(buf)
|
||||
require.NoError(t, err)
|
||||
for i := 0; i < n; i++ {
|
||||
require.Equal(t, expected, buf[i])
|
||||
expected++
|
||||
}
|
||||
bytesRead += n
|
||||
}
|
||||
}
|
||||
}
|
@ -63,7 +63,7 @@ func (c *PgConn) gssAuth() error {
|
||||
Data: nextData,
|
||||
}
|
||||
c.frontend.Send(gssResponse)
|
||||
err = c.frontend.Flush()
|
||||
err = c.flushWithPotentialWriteReadDeadlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -18,6 +18,7 @@ import (
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/iobufpool"
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
"github.com/jackc/pgx/v5/pgconn/internal/bgreader"
|
||||
"github.com/jackc/pgx/v5/pgconn/internal/ctxwatch"
|
||||
"github.com/jackc/pgx/v5/pgproto3"
|
||||
)
|
||||
@ -71,6 +72,8 @@ type PgConn struct {
|
||||
parameterStatuses map[string]string // parameters that have been reported by the server
|
||||
txStatus byte
|
||||
frontend *pgproto3.Frontend
|
||||
bgReader *bgreader.BGReader
|
||||
slowWriteTimer *time.Timer
|
||||
|
||||
config *Config
|
||||
|
||||
@ -293,7 +296,9 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
||||
|
||||
pgConn.parameterStatuses = make(map[string]string)
|
||||
pgConn.status = connStatusConnecting
|
||||
pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn)
|
||||
pgConn.bgReader = bgreader.New(pgConn.conn)
|
||||
pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), pgConn.bgReader.Start)
|
||||
pgConn.frontend = config.BuildFrontend(pgConn.bgReader, pgConn.conn)
|
||||
|
||||
startupMsg := pgproto3.StartupMessage{
|
||||
ProtocolVersion: pgproto3.ProtocolVersionNumber,
|
||||
@ -311,7 +316,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
||||
}
|
||||
|
||||
pgConn.frontend.Send(&startupMsg)
|
||||
if err := pgConn.frontend.Flush(); err != nil {
|
||||
if err := pgConn.flushWithPotentialWriteReadDeadlock(); err != nil {
|
||||
pgConn.conn.Close()
|
||||
return nil, &connectError{config: config, msg: "failed to write startup message", err: err}
|
||||
}
|
||||
@ -416,7 +421,7 @@ func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) {
|
||||
|
||||
func (pgConn *PgConn) txPasswordMessage(password string) (err error) {
|
||||
pgConn.frontend.Send(&pgproto3.PasswordMessage{Password: password})
|
||||
return pgConn.frontend.Flush()
|
||||
return pgConn.flushWithPotentialWriteReadDeadlock()
|
||||
}
|
||||
|
||||
func hexMD5(s string) string {
|
||||
@ -611,7 +616,7 @@ func (pgConn *PgConn) Close(ctx context.Context) error {
|
||||
//
|
||||
// See https://github.com/jackc/pgx/issues/637
|
||||
pgConn.frontend.Send(&pgproto3.Terminate{})
|
||||
pgConn.frontend.Flush()
|
||||
pgConn.flushWithPotentialWriteReadDeadlock()
|
||||
|
||||
return pgConn.conn.Close()
|
||||
}
|
||||
@ -638,7 +643,7 @@ func (pgConn *PgConn) asyncClose() {
|
||||
pgConn.conn.SetDeadline(deadline)
|
||||
|
||||
pgConn.frontend.Send(&pgproto3.Terminate{})
|
||||
pgConn.frontend.Flush()
|
||||
pgConn.flushWithPotentialWriteReadDeadlock()
|
||||
}()
|
||||
}
|
||||
|
||||
@ -813,7 +818,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
|
||||
pgConn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs})
|
||||
pgConn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name})
|
||||
pgConn.frontend.SendSync(&pgproto3.Sync{})
|
||||
err := pgConn.frontend.Flush()
|
||||
err := pgConn.flushWithPotentialWriteReadDeadlock()
|
||||
if err != nil {
|
||||
pgConn.asyncClose()
|
||||
return nil, err
|
||||
@ -995,7 +1000,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
|
||||
}
|
||||
|
||||
pgConn.frontend.SendQuery(&pgproto3.Query{String: sql})
|
||||
err := pgConn.frontend.Flush()
|
||||
err := pgConn.flushWithPotentialWriteReadDeadlock()
|
||||
if err != nil {
|
||||
pgConn.asyncClose()
|
||||
pgConn.contextWatcher.Unwatch()
|
||||
@ -1106,7 +1111,7 @@ func (pgConn *PgConn) execExtendedSuffix(result *ResultReader) {
|
||||
pgConn.frontend.SendExecute(&pgproto3.Execute{})
|
||||
pgConn.frontend.SendSync(&pgproto3.Sync{})
|
||||
|
||||
err := pgConn.frontend.Flush()
|
||||
err := pgConn.flushWithPotentialWriteReadDeadlock()
|
||||
if err != nil {
|
||||
pgConn.asyncClose()
|
||||
result.concludeCommand(CommandTag{}, err)
|
||||
@ -1139,7 +1144,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
|
||||
// Send copy to command
|
||||
pgConn.frontend.SendQuery(&pgproto3.Query{String: sql})
|
||||
|
||||
err := pgConn.frontend.Flush()
|
||||
err := pgConn.flushWithPotentialWriteReadDeadlock()
|
||||
if err != nil {
|
||||
pgConn.asyncClose()
|
||||
pgConn.unlock()
|
||||
@ -1197,7 +1202,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
||||
|
||||
// Send copy from query
|
||||
pgConn.frontend.SendQuery(&pgproto3.Query{String: sql})
|
||||
err := pgConn.frontend.Flush()
|
||||
err := pgConn.flushWithPotentialWriteReadDeadlock()
|
||||
if err != nil {
|
||||
pgConn.asyncClose()
|
||||
return CommandTag{}, err
|
||||
@ -1273,7 +1278,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
||||
} else {
|
||||
pgConn.frontend.Send(&pgproto3.CopyFail{Message: copyErr.Error()})
|
||||
}
|
||||
err = pgConn.frontend.Flush()
|
||||
err = pgConn.flushWithPotentialWriteReadDeadlock()
|
||||
if err != nil {
|
||||
pgConn.asyncClose()
|
||||
return CommandTag{}, err
|
||||
@ -1634,7 +1639,9 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
|
||||
|
||||
batch.buf = (&pgproto3.Sync{}).Encode(batch.buf)
|
||||
|
||||
pgConn.enterPotentialWriteReadDeadlock()
|
||||
_, err := pgConn.conn.Write(batch.buf)
|
||||
pgConn.exitPotentialWriteReadDeadlock()
|
||||
if err != nil {
|
||||
multiResult.closed = true
|
||||
multiResult.err = err
|
||||
@ -1688,6 +1695,26 @@ func (pgConn *PgConn) makeCommandTag(buf []byte) CommandTag {
|
||||
return CommandTag{s: string(buf)}
|
||||
}
|
||||
|
||||
// enterPotentialWriteReadDeadlock must be called before a write that could deadlock if the server is simultaneously
|
||||
// blocked writing to us.
|
||||
func (pgConn *PgConn) enterPotentialWriteReadDeadlock() {
|
||||
pgConn.slowWriteTimer.Reset(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
// exitPotentialWriteReadDeadlock must be called after a call to enterPotentialWriteReadDeadlock.
|
||||
func (pgConn *PgConn) exitPotentialWriteReadDeadlock() {
|
||||
if !pgConn.slowWriteTimer.Reset(time.Duration(math.MaxInt64)) {
|
||||
pgConn.slowWriteTimer.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func (pgConn *PgConn) flushWithPotentialWriteReadDeadlock() error {
|
||||
pgConn.enterPotentialWriteReadDeadlock()
|
||||
err := pgConn.frontend.Flush()
|
||||
pgConn.exitPotentialWriteReadDeadlock()
|
||||
return err
|
||||
}
|
||||
|
||||
// HijackedConn is the result of hijacking a connection.
|
||||
//
|
||||
// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning
|
||||
@ -1746,6 +1773,8 @@ func Construct(hc *HijackedConn) (*PgConn, error) {
|
||||
}
|
||||
|
||||
pgConn.contextWatcher = newContextWatcher(pgConn.conn)
|
||||
pgConn.bgReader = bgreader.New(pgConn.conn)
|
||||
pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), pgConn.bgReader.Start)
|
||||
|
||||
return pgConn, nil
|
||||
}
|
||||
@ -1868,7 +1897,7 @@ func (p *Pipeline) Flush() error {
|
||||
return errors.New("pipeline closed")
|
||||
}
|
||||
|
||||
err := p.conn.frontend.Flush()
|
||||
err := p.conn.flushWithPotentialWriteReadDeadlock()
|
||||
if err != nil {
|
||||
err = normalizeTimeoutError(p.ctx, err)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user