From 5714896b1047d2415448d32caa87547625509783 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 21 May 2022 11:06:44 -0500 Subject: [PATCH] Restructure sending messages Use an internal buffer in pgproto3.Frontend and pgproto3.Backend instead of directly writing to the underlying net.Conn. This will allow tracing messages as well as simplify pipeline mode. --- internal/pgmock/pgmock.go | 3 +- pgconn/config.go | 2 +- pgconn/errors.go | 17 ------ pgconn/frontend_test.go | 70 ---------------------- pgconn/pgconn.go | 121 +++++++++++--------------------------- pgconn/pgconn_test.go | 43 -------------- pgproto3/backend.go | 28 +++++++-- pgproto3/frontend.go | 28 +++++++-- pgproto3/pgproto3.go | 19 ++++++ 9 files changed, 105 insertions(+), 226 deletions(-) delete mode 100644 pgconn/frontend_test.go diff --git a/internal/pgmock/pgmock.go b/internal/pgmock/pgmock.go index 97dd024d..c82d7ffc 100644 --- a/internal/pgmock/pgmock.go +++ b/internal/pgmock/pgmock.go @@ -97,7 +97,8 @@ type sendMessageStep struct { } func (e *sendMessageStep) Step(backend *pgproto3.Backend) error { - return backend.Send(e.msg) + backend.Send(e.msg) + return backend.Flush() } func SendMessage(msg pgproto3.BackendMessage) Step { diff --git a/pgconn/config.go b/pgconn/config.go index 8a22d4ce..bfec11d4 100644 --- a/pgconn/config.go +++ b/pgconn/config.go @@ -222,7 +222,7 @@ func ParseConfig(connString string) (*Config, error) { User: settings["user"], Password: settings["password"], RuntimeParams: make(map[string]string), - BuildFrontend: func(r io.Reader, w io.Writer) Frontend { + BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend { return pgproto3.NewFrontend(r, w) }, } diff --git a/pgconn/errors.go b/pgconn/errors.go index a32b29c9..030f7e0a 100644 --- a/pgconn/errors.go +++ b/pgconn/errors.go @@ -178,23 +178,6 @@ func newContextAlreadyDoneError(ctx context.Context) (err error) { return &errTimeout{&contextAlreadyDoneError{err: ctx.Err()}} } -type writeError struct { - err error - safeToRetry bool -} - -func (e *writeError) Error() string { - return fmt.Sprintf("write failed: %s", e.err.Error()) -} - -func (e *writeError) SafeToRetry() bool { - return e.safeToRetry -} - -func (e *writeError) Unwrap() error { - return e.err -} - func redactPW(connString string) string { if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") { if u, err := url.Parse(connString); err == nil { diff --git a/pgconn/frontend_test.go b/pgconn/frontend_test.go deleted file mode 100644 index 439d3251..00000000 --- a/pgconn/frontend_test.go +++ /dev/null @@ -1,70 +0,0 @@ -package pgconn_test - -import ( - "context" - "io" - "os" - "testing" - - "github.com/jackc/pgx/v5/pgconn" - "github.com/jackc/pgx/v5/pgproto3" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// frontendWrapper allows to hijack a regular frontend, and inject a specific response -type frontendWrapper struct { - front pgconn.Frontend - - msg pgproto3.BackendMessage -} - -// frontendWrapper implements the pgconn.Frontend interface -var _ pgconn.Frontend = (*frontendWrapper)(nil) - -func (f *frontendWrapper) Receive() (pgproto3.BackendMessage, error) { - if f.msg != nil { - return f.msg, nil - } - - return f.front.Receive() -} - -func TestFrontendFatalErrExec(t *testing.T) { - t.Parallel() - - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - - buildFrontend := config.BuildFrontend - var front *frontendWrapper - - config.BuildFrontend = func(r io.Reader, w io.Writer) pgconn.Frontend { - wrapped := buildFrontend(r, w) - front = &frontendWrapper{wrapped, nil} - - return front - } - - conn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - require.NotNil(t, conn) - require.NotNil(t, front) - - // set frontend to return a "FATAL" message on next call - front.msg = &pgproto3.ErrorResponse{Severity: "FATAL", Message: "unit testing fatal error"} - - _, err = conn.Exec(context.Background(), "SELECT 1").ReadAll() - assert.Error(t, err) - - err = conn.Close(context.Background()) - assert.NoError(t, err) - - select { - case <-conn.CleanupDone(): - t.Log("ok, CleanupDone() is not blocking") - - default: - assert.Fail(t, "connection closed but CleanupDone() still blocking") - } -} diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index a23b1daf..2cbf8c50 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -29,8 +29,6 @@ const ( connStatusBusy ) -const wbufLen = 1024 - // Notice represents a notice response message reported by the PostgreSQL server. Be aware that this is distinct from // LISTEN/NOTIFY notification. type Notice PgError @@ -50,7 +48,7 @@ type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) type LookupFunc func(ctx context.Context, host string) (addrs []string, err error) // BuildFrontendFunc is a function that can be used to create Frontend implementation for connection. -type BuildFrontendFunc func(r io.Reader, w io.Writer) Frontend +type BuildFrontendFunc func(r io.Reader, w io.Writer) *pgproto3.Frontend // NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at // any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin @@ -64,11 +62,6 @@ type NoticeHandler func(*PgConn, *Notice) // notice event. type NotificationHandler func(*PgConn, *Notification) -// Frontend used to receive messages from backend. -type Frontend interface { - Receive() (pgproto3.BackendMessage, error) -} - // PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. type PgConn struct { conn net.Conn // the underlying TCP or unix domain socket connection @@ -76,7 +69,7 @@ type PgConn struct { secretKey uint32 // key to use to send a cancel query message to the server parameterStatuses map[string]string // parameters that have been reported by the server txStatus byte - frontend Frontend + frontend *pgproto3.Frontend config *Config @@ -90,7 +83,6 @@ type PgConn struct { peekedMsg pgproto3.BackendMessage // Reusable / preallocated resources - wbuf []byte // write buffer resultReader ResultReader multiResultReader MultiResultReader contextWatcher *ctxwatch.ContextWatcher @@ -230,7 +222,6 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) { pgConn := new(PgConn) pgConn.config = config - pgConn.wbuf = make([]byte, 0, wbufLen) pgConn.cleanupDone = make(chan struct{}) var err error @@ -282,7 +273,8 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig startupMsg.Parameters["database"] = config.Database } - if _, err := pgConn.conn.Write(startupMsg.Encode(pgConn.wbuf)); err != nil { + pgConn.frontend.Send(&startupMsg) + if err := pgConn.frontend.Flush(); err != nil { pgConn.conn.Close() return nil, &connectError{config: config, msg: "failed to write startup message", err: err} } @@ -383,9 +375,8 @@ func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { } func (pgConn *PgConn) txPasswordMessage(password string) (err error) { - msg := &pgproto3.PasswordMessage{Password: password} - _, err = pgConn.conn.Write(msg.Encode(pgConn.wbuf)) - return err + pgConn.frontend.Send(&pgproto3.PasswordMessage{Password: password}) + return pgConn.frontend.Flush() } func hexMD5(s string) string { @@ -412,36 +403,6 @@ func (pgConn *PgConn) signalMessage() chan struct{} { return ch } -// SendBytes sends buf to the PostgreSQL server. It must only be used when the connection is not busy. e.g. It is as -// error to call SendBytes while reading the result of a query. -// -// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly. -// See https://www.postgresql.org/docs/current/protocol.html. -func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { - if err := pgConn.lock(); err != nil { - return err - } - defer pgConn.unlock() - - if ctx != context.Background() { - select { - case <-ctx.Done(): - return newContextAlreadyDoneError(ctx) - default: - } - pgConn.contextWatcher.Watch(ctx) - defer pgConn.contextWatcher.Unwatch() - } - - n, err := pgConn.conn.Write(buf) - if err != nil { - pgConn.asyncClose() - return &writeError{err: err, safeToRetry: n == 0} - } - - return nil -} - // ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the // connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages // are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger @@ -797,15 +758,13 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ defer pgConn.contextWatcher.Unwatch() } - buf := pgConn.wbuf - buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) - buf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(buf) - buf = (&pgproto3.Sync{}).Encode(buf) - - n, err := pgConn.conn.Write(buf) + pgConn.frontend.Send(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}) + pgConn.frontend.Send(&pgproto3.Describe{ObjectType: 'S', Name: name}) + pgConn.frontend.Send(&pgproto3.Sync{}) + err := pgConn.frontend.Flush() if err != nil { pgConn.asyncClose() - return nil, &writeError{err: err, safeToRetry: n == 0} + return nil, err } psd := &StatementDescription{Name: name, SQL: sql} @@ -971,15 +930,13 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { pgConn.contextWatcher.Watch(ctx) } - buf := pgConn.wbuf - buf = (&pgproto3.Query{String: sql}).Encode(buf) - - n, err := pgConn.conn.Write(buf) + pgConn.frontend.Send(&pgproto3.Query{String: sql}) + err := pgConn.frontend.Flush() if err != nil { pgConn.asyncClose() pgConn.contextWatcher.Unwatch() multiResult.closed = true - multiResult.err = &writeError{err: err, safeToRetry: n == 0} + multiResult.err = err pgConn.unlock() return multiResult } @@ -1045,11 +1002,10 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] return result } - buf := pgConn.wbuf - buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) - buf = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) + pgConn.frontend.Send(&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}) + pgConn.frontend.Send(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) - pgConn.execExtendedSuffix(buf, result) + pgConn.execExtendedSuffix(result) return result } @@ -1072,10 +1028,9 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa return result } - buf := pgConn.wbuf - buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) + pgConn.frontend.Send(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) - pgConn.execExtendedSuffix(buf, result) + pgConn.execExtendedSuffix(result) return result } @@ -1115,15 +1070,15 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by return result } -func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { - buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf) - buf = (&pgproto3.Execute{}).Encode(buf) - buf = (&pgproto3.Sync{}).Encode(buf) +func (pgConn *PgConn) execExtendedSuffix(result *ResultReader) { + pgConn.frontend.Send(&pgproto3.Describe{ObjectType: 'P'}) + pgConn.frontend.Send(&pgproto3.Execute{}) + pgConn.frontend.Send(&pgproto3.Sync{}) - n, err := pgConn.conn.Write(buf) + err := pgConn.frontend.Flush() if err != nil { pgConn.asyncClose() - result.concludeCommand(CommandTag{}, &writeError{err: err, safeToRetry: n == 0}) + result.concludeCommand(CommandTag{}, err) pgConn.contextWatcher.Unwatch() result.closed = true pgConn.unlock() @@ -1151,14 +1106,13 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm } // Send copy to command - buf := pgConn.wbuf - buf = (&pgproto3.Query{String: sql}).Encode(buf) + pgConn.frontend.Send(&pgproto3.Query{String: sql}) - n, err := pgConn.conn.Write(buf) + err := pgConn.frontend.Flush() if err != nil { pgConn.asyncClose() pgConn.unlock() - return CommandTag{}, &writeError{err: err, safeToRetry: n == 0} + return CommandTag{}, err } // Read results @@ -1211,13 +1165,12 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } // Send copy to command - buf := pgConn.wbuf - buf = (&pgproto3.Query{String: sql}).Encode(buf) + pgConn.frontend.Send(&pgproto3.Query{String: sql}) - n, err := pgConn.conn.Write(buf) + err := pgConn.frontend.Flush() if err != nil { pgConn.asyncClose() - return CommandTag{}, &writeError{err: err, safeToRetry: n == 0} + return CommandTag{}, err } // Send copy data @@ -1280,15 +1233,12 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } close(abortCopyChan) - buf = buf[:0] if copyErr == io.EOF || pgErr != nil { - copyDone := &pgproto3.CopyDone{} - buf = copyDone.Encode(buf) + pgConn.frontend.Send(&pgproto3.CopyDone{}) } else { - copyFail := &pgproto3.CopyFail{Message: copyErr.Error()} - buf = copyFail.Encode(buf) + pgConn.frontend.Send(&pgproto3.CopyFail{Message: copyErr.Error()}) } - _, err = pgConn.conn.Write(buf) + err = pgConn.frontend.Flush() if err != nil { pgConn.asyncClose() return CommandTag{}, err @@ -1692,7 +1642,7 @@ type HijackedConn struct { SecretKey uint32 // key to use to send a cancel query message to the server ParameterStatuses map[string]string // parameters that have been reported by the server TxStatus byte - Frontend Frontend + Frontend *pgproto3.Frontend Config *Config } @@ -1736,7 +1686,6 @@ func Construct(hc *HijackedConn) (*PgConn, error) { status: connStatusIdle, - wbuf: make([]byte, 0, wbufLen), cleanupDone: make(chan struct{}), } diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 3ae0d1d4..fdce6e7d 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -1915,49 +1915,6 @@ func TestConnContextCanceledCancelsRunningQueryOnServer(t *testing.T) { } } -func TestConnSendBytesAndReceiveMessage(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) - require.NoError(t, err) - config.RuntimeParams["client_min_messages"] = "notice" // Ensure we only get the messages we expect. - - pgConn, err := pgconn.ConnectConfig(context.Background(), config) - require.NoError(t, err) - defer closeConn(t, pgConn) - - queryMsg := pgproto3.Query{String: "select 42"} - buf := queryMsg.Encode(nil) - - err = pgConn.SendBytes(ctx, buf) - require.NoError(t, err) - - msg, err := pgConn.ReceiveMessage(ctx) - require.NoError(t, err) - _, ok := msg.(*pgproto3.RowDescription) - require.True(t, ok) - - msg, err = pgConn.ReceiveMessage(ctx) - require.NoError(t, err) - _, ok = msg.(*pgproto3.DataRow) - require.True(t, ok) - - msg, err = pgConn.ReceiveMessage(ctx) - require.NoError(t, err) - _, ok = msg.(*pgproto3.CommandComplete) - require.True(t, ok) - - msg, err = pgConn.ReceiveMessage(ctx) - require.NoError(t, err) - _, ok = msg.(*pgproto3.ReadyForQuery) - require.True(t, ok) - - ensureConnValid(t, pgConn) -} - func TestHijackAndConstruct(t *testing.T) { t.Parallel() diff --git a/pgproto3/backend.go b/pgproto3/backend.go index b7db6f76..d619f7e7 100644 --- a/pgproto3/backend.go +++ b/pgproto3/backend.go @@ -11,6 +11,8 @@ type Backend struct { cr *chunkReader w io.Writer + wbuf []byte + // Frontend message flyweights bind Bind cancelRequest CancelRequest @@ -47,10 +49,28 @@ func NewBackend(r io.Reader, w io.Writer) *Backend { return &Backend{cr: cr, w: w} } -// Send sends a message to the frontend. -func (b *Backend) Send(msg BackendMessage) error { - _, err := b.w.Write(msg.Encode(nil)) - return err +// Send sends a message to the frontend (i.e. the client). The message is not guaranteed to be written until Flush is +// called. +func (b *Backend) Send(msg BackendMessage) { + b.wbuf = msg.Encode(b.wbuf) +} + +// Flush writes any pending messages to the frontend (i.e. the client). +func (b *Backend) Flush() error { + n, err := b.w.Write(b.wbuf) + + const maxLen = 1024 + if len(b.wbuf) > maxLen { + b.wbuf = make([]byte, 0, maxLen) + } else { + b.wbuf = b.wbuf[:0] + } + + if err != nil { + return &writeError{err: err, safeToRetry: n == 0} + } + + return nil } // ReceiveStartupMessage receives the initial connection message. This method is used of the normal Receive method diff --git a/pgproto3/frontend.go b/pgproto3/frontend.go index 435275d6..beaaef5f 100644 --- a/pgproto3/frontend.go +++ b/pgproto3/frontend.go @@ -12,6 +12,8 @@ type Frontend struct { cr *chunkReader w io.Writer + wbuf []byte + // Backend message flyweights authenticationOk AuthenticationOk authenticationCleartextPassword AuthenticationCleartextPassword @@ -56,10 +58,28 @@ func NewFrontend(r io.Reader, w io.Writer) *Frontend { return &Frontend{cr: cr, w: w} } -// Send sends a message to the backend. -func (f *Frontend) Send(msg FrontendMessage) error { - _, err := f.w.Write(msg.Encode(nil)) - return err +// Send sends a message to the backend (i.e. the server). The message is not guaranteed to be written until Flush is +// called. +func (f *Frontend) Send(msg FrontendMessage) { + f.wbuf = msg.Encode(f.wbuf) +} + +// Flush writes any pending messages to the backend (i.e. the server). +func (f *Frontend) Flush() error { + n, err := f.w.Write(f.wbuf) + + const maxLen = 1024 + if len(f.wbuf) > maxLen { + f.wbuf = make([]byte, 0, maxLen) + } else { + f.wbuf = f.wbuf[:0] + } + + if err != nil { + return &writeError{err: err, safeToRetry: n == 0} + } + + return nil } func translateEOFtoErrUnexpectedEOF(err error) error { diff --git a/pgproto3/pgproto3.go b/pgproto3/pgproto3.go index 70c825e3..a0333aa5 100644 --- a/pgproto3/pgproto3.go +++ b/pgproto3/pgproto3.go @@ -17,11 +17,13 @@ type Message interface { Encode(dst []byte) []byte } +// FrontendMessage is a message sent by the frontend (i.e. the client). type FrontendMessage interface { Message Frontend() // no-op method to distinguish frontend from backend methods } +// BackendMessage is a message sent by the backend (i.e. the server). type BackendMessage interface { Message Backend() // no-op method to distinguish frontend from backend methods @@ -50,6 +52,23 @@ func (e *invalidMessageFormatErr) Error() string { return fmt.Sprintf("%s body is invalid", e.messageType) } +type writeError struct { + err error + safeToRetry bool +} + +func (e *writeError) Error() string { + return fmt.Sprintf("write failed: %s", e.err.Error()) +} + +func (e *writeError) SafeToRetry() bool { + return e.safeToRetry +} + +func (e *writeError) Unwrap() error { + return e.err +} + // getValueFromJSON gets the value from a protocol message representation in JSON. func getValueFromJSON(v map[string]string) ([]byte, error) { if v == nil {