Restructure sending messages

Use an internal buffer in pgproto3.Frontend and pgproto3.Backend instead
of directly writing to the underlying net.Conn. This will allow tracing
messages as well as simplify pipeline mode.
non-blocking
Jack Christensen 2022-05-21 11:06:44 -05:00
parent 989a4835de
commit 5714896b10
9 changed files with 105 additions and 226 deletions

View File

@ -97,7 +97,8 @@ type sendMessageStep struct {
}
func (e *sendMessageStep) Step(backend *pgproto3.Backend) error {
return backend.Send(e.msg)
backend.Send(e.msg)
return backend.Flush()
}
func SendMessage(msg pgproto3.BackendMessage) Step {

View File

@ -222,7 +222,7 @@ func ParseConfig(connString string) (*Config, error) {
User: settings["user"],
Password: settings["password"],
RuntimeParams: make(map[string]string),
BuildFrontend: func(r io.Reader, w io.Writer) Frontend {
BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend {
return pgproto3.NewFrontend(r, w)
},
}

View File

@ -178,23 +178,6 @@ func newContextAlreadyDoneError(ctx context.Context) (err error) {
return &errTimeout{&contextAlreadyDoneError{err: ctx.Err()}}
}
type writeError struct {
err error
safeToRetry bool
}
func (e *writeError) Error() string {
return fmt.Sprintf("write failed: %s", e.err.Error())
}
func (e *writeError) SafeToRetry() bool {
return e.safeToRetry
}
func (e *writeError) Unwrap() error {
return e.err
}
func redactPW(connString string) string {
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
if u, err := url.Parse(connString); err == nil {

View File

@ -1,70 +0,0 @@
package pgconn_test
import (
"context"
"io"
"os"
"testing"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgproto3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// frontendWrapper allows to hijack a regular frontend, and inject a specific response
type frontendWrapper struct {
front pgconn.Frontend
msg pgproto3.BackendMessage
}
// frontendWrapper implements the pgconn.Frontend interface
var _ pgconn.Frontend = (*frontendWrapper)(nil)
func (f *frontendWrapper) Receive() (pgproto3.BackendMessage, error) {
if f.msg != nil {
return f.msg, nil
}
return f.front.Receive()
}
func TestFrontendFatalErrExec(t *testing.T) {
t.Parallel()
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
buildFrontend := config.BuildFrontend
var front *frontendWrapper
config.BuildFrontend = func(r io.Reader, w io.Writer) pgconn.Frontend {
wrapped := buildFrontend(r, w)
front = &frontendWrapper{wrapped, nil}
return front
}
conn, err := pgconn.ConnectConfig(context.Background(), config)
require.NoError(t, err)
require.NotNil(t, conn)
require.NotNil(t, front)
// set frontend to return a "FATAL" message on next call
front.msg = &pgproto3.ErrorResponse{Severity: "FATAL", Message: "unit testing fatal error"}
_, err = conn.Exec(context.Background(), "SELECT 1").ReadAll()
assert.Error(t, err)
err = conn.Close(context.Background())
assert.NoError(t, err)
select {
case <-conn.CleanupDone():
t.Log("ok, CleanupDone() is not blocking")
default:
assert.Fail(t, "connection closed but CleanupDone() still blocking")
}
}

View File

@ -29,8 +29,6 @@ const (
connStatusBusy
)
const wbufLen = 1024
// Notice represents a notice response message reported by the PostgreSQL server. Be aware that this is distinct from
// LISTEN/NOTIFY notification.
type Notice PgError
@ -50,7 +48,7 @@ type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
type LookupFunc func(ctx context.Context, host string) (addrs []string, err error)
// BuildFrontendFunc is a function that can be used to create Frontend implementation for connection.
type BuildFrontendFunc func(r io.Reader, w io.Writer) Frontend
type BuildFrontendFunc func(r io.Reader, w io.Writer) *pgproto3.Frontend
// NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at
// any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin
@ -64,11 +62,6 @@ type NoticeHandler func(*PgConn, *Notice)
// notice event.
type NotificationHandler func(*PgConn, *Notification)
// Frontend used to receive messages from backend.
type Frontend interface {
Receive() (pgproto3.BackendMessage, error)
}
// PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage.
type PgConn struct {
conn net.Conn // the underlying TCP or unix domain socket connection
@ -76,7 +69,7 @@ type PgConn struct {
secretKey uint32 // key to use to send a cancel query message to the server
parameterStatuses map[string]string // parameters that have been reported by the server
txStatus byte
frontend Frontend
frontend *pgproto3.Frontend
config *Config
@ -90,7 +83,6 @@ type PgConn struct {
peekedMsg pgproto3.BackendMessage
// Reusable / preallocated resources
wbuf []byte // write buffer
resultReader ResultReader
multiResultReader MultiResultReader
contextWatcher *ctxwatch.ContextWatcher
@ -230,7 +222,6 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba
func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) {
pgConn := new(PgConn)
pgConn.config = config
pgConn.wbuf = make([]byte, 0, wbufLen)
pgConn.cleanupDone = make(chan struct{})
var err error
@ -282,7 +273,8 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
startupMsg.Parameters["database"] = config.Database
}
if _, err := pgConn.conn.Write(startupMsg.Encode(pgConn.wbuf)); err != nil {
pgConn.frontend.Send(&startupMsg)
if err := pgConn.frontend.Flush(); err != nil {
pgConn.conn.Close()
return nil, &connectError{config: config, msg: "failed to write startup message", err: err}
}
@ -383,9 +375,8 @@ func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) {
}
func (pgConn *PgConn) txPasswordMessage(password string) (err error) {
msg := &pgproto3.PasswordMessage{Password: password}
_, err = pgConn.conn.Write(msg.Encode(pgConn.wbuf))
return err
pgConn.frontend.Send(&pgproto3.PasswordMessage{Password: password})
return pgConn.frontend.Flush()
}
func hexMD5(s string) string {
@ -412,36 +403,6 @@ func (pgConn *PgConn) signalMessage() chan struct{} {
return ch
}
// SendBytes sends buf to the PostgreSQL server. It must only be used when the connection is not busy. e.g. It is as
// error to call SendBytes while reading the result of a query.
//
// This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly.
// See https://www.postgresql.org/docs/current/protocol.html.
func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error {
if err := pgConn.lock(); err != nil {
return err
}
defer pgConn.unlock()
if ctx != context.Background() {
select {
case <-ctx.Done():
return newContextAlreadyDoneError(ctx)
default:
}
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
}
n, err := pgConn.conn.Write(buf)
if err != nil {
pgConn.asyncClose()
return &writeError{err: err, safeToRetry: n == 0}
}
return nil
}
// ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the
// connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages
// are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger
@ -797,15 +758,13 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
defer pgConn.contextWatcher.Unwatch()
}
buf := pgConn.wbuf
buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf)
buf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(buf)
buf = (&pgproto3.Sync{}).Encode(buf)
n, err := pgConn.conn.Write(buf)
pgConn.frontend.Send(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs})
pgConn.frontend.Send(&pgproto3.Describe{ObjectType: 'S', Name: name})
pgConn.frontend.Send(&pgproto3.Sync{})
err := pgConn.frontend.Flush()
if err != nil {
pgConn.asyncClose()
return nil, &writeError{err: err, safeToRetry: n == 0}
return nil, err
}
psd := &StatementDescription{Name: name, SQL: sql}
@ -971,15 +930,13 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
pgConn.contextWatcher.Watch(ctx)
}
buf := pgConn.wbuf
buf = (&pgproto3.Query{String: sql}).Encode(buf)
n, err := pgConn.conn.Write(buf)
pgConn.frontend.Send(&pgproto3.Query{String: sql})
err := pgConn.frontend.Flush()
if err != nil {
pgConn.asyncClose()
pgConn.contextWatcher.Unwatch()
multiResult.closed = true
multiResult.err = &writeError{err: err, safeToRetry: n == 0}
multiResult.err = err
pgConn.unlock()
return multiResult
}
@ -1045,11 +1002,10 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues []
return result
}
buf := pgConn.wbuf
buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf)
buf = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf)
pgConn.frontend.Send(&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs})
pgConn.frontend.Send(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats})
pgConn.execExtendedSuffix(buf, result)
pgConn.execExtendedSuffix(result)
return result
}
@ -1072,10 +1028,9 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa
return result
}
buf := pgConn.wbuf
buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf)
pgConn.frontend.Send(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats})
pgConn.execExtendedSuffix(buf, result)
pgConn.execExtendedSuffix(result)
return result
}
@ -1115,15 +1070,15 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
return result
}
func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) {
buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf)
buf = (&pgproto3.Execute{}).Encode(buf)
buf = (&pgproto3.Sync{}).Encode(buf)
func (pgConn *PgConn) execExtendedSuffix(result *ResultReader) {
pgConn.frontend.Send(&pgproto3.Describe{ObjectType: 'P'})
pgConn.frontend.Send(&pgproto3.Execute{})
pgConn.frontend.Send(&pgproto3.Sync{})
n, err := pgConn.conn.Write(buf)
err := pgConn.frontend.Flush()
if err != nil {
pgConn.asyncClose()
result.concludeCommand(CommandTag{}, &writeError{err: err, safeToRetry: n == 0})
result.concludeCommand(CommandTag{}, err)
pgConn.contextWatcher.Unwatch()
result.closed = true
pgConn.unlock()
@ -1151,14 +1106,13 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
}
// Send copy to command
buf := pgConn.wbuf
buf = (&pgproto3.Query{String: sql}).Encode(buf)
pgConn.frontend.Send(&pgproto3.Query{String: sql})
n, err := pgConn.conn.Write(buf)
err := pgConn.frontend.Flush()
if err != nil {
pgConn.asyncClose()
pgConn.unlock()
return CommandTag{}, &writeError{err: err, safeToRetry: n == 0}
return CommandTag{}, err
}
// Read results
@ -1211,13 +1165,12 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
}
// Send copy to command
buf := pgConn.wbuf
buf = (&pgproto3.Query{String: sql}).Encode(buf)
pgConn.frontend.Send(&pgproto3.Query{String: sql})
n, err := pgConn.conn.Write(buf)
err := pgConn.frontend.Flush()
if err != nil {
pgConn.asyncClose()
return CommandTag{}, &writeError{err: err, safeToRetry: n == 0}
return CommandTag{}, err
}
// Send copy data
@ -1280,15 +1233,12 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
}
close(abortCopyChan)
buf = buf[:0]
if copyErr == io.EOF || pgErr != nil {
copyDone := &pgproto3.CopyDone{}
buf = copyDone.Encode(buf)
pgConn.frontend.Send(&pgproto3.CopyDone{})
} else {
copyFail := &pgproto3.CopyFail{Message: copyErr.Error()}
buf = copyFail.Encode(buf)
pgConn.frontend.Send(&pgproto3.CopyFail{Message: copyErr.Error()})
}
_, err = pgConn.conn.Write(buf)
err = pgConn.frontend.Flush()
if err != nil {
pgConn.asyncClose()
return CommandTag{}, err
@ -1692,7 +1642,7 @@ type HijackedConn struct {
SecretKey uint32 // key to use to send a cancel query message to the server
ParameterStatuses map[string]string // parameters that have been reported by the server
TxStatus byte
Frontend Frontend
Frontend *pgproto3.Frontend
Config *Config
}
@ -1736,7 +1686,6 @@ func Construct(hc *HijackedConn) (*PgConn, error) {
status: connStatusIdle,
wbuf: make([]byte, 0, wbufLen),
cleanupDone: make(chan struct{}),
}

View File

@ -1915,49 +1915,6 @@ func TestConnContextCanceledCancelsRunningQueryOnServer(t *testing.T) {
}
}
func TestConnSendBytesAndReceiveMessage(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
config.RuntimeParams["client_min_messages"] = "notice" // Ensure we only get the messages we expect.
pgConn, err := pgconn.ConnectConfig(context.Background(), config)
require.NoError(t, err)
defer closeConn(t, pgConn)
queryMsg := pgproto3.Query{String: "select 42"}
buf := queryMsg.Encode(nil)
err = pgConn.SendBytes(ctx, buf)
require.NoError(t, err)
msg, err := pgConn.ReceiveMessage(ctx)
require.NoError(t, err)
_, ok := msg.(*pgproto3.RowDescription)
require.True(t, ok)
msg, err = pgConn.ReceiveMessage(ctx)
require.NoError(t, err)
_, ok = msg.(*pgproto3.DataRow)
require.True(t, ok)
msg, err = pgConn.ReceiveMessage(ctx)
require.NoError(t, err)
_, ok = msg.(*pgproto3.CommandComplete)
require.True(t, ok)
msg, err = pgConn.ReceiveMessage(ctx)
require.NoError(t, err)
_, ok = msg.(*pgproto3.ReadyForQuery)
require.True(t, ok)
ensureConnValid(t, pgConn)
}
func TestHijackAndConstruct(t *testing.T) {
t.Parallel()

View File

@ -11,6 +11,8 @@ type Backend struct {
cr *chunkReader
w io.Writer
wbuf []byte
// Frontend message flyweights
bind Bind
cancelRequest CancelRequest
@ -47,10 +49,28 @@ func NewBackend(r io.Reader, w io.Writer) *Backend {
return &Backend{cr: cr, w: w}
}
// Send sends a message to the frontend.
func (b *Backend) Send(msg BackendMessage) error {
_, err := b.w.Write(msg.Encode(nil))
return err
// Send sends a message to the frontend (i.e. the client). The message is not guaranteed to be written until Flush is
// called.
func (b *Backend) Send(msg BackendMessage) {
b.wbuf = msg.Encode(b.wbuf)
}
// Flush writes any pending messages to the frontend (i.e. the client).
func (b *Backend) Flush() error {
n, err := b.w.Write(b.wbuf)
const maxLen = 1024
if len(b.wbuf) > maxLen {
b.wbuf = make([]byte, 0, maxLen)
} else {
b.wbuf = b.wbuf[:0]
}
if err != nil {
return &writeError{err: err, safeToRetry: n == 0}
}
return nil
}
// ReceiveStartupMessage receives the initial connection message. This method is used of the normal Receive method

View File

@ -12,6 +12,8 @@ type Frontend struct {
cr *chunkReader
w io.Writer
wbuf []byte
// Backend message flyweights
authenticationOk AuthenticationOk
authenticationCleartextPassword AuthenticationCleartextPassword
@ -56,10 +58,28 @@ func NewFrontend(r io.Reader, w io.Writer) *Frontend {
return &Frontend{cr: cr, w: w}
}
// Send sends a message to the backend.
func (f *Frontend) Send(msg FrontendMessage) error {
_, err := f.w.Write(msg.Encode(nil))
return err
// Send sends a message to the backend (i.e. the server). The message is not guaranteed to be written until Flush is
// called.
func (f *Frontend) Send(msg FrontendMessage) {
f.wbuf = msg.Encode(f.wbuf)
}
// Flush writes any pending messages to the backend (i.e. the server).
func (f *Frontend) Flush() error {
n, err := f.w.Write(f.wbuf)
const maxLen = 1024
if len(f.wbuf) > maxLen {
f.wbuf = make([]byte, 0, maxLen)
} else {
f.wbuf = f.wbuf[:0]
}
if err != nil {
return &writeError{err: err, safeToRetry: n == 0}
}
return nil
}
func translateEOFtoErrUnexpectedEOF(err error) error {

View File

@ -17,11 +17,13 @@ type Message interface {
Encode(dst []byte) []byte
}
// FrontendMessage is a message sent by the frontend (i.e. the client).
type FrontendMessage interface {
Message
Frontend() // no-op method to distinguish frontend from backend methods
}
// BackendMessage is a message sent by the backend (i.e. the server).
type BackendMessage interface {
Message
Backend() // no-op method to distinguish frontend from backend methods
@ -50,6 +52,23 @@ func (e *invalidMessageFormatErr) Error() string {
return fmt.Sprintf("%s body is invalid", e.messageType)
}
type writeError struct {
err error
safeToRetry bool
}
func (e *writeError) Error() string {
return fmt.Sprintf("write failed: %s", e.err.Error())
}
func (e *writeError) SafeToRetry() bool {
return e.safeToRetry
}
func (e *writeError) Unwrap() error {
return e.err
}
// getValueFromJSON gets the value from a protocol message representation in JSON.
func getValueFromJSON(v map[string]string) ([]byte, error) {
if v == nil {