diff --git a/pgconn/auth_scram.go b/pgconn/auth_scram.go index de13c687..37b602d6 100644 --- a/pgconn/auth_scram.go +++ b/pgconn/auth_scram.go @@ -41,7 +41,8 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { AuthMechanism: "SCRAM-SHA-256", Data: sc.clientFirstMessage(), } - _, err = c.conn.Write(saslInitialResponse.Encode(nil)) + c.frontend.Send(saslInitialResponse) + err = c.frontend.Flush() if err != nil { return err } @@ -60,7 +61,8 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { saslResponse := &pgproto3.SASLResponse{ Data: []byte(sc.clientFinalMessage()), } - _, err = c.conn.Write(saslResponse.Encode(nil)) + c.frontend.Send(saslResponse) + err = c.frontend.Flush() if err != nil { return err } diff --git a/pgconn/krb5.go b/pgconn/krb5.go index 8dffc879..a4bca01f 100644 --- a/pgconn/krb5.go +++ b/pgconn/krb5.go @@ -61,7 +61,8 @@ func (c *PgConn) gssAuth() error { gssResponse := &pgproto3.GSSResponse{ Data: nextData, } - _, err = c.conn.Write(gssResponse.Encode(nil)) + c.frontend.Send(gssResponse) + err = c.frontend.Flush() if err != nil { return err } diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 2cbf8c50..935d3530 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -515,7 +515,7 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { return msg, nil } -// Conn returns the underlying net.Conn. +// Conn returns the underlying net.Conn. This rarely necessary. func (pgConn *PgConn) Conn() net.Conn { return pgConn.conn } @@ -542,6 +542,11 @@ func (pgConn *PgConn) SecretKey() uint32 { return pgConn.secretKey } +// Frontend returns the underlying *pgproto3.Frontend. This rarely necessary. +func (pgConn *PgConn) Frontend() *pgproto3.Frontend { + return pgConn.frontend +} + // Close closes a connection. It is safe to call Close on a already closed connection. Close attempts a clean close by // sending the exit message to PostgreSQL. However, this could block so ctx is available to limit the time to wait. The // underlying net.Conn.Close() will always be called regardless of any other errors. @@ -571,7 +576,8 @@ func (pgConn *PgConn) Close(ctx context.Context) error { // ignores errors. // // See https://github.com/jackc/pgx/issues/637 - pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) + pgConn.frontend.Send(&pgproto3.Terminate{}) + pgConn.frontend.Flush() return pgConn.conn.Close() } @@ -597,7 +603,8 @@ func (pgConn *PgConn) asyncClose() { pgConn.conn.SetDeadline(deadline) - pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) + pgConn.frontend.Send(&pgproto3.Terminate{}) + pgConn.frontend.Flush() }() } diff --git a/pgproto3/backend.go b/pgproto3/backend.go index d619f7e7..ba0be3d3 100644 --- a/pgproto3/backend.go +++ b/pgproto3/backend.go @@ -11,6 +11,10 @@ type Backend struct { cr *chunkReader w io.Writer + // MessageTracer is used to trace messages when Send or Receive is called. This means an outbound message is traced + // before it is actually transmitted (i.e. before Flush). + MessageTracer MessageTracer + wbuf []byte // Frontend message flyweights @@ -52,7 +56,11 @@ func NewBackend(r io.Reader, w io.Writer) *Backend { // Send sends a message to the frontend (i.e. the client). The message is not guaranteed to be written until Flush is // called. func (b *Backend) Send(msg BackendMessage) { + prevLen := len(b.wbuf) b.wbuf = msg.Encode(b.wbuf) + if b.MessageTracer != nil { + b.MessageTracer.TraceMessage('B', int32(len(b.wbuf)-prevLen), msg) + } } // Flush writes any pending messages to the frontend (i.e. the client). @@ -193,7 +201,15 @@ func (b *Backend) Receive() (FrontendMessage, error) { b.partialMsg = false err = msg.Decode(msgBody) - return msg, err + if err != nil { + return nil, err + } + + if b.MessageTracer != nil { + b.MessageTracer.TraceMessage('F', int32(5+len(msgBody)), msg) + } + + return msg, nil } // SetAuthType sets the authentication type in the backend. diff --git a/pgproto3/frontend.go b/pgproto3/frontend.go index beaaef5f..342a0ddd 100644 --- a/pgproto3/frontend.go +++ b/pgproto3/frontend.go @@ -12,6 +12,11 @@ type Frontend struct { cr *chunkReader w io.Writer + // MessageTracer is used to trace messages when Send or Receive is called. This means an outbound message is traced + // before it is actually transmitted (i.e. before Flush). It is safe to change this variable when the Frontend is + // idle. Setting and unsetting MessageTracer provides equivalent functionality to PQtrace and PQuntrace in libpq. + MessageTracer MessageTracer + wbuf []byte // Backend message flyweights @@ -61,7 +66,11 @@ func NewFrontend(r io.Reader, w io.Writer) *Frontend { // Send sends a message to the backend (i.e. the server). The message is not guaranteed to be written until Flush is // called. func (f *Frontend) Send(msg FrontendMessage) { + prevLen := len(f.wbuf) f.wbuf = msg.Encode(f.wbuf) + if f.MessageTracer != nil { + f.MessageTracer.TraceMessage('F', int32(len(f.wbuf)-prevLen), msg) + } } // Flush writes any pending messages to the backend (i.e. the server). @@ -166,7 +175,15 @@ func (f *Frontend) Receive() (BackendMessage, error) { } err = msg.Decode(msgBody) - return msg, err + if err != nil { + return nil, err + } + + if f.MessageTracer != nil { + f.MessageTracer.TraceMessage('B', int32(5+len(msgBody)), msg) + } + + return msg, nil } // Authentication message type constants. diff --git a/pgproto3/trace.go b/pgproto3/trace.go new file mode 100644 index 00000000..b35ecdb6 --- /dev/null +++ b/pgproto3/trace.go @@ -0,0 +1,191 @@ +package pgproto3 + +import ( + "bytes" + "fmt" + "io" + "strings" + "time" +) + +// MessageTracer is an interface that traces the messages send to and from a Backend or Frontend. +type MessageTracer interface { + // TraceMessage tracks the sending or receiving of a message. sender is either 'F' for frontend or 'B' for backend. + TraceMessage(sender byte, encodedLen int32, msg Message) +} + +// LibpqMessageTracer is a MessageTracer that roughly mimics the format produced by the libpq C function PQtrace. +type LibpqMessageTracer struct { + Writer io.Writer + + // SuppressTimestamps prevents printing of timestamps. + SuppressTimestamps bool + + // RegressMode redacts fields that may be vary between executions. + RegressMode bool +} + +func (t *LibpqMessageTracer) TraceMessage(sender byte, encodedLen int32, msg Message) { + buf := &bytes.Buffer{} + + if !t.SuppressTimestamps { + now := time.Now() + buf.WriteString(now.Format("2006-01-02 15:04:05.000000")) + buf.WriteByte('\t') + } + + buf.WriteByte(sender) + buf.WriteByte('\t') + + switch msg := msg.(type) { + case *AuthenticationCleartextPassword: + buf.WriteString("AuthenticationCleartextPassword") + case *AuthenticationGSS: + buf.WriteString("AuthenticationGSS") + case *AuthenticationGSSContinue: + buf.WriteString("AuthenticationGSSContinue") + case *AuthenticationMD5Password: + buf.WriteString("AuthenticationMD5Password") + case *AuthenticationOk: + buf.WriteString("AuthenticationOk") + case *AuthenticationSASL: + buf.WriteString("AuthenticationSASL") + case *AuthenticationSASLContinue: + buf.WriteString("AuthenticationSASLContinue") + case *AuthenticationSASLFinal: + buf.WriteString("AuthenticationSASLFinal") + case *BackendKeyData: + if t.RegressMode { + buf.WriteString("BackendKeyData\t NNNN NNNN") + } else { + fmt.Fprintf(buf, "BackendKeyData\t %d %d", msg.ProcessID, msg.SecretKey) + } + case *Bind: + fmt.Fprintf(buf, "Bind\t %s %s %d", traceDoubleQuotedString([]byte(msg.DestinationPortal)), traceDoubleQuotedString([]byte(msg.PreparedStatement)), len(msg.ParameterFormatCodes)) + for _, fc := range msg.ParameterFormatCodes { + fmt.Fprintf(buf, " %d", fc) + } + fmt.Fprintf(buf, " %d", len(msg.Parameters)) + for _, p := range msg.Parameters { + fmt.Fprintf(buf, " %s", traceSingleQuotedString(p)) + } + fmt.Fprintf(buf, " %d", len(msg.ResultFormatCodes)) + for _, fc := range msg.ResultFormatCodes { + fmt.Fprintf(buf, " %d", fc) + } + case *BindComplete: + buf.WriteString("BindComplete") + case *CancelRequest: + buf.WriteString("CancelRequest") + case *Close: + buf.WriteString("Close") + case *CloseComplete: + buf.WriteString("CloseComplete") + case *CommandComplete: + fmt.Fprintf(buf, "CommandComplete\t %s", traceDoubleQuotedString(msg.CommandTag)) + case *CopyBothResponse: + buf.WriteString("CopyBothResponse") + case *CopyData: + buf.WriteString("CopyData") + case *CopyDone: + buf.WriteString("CopyDone") + case *CopyFail: + fmt.Fprintf(buf, "CopyFail\t %s", traceDoubleQuotedString([]byte(msg.Message))) + case *CopyInResponse: + buf.WriteString("CopyInResponse") + case *CopyOutResponse: + buf.WriteString("CopyOutResponse") + case *DataRow: + fmt.Fprintf(buf, "DataRow\t %d", len(msg.Values)) + for _, v := range msg.Values { + if v == nil { + buf.WriteString(" -1") + } else { + fmt.Fprintf(buf, " %d %s", len(v), traceSingleQuotedString(v)) + } + } + case *Describe: + fmt.Fprintf(buf, "Describe\t %c %s", msg.ObjectType, traceDoubleQuotedString([]byte(msg.Name))) + case *EmptyQueryResponse: + buf.WriteString("EmptyQueryResponse") + case *ErrorResponse: + buf.WriteString("ErrorResponse") + case *Execute: + fmt.Fprintf(buf, "Execute\t %s %d", traceDoubleQuotedString([]byte(msg.Portal)), msg.MaxRows) + case *Flush: + buf.WriteString("Flush") + case *FunctionCall: + buf.WriteString("FunctionCall") + case *FunctionCallResponse: + buf.WriteString("FunctionCallResponse") + case *GSSEncRequest: + buf.WriteString("GSSEncRequest") + case *NoData: + buf.WriteString("NoData") + case *NoticeResponse: + buf.WriteString("NoticeResponse") + case *NotificationResponse: + fmt.Fprintf(buf, "NotificationResponse\t %d %s %s", msg.PID, traceDoubleQuotedString([]byte(msg.Channel)), traceDoubleQuotedString([]byte(msg.Payload))) + case *ParameterDescription: + buf.WriteString("ParameterDescription") + case *ParameterStatus: + fmt.Fprintf(buf, "ParameterStatus\t %s %s", traceDoubleQuotedString([]byte(msg.Name)), traceDoubleQuotedString([]byte(msg.Value))) + case *Parse: + fmt.Fprintf(buf, "Parse\t %s %s %d", traceDoubleQuotedString([]byte(msg.Name)), traceDoubleQuotedString([]byte(msg.Query)), len(msg.ParameterOIDs)) + for _, oid := range msg.ParameterOIDs { + fmt.Fprintf(buf, " %d", oid) + } + case *ParseComplete: + buf.WriteString("ParseComplete") + case *PortalSuspended: + buf.WriteString("PortalSuspended") + case *Query: + buf.WriteString("Query\t") + fmt.Fprintf(buf, ` "%s"`, msg.String) + case *ReadyForQuery: + fmt.Fprintf(buf, "ReadyForQuery\t %c", msg.TxStatus) + case *RowDescription: + buf.WriteString("RowDescription\t") + fmt.Fprintf(buf, " %d", len(msg.Fields)) + for _, fd := range msg.Fields { + fmt.Fprintf(buf, ` %s %d %d %d %d %d %d`, traceDoubleQuotedString(fd.Name), fd.TableOID, fd.TableAttributeNumber, fd.DataTypeOID, fd.DataTypeSize, fd.TypeModifier, fd.Format) + } + case *SSLRequest: + buf.WriteString("SSLRequest") + case *StartupMessage: + buf.WriteString("StartupMessage") + case *Sync: + buf.WriteString("Sync") + case *Terminate: + buf.WriteString("Terminate") + default: + buf.WriteString("Unknown") + } + + buf.WriteByte('\n') + buf.WriteTo(t.Writer) +} + +// traceDoubleQuotedString returns buf as a double-quoted string without any escaping. It is roughly equivalent to +// pqTraceOutputString in libpq. +func traceDoubleQuotedString(buf []byte) string { + return `"` + string(buf) + `"` +} + +// traceSingleQuotedString returns buf as a single-quoted string with non-printable characters hex-escaped. It is +// roughly equivalent to pqTraceOutputNchar in libpq. +func traceSingleQuotedString(buf []byte) string { + sb := &strings.Builder{} + + sb.WriteByte('\'') + for _, b := range buf { + if b < 32 || b > 126 { + fmt.Fprintf(sb, `\x%x`, b) + } else { + sb.WriteByte(b) + } + } + sb.WriteByte('\'') + + return sb.String() +} diff --git a/pgproto3/trace_test.go b/pgproto3/trace_test.go new file mode 100644 index 00000000..f78bd346 --- /dev/null +++ b/pgproto3/trace_test.go @@ -0,0 +1,79 @@ +package pgproto3_test + +import ( + "bytes" + "context" + "io" + "os" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgproto3" + "github.com/stretchr/testify/require" +) + +func TestLibpqMessageTracer(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + + traceOutput := &bytes.Buffer{} + + config.BuildFrontend = func(r io.Reader, w io.Writer) *pgproto3.Frontend { + f := pgproto3.NewFrontend(r, w) + f.MessageTracer = &pgproto3.LibpqMessageTracer{ + Writer: traceOutput, + SuppressTimestamps: true, + RegressMode: true, + } + return f + } + + conn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer conn.Close(ctx) + + result := conn.ExecParams(ctx, "select n from generate_series(1,5) n", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + + expected := `F StartupMessage +B AuthenticationOk +B ParameterStatus "application_name" "" +B ParameterStatus "client_encoding" "UTF8" +B ParameterStatus "DateStyle" "ISO, MDY" +B ParameterStatus "default_transaction_read_only" "off" +B ParameterStatus "in_hot_standby" "off" +B ParameterStatus "integer_datetimes" "on" +B ParameterStatus "IntervalStyle" "postgres" +B ParameterStatus "is_superuser" "on" +B ParameterStatus "server_encoding" "UTF8" +B ParameterStatus "server_version" "14.3" +B ParameterStatus "session_authorization" "jack" +B ParameterStatus "standard_conforming_strings" "on" +B ParameterStatus "TimeZone" "America/Chicago" +B BackendKeyData NNNN NNNN +B ReadyForQuery I +F Parse "" "select n from generate_series(1,5) n" 0 +F Bind "" "" 0 0 0 +F Describe P "" +F Execute "" 0 +F Sync +B ParseComplete +B BindComplete +B RowDescription 1 "n" 0 0 23 4 -1 0 +B DataRow 1 1 '1' +B DataRow 1 1 '2' +B DataRow 1 1 '3' +B DataRow 1 1 '4' +B DataRow 1 1 '5' +B CommandComplete "SELECT 5" +B ReadyForQuery I +` + + require.Equal(t, expected, traceOutput.String()) +}