From f512b9688b99c1d25a47d6c92dc68bafb870986d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 8 Jul 2023 16:46:38 -0500 Subject: [PATCH] Add PgConn.SyncConn This provides a way to ensure it is safe to directly read or write to the underlying net.Conn. https://github.com/jackc/pgx/issues/1673 --- pgconn/internal/bgreader/bgreader.go | 47 ++++++++++++++++------------ pgconn/pgconn.go | 33 ++++++++++++++++--- pgconn/pgconn_test.go | 3 ++ pgproto3/frontend.go | 4 +++ 4 files changed, 63 insertions(+), 24 deletions(-) diff --git a/pgconn/internal/bgreader/bgreader.go b/pgconn/internal/bgreader/bgreader.go index aa1a3d39..e65c2c2b 100644 --- a/pgconn/internal/bgreader/bgreader.go +++ b/pgconn/internal/bgreader/bgreader.go @@ -9,18 +9,18 @@ import ( ) const ( - bgReaderStatusStopped = iota - bgReaderStatusRunning - bgReaderStatusStopping + StatusStopped = iota + StatusRunning + StatusStopping ) // BGReader is an io.Reader that can optionally buffer reads in the background. It is safe for concurrent use. type BGReader struct { r io.Reader - cond *sync.Cond - bgReaderStatus int32 - readResults []readResult + cond *sync.Cond + status int32 + readResults []readResult } type readResult struct { @@ -34,14 +34,14 @@ func (r *BGReader) Start() { r.cond.L.Lock() defer r.cond.L.Unlock() - switch r.bgReaderStatus { - case bgReaderStatusStopped: - r.bgReaderStatus = bgReaderStatusRunning + switch r.status { + case StatusStopped: + r.status = StatusRunning go r.bgRead() - case bgReaderStatusRunning: + case StatusRunning: // no-op - case bgReaderStatusStopping: - r.bgReaderStatus = bgReaderStatusRunning + case StatusStopping: + r.status = StatusRunning } } @@ -51,16 +51,23 @@ func (r *BGReader) Stop() { r.cond.L.Lock() defer r.cond.L.Unlock() - switch r.bgReaderStatus { - case bgReaderStatusStopped: + switch r.status { + case StatusStopped: // no-op - case bgReaderStatusRunning: - r.bgReaderStatus = bgReaderStatusStopping - case bgReaderStatusStopping: + case StatusRunning: + r.status = StatusStopping + case StatusStopping: // no-op } } +// Status returns the current status of the background reader. +func (r *BGReader) Status() int32 { + r.cond.L.Lock() + defer r.cond.L.Unlock() + return r.status +} + func (r *BGReader) bgRead() { keepReading := true for keepReading { @@ -70,8 +77,8 @@ func (r *BGReader) bgRead() { r.cond.L.Lock() r.readResults = append(r.readResults, readResult{buf: buf, err: err}) - if r.bgReaderStatus == bgReaderStatusStopping || err != nil { - r.bgReaderStatus = bgReaderStatusStopped + if r.status == StatusStopping || err != nil { + r.status = StatusStopped keepReading = false } r.cond.L.Unlock() @@ -89,7 +96,7 @@ func (r *BGReader) Read(p []byte) (int, error) { } // There are no unread background read results and the background reader is stopped. - if r.bgReaderStatus == bgReaderStatusStopped { + if r.status == StatusStopped { return r.r.Read(p) } diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 8a9f80a6..95a8a143 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -556,7 +556,8 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { return msg, nil } -// Conn returns the underlying net.Conn. This rarely necessary. +// Conn returns the underlying net.Conn. This rarely necessary. If the connection will be directly used for reading or +// writing then SyncConn should usually be called before Conn. func (pgConn *PgConn) Conn() net.Conn { return pgConn.conn } @@ -1740,6 +1741,30 @@ func (pgConn *PgConn) flushWithPotentialWriteReadDeadlock() error { return err } +// SyncConn prepares the underlying net.Conn for direct use. PgConn may internally buffer reads or use goroutines for +// background IO. This means that any direct use of the underlying net.Conn may be corrupted if a read is already +// buffered or a read is in progress. SyncConn drains read buffers and stops background IO. In some cases this may +// require sending a ping to the server. ctx can be used to cancel this operation. This should be called before any +// operation that will use the underlying net.Conn directly. e.g. Before Conn() or Hijack(). +// +// This should not be confused with the PostgreSQL protocol Sync message. +func (pgConn *PgConn) SyncConn(ctx context.Context) error { + for i := 0; i < 10; i++ { + if pgConn.bgReader.Status() == bgreader.StatusStopped && pgConn.frontend.ReadBufferLen() == 0 { + return nil + } + + err := pgConn.Ping(ctx) + if err != nil { + return fmt.Errorf("SyncConn: Ping failed while syncing conn: %w", err) + } + } + + // This should never happen. Only way I can imagine this occuring is if the server is constantly sending data such as + // LISTEN/NOTIFY or log notifications such that we never can get an empty buffer. + return errors.New("SyncConn: conn never synchronized") +} + // HijackedConn is the result of hijacking a connection. // // Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning @@ -1754,9 +1779,9 @@ type HijackedConn struct { Config *Config } -// Hijack extracts the internal connection data. pgConn must be in an idle state. pgConn is unusable after hijacking. -// Hijacking is typically only useful when using pgconn to establish a connection, but taking complete control of the -// raw connection after that (e.g. a load balancer or proxy). +// Hijack extracts the internal connection data. pgConn must be in an idle state. SyncConn should be called immediately +// before Hijack. pgConn is unusable after hijacking. Hijacking is typically only useful when using pgconn to establish +// a connection, but taking complete control of the raw connection after that (e.g. a load balancer or proxy). // // Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning // compatibility. diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index fbbf41eb..37a5512c 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -2319,6 +2319,9 @@ func TestHijackAndConstruct(t *testing.T) { origConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) + err = origConn.SyncConn(ctx) + require.NoError(t, err) + hc, err := origConn.Hijack() require.NoError(t, err) diff --git a/pgproto3/frontend.go b/pgproto3/frontend.go index 83dea963..33c3882a 100644 --- a/pgproto3/frontend.go +++ b/pgproto3/frontend.go @@ -361,3 +361,7 @@ func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, er func (f *Frontend) GetAuthType() uint32 { return f.authType } + +func (f *Frontend) ReadBufferLen() int { + return f.cr.wp - f.cr.rp +}