mirror of https://github.com/jackc/pgx.git
Add PgConn.SyncConn
This provides a way to ensure it is safe to directly read or write to the underlying net.Conn. https://github.com/jackc/pgx/issues/1673pull/1683/head
parent
05440f9d3f
commit
f512b9688b
|
@ -9,18 +9,18 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
bgReaderStatusStopped = iota
|
StatusStopped = iota
|
||||||
bgReaderStatusRunning
|
StatusRunning
|
||||||
bgReaderStatusStopping
|
StatusStopping
|
||||||
)
|
)
|
||||||
|
|
||||||
// BGReader is an io.Reader that can optionally buffer reads in the background. It is safe for concurrent use.
|
// BGReader is an io.Reader that can optionally buffer reads in the background. It is safe for concurrent use.
|
||||||
type BGReader struct {
|
type BGReader struct {
|
||||||
r io.Reader
|
r io.Reader
|
||||||
|
|
||||||
cond *sync.Cond
|
cond *sync.Cond
|
||||||
bgReaderStatus int32
|
status int32
|
||||||
readResults []readResult
|
readResults []readResult
|
||||||
}
|
}
|
||||||
|
|
||||||
type readResult struct {
|
type readResult struct {
|
||||||
|
@ -34,14 +34,14 @@ func (r *BGReader) Start() {
|
||||||
r.cond.L.Lock()
|
r.cond.L.Lock()
|
||||||
defer r.cond.L.Unlock()
|
defer r.cond.L.Unlock()
|
||||||
|
|
||||||
switch r.bgReaderStatus {
|
switch r.status {
|
||||||
case bgReaderStatusStopped:
|
case StatusStopped:
|
||||||
r.bgReaderStatus = bgReaderStatusRunning
|
r.status = StatusRunning
|
||||||
go r.bgRead()
|
go r.bgRead()
|
||||||
case bgReaderStatusRunning:
|
case StatusRunning:
|
||||||
// no-op
|
// no-op
|
||||||
case bgReaderStatusStopping:
|
case StatusStopping:
|
||||||
r.bgReaderStatus = bgReaderStatusRunning
|
r.status = StatusRunning
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -51,16 +51,23 @@ func (r *BGReader) Stop() {
|
||||||
r.cond.L.Lock()
|
r.cond.L.Lock()
|
||||||
defer r.cond.L.Unlock()
|
defer r.cond.L.Unlock()
|
||||||
|
|
||||||
switch r.bgReaderStatus {
|
switch r.status {
|
||||||
case bgReaderStatusStopped:
|
case StatusStopped:
|
||||||
// no-op
|
// no-op
|
||||||
case bgReaderStatusRunning:
|
case StatusRunning:
|
||||||
r.bgReaderStatus = bgReaderStatusStopping
|
r.status = StatusStopping
|
||||||
case bgReaderStatusStopping:
|
case StatusStopping:
|
||||||
// no-op
|
// no-op
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Status returns the current status of the background reader.
|
||||||
|
func (r *BGReader) Status() int32 {
|
||||||
|
r.cond.L.Lock()
|
||||||
|
defer r.cond.L.Unlock()
|
||||||
|
return r.status
|
||||||
|
}
|
||||||
|
|
||||||
func (r *BGReader) bgRead() {
|
func (r *BGReader) bgRead() {
|
||||||
keepReading := true
|
keepReading := true
|
||||||
for keepReading {
|
for keepReading {
|
||||||
|
@ -70,8 +77,8 @@ func (r *BGReader) bgRead() {
|
||||||
|
|
||||||
r.cond.L.Lock()
|
r.cond.L.Lock()
|
||||||
r.readResults = append(r.readResults, readResult{buf: buf, err: err})
|
r.readResults = append(r.readResults, readResult{buf: buf, err: err})
|
||||||
if r.bgReaderStatus == bgReaderStatusStopping || err != nil {
|
if r.status == StatusStopping || err != nil {
|
||||||
r.bgReaderStatus = bgReaderStatusStopped
|
r.status = StatusStopped
|
||||||
keepReading = false
|
keepReading = false
|
||||||
}
|
}
|
||||||
r.cond.L.Unlock()
|
r.cond.L.Unlock()
|
||||||
|
@ -89,7 +96,7 @@ func (r *BGReader) Read(p []byte) (int, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// There are no unread background read results and the background reader is stopped.
|
// There are no unread background read results and the background reader is stopped.
|
||||||
if r.bgReaderStatus == bgReaderStatusStopped {
|
if r.status == StatusStopped {
|
||||||
return r.r.Read(p)
|
return r.r.Read(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -556,7 +556,8 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
|
||||||
return msg, nil
|
return msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Conn returns the underlying net.Conn. This rarely necessary.
|
// Conn returns the underlying net.Conn. This rarely necessary. If the connection will be directly used for reading or
|
||||||
|
// writing then SyncConn should usually be called before Conn.
|
||||||
func (pgConn *PgConn) Conn() net.Conn {
|
func (pgConn *PgConn) Conn() net.Conn {
|
||||||
return pgConn.conn
|
return pgConn.conn
|
||||||
}
|
}
|
||||||
|
@ -1740,6 +1741,30 @@ func (pgConn *PgConn) flushWithPotentialWriteReadDeadlock() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SyncConn prepares the underlying net.Conn for direct use. PgConn may internally buffer reads or use goroutines for
|
||||||
|
// background IO. This means that any direct use of the underlying net.Conn may be corrupted if a read is already
|
||||||
|
// buffered or a read is in progress. SyncConn drains read buffers and stops background IO. In some cases this may
|
||||||
|
// require sending a ping to the server. ctx can be used to cancel this operation. This should be called before any
|
||||||
|
// operation that will use the underlying net.Conn directly. e.g. Before Conn() or Hijack().
|
||||||
|
//
|
||||||
|
// This should not be confused with the PostgreSQL protocol Sync message.
|
||||||
|
func (pgConn *PgConn) SyncConn(ctx context.Context) error {
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
if pgConn.bgReader.Status() == bgreader.StatusStopped && pgConn.frontend.ReadBufferLen() == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := pgConn.Ping(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("SyncConn: Ping failed while syncing conn: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// This should never happen. Only way I can imagine this occuring is if the server is constantly sending data such as
|
||||||
|
// LISTEN/NOTIFY or log notifications such that we never can get an empty buffer.
|
||||||
|
return errors.New("SyncConn: conn never synchronized")
|
||||||
|
}
|
||||||
|
|
||||||
// HijackedConn is the result of hijacking a connection.
|
// HijackedConn is the result of hijacking a connection.
|
||||||
//
|
//
|
||||||
// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning
|
// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning
|
||||||
|
@ -1754,9 +1779,9 @@ type HijackedConn struct {
|
||||||
Config *Config
|
Config *Config
|
||||||
}
|
}
|
||||||
|
|
||||||
// Hijack extracts the internal connection data. pgConn must be in an idle state. pgConn is unusable after hijacking.
|
// Hijack extracts the internal connection data. pgConn must be in an idle state. SyncConn should be called immediately
|
||||||
// Hijacking is typically only useful when using pgconn to establish a connection, but taking complete control of the
|
// before Hijack. pgConn is unusable after hijacking. Hijacking is typically only useful when using pgconn to establish
|
||||||
// raw connection after that (e.g. a load balancer or proxy).
|
// a connection, but taking complete control of the raw connection after that (e.g. a load balancer or proxy).
|
||||||
//
|
//
|
||||||
// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning
|
// Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning
|
||||||
// compatibility.
|
// compatibility.
|
||||||
|
|
|
@ -2319,6 +2319,9 @@ func TestHijackAndConstruct(t *testing.T) {
|
||||||
origConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
origConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = origConn.SyncConn(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
hc, err := origConn.Hijack()
|
hc, err := origConn.Hijack()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
|
|
@ -361,3 +361,7 @@ func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, er
|
||||||
func (f *Frontend) GetAuthType() uint32 {
|
func (f *Frontend) GetAuthType() uint32 {
|
||||||
return f.authType
|
return f.authType
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *Frontend) ReadBufferLen() int {
|
||||||
|
return f.cr.wp - f.cr.rp
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue