mirror of https://github.com/jackc/pgx.git
Handle writes that could deadlock with reads from the server
This commit adds a background reader that can optionally buffer reads. It is used whenever a potentially blocking write is made to the server. The background reader is started on a slight delay so there should be no meaningful performance impact as it doesn't run for quick queries and its overhead is minimal relative to slower queries.pull/1644/head
parent
85136a8efe
commit
26c79eb215
|
@ -42,7 +42,7 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
|
|||
Data: sc.clientFirstMessage(),
|
||||
}
|
||||
c.frontend.Send(saslInitialResponse)
|
||||
err = c.frontend.Flush()
|
||||
err = c.flushWithPotentialWriteReadDeadlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -62,7 +62,7 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
|
|||
Data: []byte(sc.clientFinalMessage()),
|
||||
}
|
||||
c.frontend.Send(saslResponse)
|
||||
err = c.frontend.Flush()
|
||||
err = c.flushWithPotentialWriteReadDeadlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -0,0 +1,132 @@
|
|||
// Package bgreader provides a io.Reader that can optionally buffer reads in the background.
|
||||
package bgreader
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/iobufpool"
|
||||
)
|
||||
|
||||
const (
|
||||
bgReaderStatusStopped = iota
|
||||
bgReaderStatusRunning
|
||||
bgReaderStatusStopping
|
||||
)
|
||||
|
||||
// BGReader is an io.Reader that can optionally buffer reads in the background. It is safe for concurrent use.
|
||||
type BGReader struct {
|
||||
r io.Reader
|
||||
|
||||
cond *sync.Cond
|
||||
bgReaderStatus int32
|
||||
readResults []readResult
|
||||
}
|
||||
|
||||
type readResult struct {
|
||||
buf *[]byte
|
||||
err error
|
||||
}
|
||||
|
||||
// Start starts the backgrounder reader. If the background reader is already running this is a no-op. The background
|
||||
// reader will stop automatically when the underlying reader returns an error.
|
||||
func (r *BGReader) Start() {
|
||||
r.cond.L.Lock()
|
||||
defer r.cond.L.Unlock()
|
||||
|
||||
switch r.bgReaderStatus {
|
||||
case bgReaderStatusStopped:
|
||||
r.bgReaderStatus = bgReaderStatusRunning
|
||||
go r.bgRead()
|
||||
case bgReaderStatusRunning:
|
||||
// no-op
|
||||
case bgReaderStatusStopping:
|
||||
r.bgReaderStatus = bgReaderStatusRunning
|
||||
}
|
||||
}
|
||||
|
||||
// Stop tells the background reader to stop after the in progress Read returns. It is safe to call Stop when the
|
||||
// background reader is not running.
|
||||
func (r *BGReader) Stop() {
|
||||
r.cond.L.Lock()
|
||||
defer r.cond.L.Unlock()
|
||||
|
||||
switch r.bgReaderStatus {
|
||||
case bgReaderStatusStopped:
|
||||
// no-op
|
||||
case bgReaderStatusRunning:
|
||||
r.bgReaderStatus = bgReaderStatusStopping
|
||||
case bgReaderStatusStopping:
|
||||
// no-op
|
||||
}
|
||||
}
|
||||
|
||||
func (r *BGReader) bgRead() {
|
||||
keepReading := true
|
||||
for keepReading {
|
||||
buf := iobufpool.Get(8192)
|
||||
n, err := r.r.Read(*buf)
|
||||
*buf = (*buf)[:n]
|
||||
|
||||
r.cond.L.Lock()
|
||||
r.readResults = append(r.readResults, readResult{buf: buf, err: err})
|
||||
if r.bgReaderStatus == bgReaderStatusStopping || err != nil {
|
||||
r.bgReaderStatus = bgReaderStatusStopped
|
||||
keepReading = false
|
||||
}
|
||||
r.cond.L.Unlock()
|
||||
r.cond.Broadcast()
|
||||
}
|
||||
}
|
||||
|
||||
// Read implements the io.Reader interface.
|
||||
func (r *BGReader) Read(p []byte) (int, error) {
|
||||
r.cond.L.Lock()
|
||||
defer r.cond.L.Unlock()
|
||||
|
||||
if len(r.readResults) > 0 {
|
||||
return r.readFromReadResults(p)
|
||||
}
|
||||
|
||||
// There are no unread background read results and the background reader is stopped.
|
||||
if r.bgReaderStatus == bgReaderStatusStopped {
|
||||
return r.r.Read(p)
|
||||
}
|
||||
|
||||
// Wait for results from the background reader
|
||||
for len(r.readResults) == 0 {
|
||||
r.cond.Wait()
|
||||
}
|
||||
return r.readFromReadResults(p)
|
||||
}
|
||||
|
||||
// readBackgroundResults reads a result previously read by the background reader. r.cond.L must be held.
|
||||
func (r *BGReader) readFromReadResults(p []byte) (int, error) {
|
||||
buf := r.readResults[0].buf
|
||||
var err error
|
||||
|
||||
n := copy(p, *buf)
|
||||
if n == len(*buf) {
|
||||
err = r.readResults[0].err
|
||||
iobufpool.Put(buf)
|
||||
if len(r.readResults) == 1 {
|
||||
r.readResults = nil
|
||||
} else {
|
||||
r.readResults = r.readResults[1:]
|
||||
}
|
||||
} else {
|
||||
*buf = (*buf)[n:]
|
||||
r.readResults[0].buf = buf
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
func New(r io.Reader) *BGReader {
|
||||
return &BGReader{
|
||||
r: r,
|
||||
cond: &sync.Cond{
|
||||
L: &sync.Mutex{},
|
||||
},
|
||||
}
|
||||
}
|
|
@ -0,0 +1,140 @@
|
|||
package bgreader_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn/internal/bgreader"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBGReaderReadWhenStopped(t *testing.T) {
|
||||
r := bytes.NewReader([]byte("foo bar baz"))
|
||||
bgr := bgreader.New(r)
|
||||
buf, err := io.ReadAll(bgr)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("foo bar baz"), buf)
|
||||
}
|
||||
|
||||
func TestBGReaderReadWhenStarted(t *testing.T) {
|
||||
r := bytes.NewReader([]byte("foo bar baz"))
|
||||
bgr := bgreader.New(r)
|
||||
bgr.Start()
|
||||
buf, err := io.ReadAll(bgr)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("foo bar baz"), buf)
|
||||
}
|
||||
|
||||
type mockReadFunc func(p []byte) (int, error)
|
||||
|
||||
type mockReader struct {
|
||||
readFuncs []mockReadFunc
|
||||
}
|
||||
|
||||
func (r *mockReader) Read(p []byte) (int, error) {
|
||||
if len(r.readFuncs) == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
fn := r.readFuncs[0]
|
||||
r.readFuncs = r.readFuncs[1:]
|
||||
|
||||
return fn(p)
|
||||
}
|
||||
|
||||
func TestBGReaderReadWaitsForBackgroundRead(t *testing.T) {
|
||||
rr := &mockReader{
|
||||
readFuncs: []mockReadFunc{
|
||||
func(p []byte) (int, error) { time.Sleep(1 * time.Second); return copy(p, []byte("foo")), nil },
|
||||
func(p []byte) (int, error) { return copy(p, []byte("bar")), nil },
|
||||
func(p []byte) (int, error) { return copy(p, []byte("baz")), nil },
|
||||
},
|
||||
}
|
||||
bgr := bgreader.New(rr)
|
||||
bgr.Start()
|
||||
buf := make([]byte, 3)
|
||||
n, err := bgr.Read(buf)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 3, n)
|
||||
require.Equal(t, []byte("foo"), buf)
|
||||
}
|
||||
|
||||
func TestBGReaderErrorWhenStarted(t *testing.T) {
|
||||
rr := &mockReader{
|
||||
readFuncs: []mockReadFunc{
|
||||
func(p []byte) (int, error) { return copy(p, []byte("foo")), nil },
|
||||
func(p []byte) (int, error) { return copy(p, []byte("bar")), nil },
|
||||
func(p []byte) (int, error) { return copy(p, []byte("baz")), errors.New("oops") },
|
||||
},
|
||||
}
|
||||
|
||||
bgr := bgreader.New(rr)
|
||||
bgr.Start()
|
||||
buf, err := io.ReadAll(bgr)
|
||||
require.Equal(t, []byte("foobarbaz"), buf)
|
||||
require.EqualError(t, err, "oops")
|
||||
}
|
||||
|
||||
func TestBGReaderErrorWhenStopped(t *testing.T) {
|
||||
rr := &mockReader{
|
||||
readFuncs: []mockReadFunc{
|
||||
func(p []byte) (int, error) { return copy(p, []byte("foo")), nil },
|
||||
func(p []byte) (int, error) { return copy(p, []byte("bar")), nil },
|
||||
func(p []byte) (int, error) { return copy(p, []byte("baz")), errors.New("oops") },
|
||||
},
|
||||
}
|
||||
|
||||
bgr := bgreader.New(rr)
|
||||
buf, err := io.ReadAll(bgr)
|
||||
require.Equal(t, []byte("foobarbaz"), buf)
|
||||
require.EqualError(t, err, "oops")
|
||||
}
|
||||
|
||||
type numberReader struct {
|
||||
v uint8
|
||||
rng *rand.Rand
|
||||
}
|
||||
|
||||
func (nr *numberReader) Read(p []byte) (int, error) {
|
||||
n := nr.rng.Intn(len(p))
|
||||
for i := 0; i < n; i++ {
|
||||
p[i] = nr.v
|
||||
nr.v++
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// TestBGReaderStress stress tests BGReader by reading a lot of bytes in random sizes while randomly starting and
|
||||
// stopping the background worker from other goroutines.
|
||||
func TestBGReaderStress(t *testing.T) {
|
||||
nr := &numberReader{rng: rand.New(rand.NewSource(0))}
|
||||
bgr := bgreader.New(nr)
|
||||
|
||||
bytesRead := 0
|
||||
var expected uint8
|
||||
buf := make([]byte, 10_000)
|
||||
rng := rand.New(rand.NewSource(0))
|
||||
|
||||
for bytesRead < 1_000_000 {
|
||||
randomNumber := rng.Intn(100)
|
||||
switch {
|
||||
case randomNumber < 10:
|
||||
go bgr.Start()
|
||||
case randomNumber < 20:
|
||||
go bgr.Stop()
|
||||
default:
|
||||
n, err := bgr.Read(buf)
|
||||
require.NoError(t, err)
|
||||
for i := 0; i < n; i++ {
|
||||
require.Equal(t, expected, buf[i])
|
||||
expected++
|
||||
}
|
||||
bytesRead += n
|
||||
}
|
||||
}
|
||||
}
|
|
@ -63,7 +63,7 @@ func (c *PgConn) gssAuth() error {
|
|||
Data: nextData,
|
||||
}
|
||||
c.frontend.Send(gssResponse)
|
||||
err = c.frontend.Flush()
|
||||
err = c.flushWithPotentialWriteReadDeadlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@ import (
|
|||
|
||||
"github.com/jackc/pgx/v5/internal/iobufpool"
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
"github.com/jackc/pgx/v5/pgconn/internal/bgreader"
|
||||
"github.com/jackc/pgx/v5/pgconn/internal/ctxwatch"
|
||||
"github.com/jackc/pgx/v5/pgproto3"
|
||||
)
|
||||
|
@ -71,6 +72,8 @@ type PgConn struct {
|
|||
parameterStatuses map[string]string // parameters that have been reported by the server
|
||||
txStatus byte
|
||||
frontend *pgproto3.Frontend
|
||||
bgReader *bgreader.BGReader
|
||||
slowWriteTimer *time.Timer
|
||||
|
||||
config *Config
|
||||
|
||||
|
@ -293,7 +296,9 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
|||
|
||||
pgConn.parameterStatuses = make(map[string]string)
|
||||
pgConn.status = connStatusConnecting
|
||||
pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn)
|
||||
pgConn.bgReader = bgreader.New(pgConn.conn)
|
||||
pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), pgConn.bgReader.Start)
|
||||
pgConn.frontend = config.BuildFrontend(pgConn.bgReader, pgConn.conn)
|
||||
|
||||
startupMsg := pgproto3.StartupMessage{
|
||||
ProtocolVersion: pgproto3.ProtocolVersionNumber,
|
||||
|
@ -311,7 +316,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
|||
}
|
||||
|
||||
pgConn.frontend.Send(&startupMsg)
|
||||
if err := pgConn.frontend.Flush(); err != nil {
|
||||
if err := pgConn.flushWithPotentialWriteReadDeadlock(); err != nil {
|
||||
pgConn.conn.Close()
|
||||
return nil, &connectError{config: config, msg: "failed to write startup message", err: err}
|
||||
}
|
||||
|
@ -416,7 +421,7 @@ func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) {
|
|||
|
||||
func (pgConn *PgConn) txPasswordMessage(password string) (err error) {
|
||||
pgConn.frontend.Send(&pgproto3.PasswordMessage{Password: password})
|
||||
return pgConn.frontend.Flush()
|
||||
return pgConn.flushWithPotentialWriteReadDeadlock()
|
||||
}
|
||||
|
||||
func hexMD5(s string) string {
|
||||
|
@ -611,7 +616,7 @@ func (pgConn *PgConn) Close(ctx context.Context) error {
|
|||
//
|
||||
// See https://github.com/jackc/pgx/issues/637
|
||||
pgConn.frontend.Send(&pgproto3.Terminate{})
|
||||
pgConn.frontend.Flush()
|
||||
pgConn.flushWithPotentialWriteReadDeadlock()
|
||||
|
||||
return pgConn.conn.Close()
|
||||
}
|
||||
|
@ -638,7 +643,7 @@ func (pgConn *PgConn) asyncClose() {
|
|||
pgConn.conn.SetDeadline(deadline)
|
||||
|
||||
pgConn.frontend.Send(&pgproto3.Terminate{})
|
||||
pgConn.frontend.Flush()
|
||||
pgConn.flushWithPotentialWriteReadDeadlock()
|
||||
}()
|
||||
}
|
||||
|
||||
|
@ -813,7 +818,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
|
|||
pgConn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs})
|
||||
pgConn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name})
|
||||
pgConn.frontend.SendSync(&pgproto3.Sync{})
|
||||
err := pgConn.frontend.Flush()
|
||||
err := pgConn.flushWithPotentialWriteReadDeadlock()
|
||||
if err != nil {
|
||||
pgConn.asyncClose()
|
||||
return nil, err
|
||||
|
@ -995,7 +1000,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
|
|||
}
|
||||
|
||||
pgConn.frontend.SendQuery(&pgproto3.Query{String: sql})
|
||||
err := pgConn.frontend.Flush()
|
||||
err := pgConn.flushWithPotentialWriteReadDeadlock()
|
||||
if err != nil {
|
||||
pgConn.asyncClose()
|
||||
pgConn.contextWatcher.Unwatch()
|
||||
|
@ -1106,7 +1111,7 @@ func (pgConn *PgConn) execExtendedSuffix(result *ResultReader) {
|
|||
pgConn.frontend.SendExecute(&pgproto3.Execute{})
|
||||
pgConn.frontend.SendSync(&pgproto3.Sync{})
|
||||
|
||||
err := pgConn.frontend.Flush()
|
||||
err := pgConn.flushWithPotentialWriteReadDeadlock()
|
||||
if err != nil {
|
||||
pgConn.asyncClose()
|
||||
result.concludeCommand(CommandTag{}, err)
|
||||
|
@ -1139,7 +1144,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
|
|||
// Send copy to command
|
||||
pgConn.frontend.SendQuery(&pgproto3.Query{String: sql})
|
||||
|
||||
err := pgConn.frontend.Flush()
|
||||
err := pgConn.flushWithPotentialWriteReadDeadlock()
|
||||
if err != nil {
|
||||
pgConn.asyncClose()
|
||||
pgConn.unlock()
|
||||
|
@ -1197,7 +1202,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||
|
||||
// Send copy from query
|
||||
pgConn.frontend.SendQuery(&pgproto3.Query{String: sql})
|
||||
err := pgConn.frontend.Flush()
|
||||
err := pgConn.flushWithPotentialWriteReadDeadlock()
|
||||
if err != nil {
|
||||
pgConn.asyncClose()
|
||||
return CommandTag{}, err
|
||||
|
@ -1273,7 +1278,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||
} else {
|
||||
pgConn.frontend.Send(&pgproto3.CopyFail{Message: copyErr.Error()})
|
||||
}
|
||||
err = pgConn.frontend.Flush()
|
||||
err = pgConn.flushWithPotentialWriteReadDeadlock()
|
||||
if err != nil {
|
||||
pgConn.asyncClose()
|
||||
return CommandTag{}, err
|
||||
|
@ -1634,7 +1639,9 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
|
|||
|
||||
batch.buf = (&pgproto3.Sync{}).Encode(batch.buf)
|
||||
|
||||
pgConn.enterPotentialWriteReadDeadlock()
|
||||
_, err := pgConn.conn.Write(batch.buf)
|
||||
pgConn.exitPotentialWriteReadDeadlock()
|
||||
if err != nil {
|
||||
multiResult.closed = true
|
||||
multiResult.err = err
|
||||
|
@ -1688,6 +1695,26 @@ func (pgConn *PgConn) makeCommandTag(buf []byte) CommandTag {
|
|||
return CommandTag{s: string(buf)}
|
||||
}
|
||||
|
||||
// enterPotentialWriteReadDeadlock must be called before a write that could deadlock if the server is simultaneously
|
||||
// blocked writing to us.
|
||||
func (pgConn *PgConn) enterPotentialWriteReadDeadlock() {
|
||||
pgConn.slowWriteTimer.Reset(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
// exitPotentialWriteReadDeadlock must be called after a call to enterPotentialWriteReadDeadlock.
|
||||
func (pgConn *PgConn) exitPotentialWriteReadDeadlock() {
|
||||
if !pgConn.slowWriteTimer.Reset(time.Duration(math.MaxInt64)) {
|
||||
pgConn.slowWriteTimer.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func (pgConn *PgConn) flushWithPotentialWriteReadDeadlock() error {
|
||||
pgConn.enterPotentialWriteReadDeadlock()
|
||||
err := pgConn.frontend.Flush()
|
||||
pgConn.exitPotentialWriteReadDeadlock()
|
||||
return err
|
||||
}
|
||||
|
||||
// HijackedConn is the result of hijacking a connection.
|
||||
//
|
||||
// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning
|
||||
|
@ -1746,6 +1773,8 @@ func Construct(hc *HijackedConn) (*PgConn, error) {
|
|||
}
|
||||
|
||||
pgConn.contextWatcher = newContextWatcher(pgConn.conn)
|
||||
pgConn.bgReader = bgreader.New(pgConn.conn)
|
||||
pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), pgConn.bgReader.Start)
|
||||
|
||||
return pgConn, nil
|
||||
}
|
||||
|
@ -1868,7 +1897,7 @@ func (p *Pipeline) Flush() error {
|
|||
return errors.New("pipeline closed")
|
||||
}
|
||||
|
||||
err := p.conn.frontend.Flush()
|
||||
err := p.conn.flushWithPotentialWriteReadDeadlock()
|
||||
if err != nil {
|
||||
err = normalizeTimeoutError(p.ctx, err)
|
||||
|
||||
|
|
Loading…
Reference in New Issue