Add message tracing

non-blocking
Jack Christensen 2022-05-21 14:43:04 -05:00
parent 5714896b10
commit f2e96156a0
7 changed files with 321 additions and 8 deletions

View File

@ -41,7 +41,8 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
AuthMechanism: "SCRAM-SHA-256",
Data: sc.clientFirstMessage(),
}
_, err = c.conn.Write(saslInitialResponse.Encode(nil))
c.frontend.Send(saslInitialResponse)
err = c.frontend.Flush()
if err != nil {
return err
}
@ -60,7 +61,8 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
saslResponse := &pgproto3.SASLResponse{
Data: []byte(sc.clientFinalMessage()),
}
_, err = c.conn.Write(saslResponse.Encode(nil))
c.frontend.Send(saslResponse)
err = c.frontend.Flush()
if err != nil {
return err
}

View File

@ -61,7 +61,8 @@ func (c *PgConn) gssAuth() error {
gssResponse := &pgproto3.GSSResponse{
Data: nextData,
}
_, err = c.conn.Write(gssResponse.Encode(nil))
c.frontend.Send(gssResponse)
err = c.frontend.Flush()
if err != nil {
return err
}

View File

@ -515,7 +515,7 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
return msg, nil
}
// Conn returns the underlying net.Conn.
// Conn returns the underlying net.Conn. This rarely necessary.
func (pgConn *PgConn) Conn() net.Conn {
return pgConn.conn
}
@ -542,6 +542,11 @@ func (pgConn *PgConn) SecretKey() uint32 {
return pgConn.secretKey
}
// Frontend returns the underlying *pgproto3.Frontend. This rarely necessary.
func (pgConn *PgConn) Frontend() *pgproto3.Frontend {
return pgConn.frontend
}
// Close closes a connection. It is safe to call Close on a already closed connection. Close attempts a clean close by
// sending the exit message to PostgreSQL. However, this could block so ctx is available to limit the time to wait. The
// underlying net.Conn.Close() will always be called regardless of any other errors.
@ -571,7 +576,8 @@ func (pgConn *PgConn) Close(ctx context.Context) error {
// ignores errors.
//
// See https://github.com/jackc/pgx/issues/637
pgConn.conn.Write([]byte{'X', 0, 0, 0, 4})
pgConn.frontend.Send(&pgproto3.Terminate{})
pgConn.frontend.Flush()
return pgConn.conn.Close()
}
@ -597,7 +603,8 @@ func (pgConn *PgConn) asyncClose() {
pgConn.conn.SetDeadline(deadline)
pgConn.conn.Write([]byte{'X', 0, 0, 0, 4})
pgConn.frontend.Send(&pgproto3.Terminate{})
pgConn.frontend.Flush()
}()
}

View File

@ -11,6 +11,10 @@ type Backend struct {
cr *chunkReader
w io.Writer
// MessageTracer is used to trace messages when Send or Receive is called. This means an outbound message is traced
// before it is actually transmitted (i.e. before Flush).
MessageTracer MessageTracer
wbuf []byte
// Frontend message flyweights
@ -52,7 +56,11 @@ func NewBackend(r io.Reader, w io.Writer) *Backend {
// Send sends a message to the frontend (i.e. the client). The message is not guaranteed to be written until Flush is
// called.
func (b *Backend) Send(msg BackendMessage) {
prevLen := len(b.wbuf)
b.wbuf = msg.Encode(b.wbuf)
if b.MessageTracer != nil {
b.MessageTracer.TraceMessage('B', int32(len(b.wbuf)-prevLen), msg)
}
}
// Flush writes any pending messages to the frontend (i.e. the client).
@ -193,7 +201,15 @@ func (b *Backend) Receive() (FrontendMessage, error) {
b.partialMsg = false
err = msg.Decode(msgBody)
return msg, err
if err != nil {
return nil, err
}
if b.MessageTracer != nil {
b.MessageTracer.TraceMessage('F', int32(5+len(msgBody)), msg)
}
return msg, nil
}
// SetAuthType sets the authentication type in the backend.

View File

@ -12,6 +12,11 @@ type Frontend struct {
cr *chunkReader
w io.Writer
// MessageTracer is used to trace messages when Send or Receive is called. This means an outbound message is traced
// before it is actually transmitted (i.e. before Flush). It is safe to change this variable when the Frontend is
// idle. Setting and unsetting MessageTracer provides equivalent functionality to PQtrace and PQuntrace in libpq.
MessageTracer MessageTracer
wbuf []byte
// Backend message flyweights
@ -61,7 +66,11 @@ func NewFrontend(r io.Reader, w io.Writer) *Frontend {
// Send sends a message to the backend (i.e. the server). The message is not guaranteed to be written until Flush is
// called.
func (f *Frontend) Send(msg FrontendMessage) {
prevLen := len(f.wbuf)
f.wbuf = msg.Encode(f.wbuf)
if f.MessageTracer != nil {
f.MessageTracer.TraceMessage('F', int32(len(f.wbuf)-prevLen), msg)
}
}
// Flush writes any pending messages to the backend (i.e. the server).
@ -166,7 +175,15 @@ func (f *Frontend) Receive() (BackendMessage, error) {
}
err = msg.Decode(msgBody)
return msg, err
if err != nil {
return nil, err
}
if f.MessageTracer != nil {
f.MessageTracer.TraceMessage('B', int32(5+len(msgBody)), msg)
}
return msg, nil
}
// Authentication message type constants.

191
pgproto3/trace.go Normal file
View File

@ -0,0 +1,191 @@
package pgproto3
import (
"bytes"
"fmt"
"io"
"strings"
"time"
)
// MessageTracer is an interface that traces the messages send to and from a Backend or Frontend.
type MessageTracer interface {
// TraceMessage tracks the sending or receiving of a message. sender is either 'F' for frontend or 'B' for backend.
TraceMessage(sender byte, encodedLen int32, msg Message)
}
// LibpqMessageTracer is a MessageTracer that roughly mimics the format produced by the libpq C function PQtrace.
type LibpqMessageTracer struct {
Writer io.Writer
// SuppressTimestamps prevents printing of timestamps.
SuppressTimestamps bool
// RegressMode redacts fields that may be vary between executions.
RegressMode bool
}
func (t *LibpqMessageTracer) TraceMessage(sender byte, encodedLen int32, msg Message) {
buf := &bytes.Buffer{}
if !t.SuppressTimestamps {
now := time.Now()
buf.WriteString(now.Format("2006-01-02 15:04:05.000000"))
buf.WriteByte('\t')
}
buf.WriteByte(sender)
buf.WriteByte('\t')
switch msg := msg.(type) {
case *AuthenticationCleartextPassword:
buf.WriteString("AuthenticationCleartextPassword")
case *AuthenticationGSS:
buf.WriteString("AuthenticationGSS")
case *AuthenticationGSSContinue:
buf.WriteString("AuthenticationGSSContinue")
case *AuthenticationMD5Password:
buf.WriteString("AuthenticationMD5Password")
case *AuthenticationOk:
buf.WriteString("AuthenticationOk")
case *AuthenticationSASL:
buf.WriteString("AuthenticationSASL")
case *AuthenticationSASLContinue:
buf.WriteString("AuthenticationSASLContinue")
case *AuthenticationSASLFinal:
buf.WriteString("AuthenticationSASLFinal")
case *BackendKeyData:
if t.RegressMode {
buf.WriteString("BackendKeyData\t NNNN NNNN")
} else {
fmt.Fprintf(buf, "BackendKeyData\t %d %d", msg.ProcessID, msg.SecretKey)
}
case *Bind:
fmt.Fprintf(buf, "Bind\t %s %s %d", traceDoubleQuotedString([]byte(msg.DestinationPortal)), traceDoubleQuotedString([]byte(msg.PreparedStatement)), len(msg.ParameterFormatCodes))
for _, fc := range msg.ParameterFormatCodes {
fmt.Fprintf(buf, " %d", fc)
}
fmt.Fprintf(buf, " %d", len(msg.Parameters))
for _, p := range msg.Parameters {
fmt.Fprintf(buf, " %s", traceSingleQuotedString(p))
}
fmt.Fprintf(buf, " %d", len(msg.ResultFormatCodes))
for _, fc := range msg.ResultFormatCodes {
fmt.Fprintf(buf, " %d", fc)
}
case *BindComplete:
buf.WriteString("BindComplete")
case *CancelRequest:
buf.WriteString("CancelRequest")
case *Close:
buf.WriteString("Close")
case *CloseComplete:
buf.WriteString("CloseComplete")
case *CommandComplete:
fmt.Fprintf(buf, "CommandComplete\t %s", traceDoubleQuotedString(msg.CommandTag))
case *CopyBothResponse:
buf.WriteString("CopyBothResponse")
case *CopyData:
buf.WriteString("CopyData")
case *CopyDone:
buf.WriteString("CopyDone")
case *CopyFail:
fmt.Fprintf(buf, "CopyFail\t %s", traceDoubleQuotedString([]byte(msg.Message)))
case *CopyInResponse:
buf.WriteString("CopyInResponse")
case *CopyOutResponse:
buf.WriteString("CopyOutResponse")
case *DataRow:
fmt.Fprintf(buf, "DataRow\t %d", len(msg.Values))
for _, v := range msg.Values {
if v == nil {
buf.WriteString(" -1")
} else {
fmt.Fprintf(buf, " %d %s", len(v), traceSingleQuotedString(v))
}
}
case *Describe:
fmt.Fprintf(buf, "Describe\t %c %s", msg.ObjectType, traceDoubleQuotedString([]byte(msg.Name)))
case *EmptyQueryResponse:
buf.WriteString("EmptyQueryResponse")
case *ErrorResponse:
buf.WriteString("ErrorResponse")
case *Execute:
fmt.Fprintf(buf, "Execute\t %s %d", traceDoubleQuotedString([]byte(msg.Portal)), msg.MaxRows)
case *Flush:
buf.WriteString("Flush")
case *FunctionCall:
buf.WriteString("FunctionCall")
case *FunctionCallResponse:
buf.WriteString("FunctionCallResponse")
case *GSSEncRequest:
buf.WriteString("GSSEncRequest")
case *NoData:
buf.WriteString("NoData")
case *NoticeResponse:
buf.WriteString("NoticeResponse")
case *NotificationResponse:
fmt.Fprintf(buf, "NotificationResponse\t %d %s %s", msg.PID, traceDoubleQuotedString([]byte(msg.Channel)), traceDoubleQuotedString([]byte(msg.Payload)))
case *ParameterDescription:
buf.WriteString("ParameterDescription")
case *ParameterStatus:
fmt.Fprintf(buf, "ParameterStatus\t %s %s", traceDoubleQuotedString([]byte(msg.Name)), traceDoubleQuotedString([]byte(msg.Value)))
case *Parse:
fmt.Fprintf(buf, "Parse\t %s %s %d", traceDoubleQuotedString([]byte(msg.Name)), traceDoubleQuotedString([]byte(msg.Query)), len(msg.ParameterOIDs))
for _, oid := range msg.ParameterOIDs {
fmt.Fprintf(buf, " %d", oid)
}
case *ParseComplete:
buf.WriteString("ParseComplete")
case *PortalSuspended:
buf.WriteString("PortalSuspended")
case *Query:
buf.WriteString("Query\t")
fmt.Fprintf(buf, ` "%s"`, msg.String)
case *ReadyForQuery:
fmt.Fprintf(buf, "ReadyForQuery\t %c", msg.TxStatus)
case *RowDescription:
buf.WriteString("RowDescription\t")
fmt.Fprintf(buf, " %d", len(msg.Fields))
for _, fd := range msg.Fields {
fmt.Fprintf(buf, ` %s %d %d %d %d %d %d`, traceDoubleQuotedString(fd.Name), fd.TableOID, fd.TableAttributeNumber, fd.DataTypeOID, fd.DataTypeSize, fd.TypeModifier, fd.Format)
}
case *SSLRequest:
buf.WriteString("SSLRequest")
case *StartupMessage:
buf.WriteString("StartupMessage")
case *Sync:
buf.WriteString("Sync")
case *Terminate:
buf.WriteString("Terminate")
default:
buf.WriteString("Unknown")
}
buf.WriteByte('\n')
buf.WriteTo(t.Writer)
}
// traceDoubleQuotedString returns buf as a double-quoted string without any escaping. It is roughly equivalent to
// pqTraceOutputString in libpq.
func traceDoubleQuotedString(buf []byte) string {
return `"` + string(buf) + `"`
}
// traceSingleQuotedString returns buf as a single-quoted string with non-printable characters hex-escaped. It is
// roughly equivalent to pqTraceOutputNchar in libpq.
func traceSingleQuotedString(buf []byte) string {
sb := &strings.Builder{}
sb.WriteByte('\'')
for _, b := range buf {
if b < 32 || b > 126 {
fmt.Fprintf(sb, `\x%x`, b)
} else {
sb.WriteByte(b)
}
}
sb.WriteByte('\'')
return sb.String()
}

79
pgproto3/trace_test.go Normal file
View File

@ -0,0 +1,79 @@
package pgproto3_test
import (
"bytes"
"context"
"io"
"os"
"testing"
"time"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgproto3"
"github.com/stretchr/testify/require"
)
func TestLibpqMessageTracer(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
require.NoError(t, err)
traceOutput := &bytes.Buffer{}
config.BuildFrontend = func(r io.Reader, w io.Writer) *pgproto3.Frontend {
f := pgproto3.NewFrontend(r, w)
f.MessageTracer = &pgproto3.LibpqMessageTracer{
Writer: traceOutput,
SuppressTimestamps: true,
RegressMode: true,
}
return f
}
conn, err := pgconn.ConnectConfig(ctx, config)
require.NoError(t, err)
defer conn.Close(ctx)
result := conn.ExecParams(ctx, "select n from generate_series(1,5) n", nil, nil, nil, nil).Read()
require.NoError(t, result.Err)
expected := `F StartupMessage
B AuthenticationOk
B ParameterStatus "application_name" ""
B ParameterStatus "client_encoding" "UTF8"
B ParameterStatus "DateStyle" "ISO, MDY"
B ParameterStatus "default_transaction_read_only" "off"
B ParameterStatus "in_hot_standby" "off"
B ParameterStatus "integer_datetimes" "on"
B ParameterStatus "IntervalStyle" "postgres"
B ParameterStatus "is_superuser" "on"
B ParameterStatus "server_encoding" "UTF8"
B ParameterStatus "server_version" "14.3"
B ParameterStatus "session_authorization" "jack"
B ParameterStatus "standard_conforming_strings" "on"
B ParameterStatus "TimeZone" "America/Chicago"
B BackendKeyData NNNN NNNN
B ReadyForQuery I
F Parse "" "select n from generate_series(1,5) n" 0
F Bind "" "" 0 0 0
F Describe P ""
F Execute "" 0
F Sync
B ParseComplete
B BindComplete
B RowDescription 1 "n" 0 0 23 4 -1 0
B DataRow 1 1 '1'
B DataRow 1 1 '2'
B DataRow 1 1 '3'
B DataRow 1 1 '4'
B DataRow 1 1 '5'
B CommandComplete "SELECT 5"
B ReadyForQuery I
`
require.Equal(t, expected, traceOutput.String())
}