From 72b1dcff2fa42a0c15de9369116a626416f94610 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 25 Jun 2022 15:55:09 -0500 Subject: [PATCH] Add pgconn.CheckConn --- bench_test.go | 11 +++++++- internal/nbconn/nbconn.go | 49 +++++++++++++++++++++++++--------- internal/nbconn/nbconn_test.go | 29 ++++++++++++++++++++ pgconn/pgconn.go | 30 ++++++++++++++------- pgconn/pgconn_test.go | 28 +++++++++++++++++++ 5 files changed, 124 insertions(+), 23 deletions(-) diff --git a/bench_test.go b/bench_test.go index c441b374..31b3b38e 100644 --- a/bench_test.go +++ b/bench_test.go @@ -13,6 +13,7 @@ import ( "time" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/internal/nbconn" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgtype" "github.com/stretchr/testify/require" @@ -1236,7 +1237,7 @@ func BenchmarkSelectRowsPgConnExecPrepared(b *testing.B) { } type queryRecorder struct { - conn net.Conn + conn nbconn.Conn writeBuf []byte readCount int } @@ -1252,6 +1253,14 @@ func (qr *queryRecorder) Write(b []byte) (n int, err error) { return qr.conn.Write(b) } +func (qr *queryRecorder) BufferReadUntilBlock() error { + return qr.conn.BufferReadUntilBlock() +} + +func (qr *queryRecorder) Flush() error { + return qr.conn.Flush() +} + func (qr *queryRecorder) Close() error { return qr.conn.Close() } diff --git a/internal/nbconn/nbconn.go b/internal/nbconn/nbconn.go index 00d0e420..16c4b713 100644 --- a/internal/nbconn/nbconn.go +++ b/internal/nbconn/nbconn.go @@ -13,6 +13,7 @@ package nbconn import ( "crypto/tls" "errors" + "io" "net" "os" "sync" @@ -46,11 +47,16 @@ func (*wouldBlockError) Error() string { func (*wouldBlockError) Timeout() bool { return true } func (*wouldBlockError) Temporary() bool { return true } -// Conn is a net.Conn where Write never blocks and always succeeds. Flush must be called to actually write to the -// underlying connection. +// Conn is a net.Conn where Write never blocks and always succeeds. Flush or Read must be called to actually write to +// the underlying connection. type Conn interface { net.Conn + + // Flush flushes any buffered writes. Flush() error + + // BufferReadUntilBlock reads and buffers any sucessfully read bytes until the read would block. + BufferReadUntilBlock() error } // NetConn is a non-blocking net.Conn wrapper. It implements net.Conn. @@ -303,24 +309,35 @@ func (c *NetConn) flush() error { return nil } +func (c *NetConn) BufferReadUntilBlock() error { + for { + buf := iobufpool.Get(8 * 1024) + n, err := c.nonblockingRead(buf) + if n > 0 { + buf = buf[:n] + c.readQueue.pushBack(buf) + } + + if err != nil { + if errors.Is(err, ErrWouldBlock) { + return nil + } else { + return err + } + } + } +} + func (c *NetConn) bufferNonblockingRead() (stopChan chan struct{}, errChan chan error) { stopChan = make(chan struct{}) errChan = make(chan error, 1) go func() { for { - buf := iobufpool.Get(8 * 1024) - n, err := c.nonblockingRead(buf) - if n > 0 { - buf = buf[:n] - c.readQueue.pushBack(buf) - } - + err := c.BufferReadUntilBlock() if err != nil { - if !errors.Is(err, ErrWouldBlock) { - errChan <- err - return - } + errChan <- err + return } select { @@ -456,6 +473,11 @@ func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) { return n, err } + // syscall read did not return an error and 0 bytes were read means EOF. + if n == 0 { + return 0, io.EOF + } + return n, nil } @@ -494,6 +516,7 @@ type TLSConn struct { func (tc *TLSConn) Read(b []byte) (n int, err error) { return tc.tlsConn.Read(b) } func (tc *TLSConn) Write(b []byte) (n int, err error) { return tc.tlsConn.Write(b) } +func (tc *TLSConn) BufferReadUntilBlock() error { return tc.nbConn.BufferReadUntilBlock() } func (tc *TLSConn) Flush() error { return tc.nbConn.Flush() } func (tc *TLSConn) LocalAddr() net.Addr { return tc.tlsConn.LocalAddr() } func (tc *TLSConn) RemoteAddr() net.Addr { return tc.tlsConn.RemoteAddr() } diff --git a/internal/nbconn/nbconn_test.go b/internal/nbconn/nbconn_test.go index 2db47039..de32b9c7 100644 --- a/internal/nbconn/nbconn_test.go +++ b/internal/nbconn/nbconn_test.go @@ -2,6 +2,7 @@ package nbconn_test import ( "crypto/tls" + "errors" "io" "net" "strings" @@ -345,6 +346,34 @@ func TestNonBlockingRead(t *testing.T) { }) } +func TestBufferNonBlockingRead(t *testing.T) { + testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { + err := conn.BufferReadUntilBlock() + require.NoError(t, err) + + errChan := make(chan error, 1) + go func() { + _, err := remote.Write([]byte("okay")) + errChan <- err + }() + + for i := 0; i < 1000; i++ { + err = conn.BufferReadUntilBlock() + if !errors.Is(err, nbconn.ErrWouldBlock) { + break + } + time.Sleep(time.Millisecond) + } + require.NoError(t, err) + + buf := make([]byte, 4) + n, err := conn.Read(buf) + require.NoError(t, err) + require.EqualValues(t, 4, n) + require.Equal(t, []byte("okay"), buf) + }) +} + func TestReadPreviouslyBuffered(t *testing.T) { testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) { diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 002db39a..306b2e16 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -65,7 +65,7 @@ type NotificationHandler func(*PgConn, *Notification) // PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. type PgConn struct { - conn net.Conn // the underlying TCP or unix domain socket connection + conn nbconn.Conn // the non-blocking wrapper for the underlying TCP or unix domain socket connection pid uint32 // backend pid secretKey uint32 // key to use to send a cancel query message to the server parameterStatuses map[string]string // parameters that have been reported by the server @@ -230,22 +230,22 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig } return nil, &connectError{config: config, msg: "dial error", err: err} } - netConn = nbconn.NewNetConn(netConn, false) + nbNetConn := nbconn.NewNetConn(netConn, false) - pgConn.conn = netConn - pgConn.contextWatcher = newContextWatcher(netConn) + pgConn.conn = nbNetConn + pgConn.contextWatcher = newContextWatcher(nbNetConn) pgConn.contextWatcher.Watch(ctx) if fallbackConfig.TLSConfig != nil { - tlsConn, err := startTLS(netConn.(*nbconn.NetConn), fallbackConfig.TLSConfig) + nbTLSConn, err := startTLS(nbNetConn, fallbackConfig.TLSConfig) pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS. if err != nil { netConn.Close() return nil, &connectError{config: config, msg: "tls error", err: err} } - pgConn.conn = tlsConn - pgConn.contextWatcher = newContextWatcher(tlsConn) + pgConn.conn = nbTLSConn + pgConn.contextWatcher = newContextWatcher(nbTLSConn) pgConn.contextWatcher.Watch(ctx) } @@ -353,7 +353,7 @@ func newContextWatcher(conn net.Conn) *ctxwatch.ContextWatcher { ) } -func startTLS(conn *nbconn.NetConn, tlsConfig *tls.Config) (net.Conn, error) { +func startTLS(conn *nbconn.NetConn, tlsConfig *tls.Config) (*nbconn.TLSConn, error) { err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103}) if err != nil { return nil, err @@ -1596,6 +1596,18 @@ func (pgConn *PgConn) EscapeString(s string) (string, error) { return strings.Replace(s, "'", "''", -1), nil } +// CheckConn checks the underlying connection without writing any bytes. This is currently implemented by reading and +// buffering until the read would block or an error occurs. This can be used to check if the server has closed the +// connection. If this is done immediately before sending a query it reduces the chances a query will be sent that fails +// without the client knowing whether the server received it or not. +func (pgConn *PgConn) CheckConn() error { + err := pgConn.conn.BufferReadUntilBlock() + if err != nil && !errors.Is(err, nbconn.ErrWouldBlock) { + return err + } + return nil +} + // makeCommandTag makes a CommandTag. It does not retain a reference to buf or buf's underlying memory. func (pgConn *PgConn) makeCommandTag(buf []byte) CommandTag { ct := make([]byte, len(buf)) @@ -1608,7 +1620,7 @@ func (pgConn *PgConn) makeCommandTag(buf []byte) CommandTag { // Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning // compatibility. type HijackedConn struct { - Conn net.Conn // the underlying TCP or unix domain socket connection + Conn nbconn.Conn // the non-blocking wrapper of the underlying TCP or unix domain socket connection PID uint32 // backend pid SecretKey uint32 // key to use to send a cancel query message to the server ParameterStatuses map[string]string // parameters that have been reported by the server diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 87936225..f517f268 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -2059,6 +2059,34 @@ func TestConnLargeResponseWhileWritingDoesNotDeadlock(t *testing.T) { ensureConnValid(t, pgConn) } +func TestConnCheckConn(t *testing.T) { + t.Parallel() + + c1, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_TCP_CONN_STRING")) + require.NoError(t, err) + defer c1.Close(context.Background()) + + if c1.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)") + } + + err = c1.CheckConn() + require.NoError(t, err) + + c2, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_TCP_CONN_STRING")) + require.NoError(t, err) + defer c2.Close(context.Background()) + + _, err = c2.Exec(context.Background(), fmt.Sprintf("select pg_terminate_backend(%d)", c1.PID())).ReadAll() + require.NoError(t, err) + + // Give a little time for the signal to actually kill the backend. + time.Sleep(500 * time.Millisecond) + + err = c1.CheckConn() + require.Error(t, err) +} + func Example() { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) if err != nil {