From b74c109f61fd17a89d2850ee379faf7c74754c8b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 21 May 2022 17:18:05 -0500 Subject: [PATCH] Optimize tracing The addition of tracing caused messages to escape to the heap. By avoiding interfaces the messages no longer escape. --- pgconn/pgconn.go | 26 +-- pgproto3/backend.go | 28 ++- pgproto3/frontend.go | 120 ++++++++++- pgproto3/trace.go | 446 ++++++++++++++++++++++++++++++++--------- pgproto3/trace_test.go | 7 +- 5 files changed, 497 insertions(+), 130 deletions(-) diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 935d3530..af8aeb57 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -765,9 +765,9 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ defer pgConn.contextWatcher.Unwatch() } - pgConn.frontend.Send(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}) - pgConn.frontend.Send(&pgproto3.Describe{ObjectType: 'S', Name: name}) - pgConn.frontend.Send(&pgproto3.Sync{}) + pgConn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}) + pgConn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name}) + pgConn.frontend.SendSync(&pgproto3.Sync{}) err := pgConn.frontend.Flush() if err != nil { pgConn.asyncClose() @@ -937,7 +937,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { pgConn.contextWatcher.Watch(ctx) } - pgConn.frontend.Send(&pgproto3.Query{String: sql}) + pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) err := pgConn.frontend.Flush() if err != nil { pgConn.asyncClose() @@ -1009,8 +1009,8 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] return result } - pgConn.frontend.Send(&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}) - pgConn.frontend.Send(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) + pgConn.frontend.SendParse(&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}) + pgConn.frontend.SendBind(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) pgConn.execExtendedSuffix(result) @@ -1035,7 +1035,7 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa return result } - pgConn.frontend.Send(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) + pgConn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) pgConn.execExtendedSuffix(result) @@ -1078,9 +1078,9 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by } func (pgConn *PgConn) execExtendedSuffix(result *ResultReader) { - pgConn.frontend.Send(&pgproto3.Describe{ObjectType: 'P'}) - pgConn.frontend.Send(&pgproto3.Execute{}) - pgConn.frontend.Send(&pgproto3.Sync{}) + pgConn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'}) + pgConn.frontend.SendExecute(&pgproto3.Execute{}) + pgConn.frontend.SendSync(&pgproto3.Sync{}) err := pgConn.frontend.Flush() if err != nil { @@ -1113,7 +1113,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm } // Send copy to command - pgConn.frontend.Send(&pgproto3.Query{String: sql}) + pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) err := pgConn.frontend.Flush() if err != nil { @@ -1172,7 +1172,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } // Send copy to command - pgConn.frontend.Send(&pgproto3.Query{String: sql}) + pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) err := pgConn.frontend.Flush() if err != nil { @@ -1196,7 +1196,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co buf = buf[0 : n+5] pgio.SetInt32(buf[sp:], int32(n+4)) - _, writeErr := pgConn.conn.Write(buf) + writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(buf) if writeErr != nil { // Write errors are always fatal, but we can't use asyncClose because we are in a different goroutine. pgConn.conn.Close() diff --git a/pgproto3/backend.go b/pgproto3/backend.go index ba0be3d3..09aeb7c8 100644 --- a/pgproto3/backend.go +++ b/pgproto3/backend.go @@ -1,6 +1,7 @@ package pgproto3 import ( + "bytes" "encoding/binary" "fmt" "io" @@ -11,9 +12,9 @@ type Backend struct { cr *chunkReader w io.Writer - // MessageTracer is used to trace messages when Send or Receive is called. This means an outbound message is traced + // tracer is used to trace messages when Send or Receive is called. This means an outbound message is traced // before it is actually transmitted (i.e. before Flush). - MessageTracer MessageTracer + tracer *tracer wbuf []byte @@ -58,8 +59,8 @@ func NewBackend(r io.Reader, w io.Writer) *Backend { func (b *Backend) Send(msg BackendMessage) { prevLen := len(b.wbuf) b.wbuf = msg.Encode(b.wbuf) - if b.MessageTracer != nil { - b.MessageTracer.TraceMessage('B', int32(len(b.wbuf)-prevLen), msg) + if b.tracer != nil { + b.tracer.traceMessage('B', int32(len(b.wbuf)-prevLen), msg) } } @@ -81,6 +82,21 @@ func (b *Backend) Flush() error { return nil } +// Trace starts tracing the message traffic to w. It writes in a similar format to that produced by the libpq function +// PQtrace. +func (b *Backend) Trace(w io.Writer, options TracerOptions) { + b.tracer = &tracer{ + w: w, + buf: &bytes.Buffer{}, + TracerOptions: options, + } +} + +// Untrace stops tracing. +func (b *Backend) Untrace() { + b.tracer = nil +} + // ReceiveStartupMessage receives the initial connection message. This method is used of the normal Receive method // because the initial connection message is "special" and does not include the message type as the first byte. This // will return either a StartupMessage, SSLRequest, GSSEncRequest, or CancelRequest. @@ -205,8 +221,8 @@ func (b *Backend) Receive() (FrontendMessage, error) { return nil, err } - if b.MessageTracer != nil { - b.MessageTracer.TraceMessage('F', int32(5+len(msgBody)), msg) + if b.tracer != nil { + b.tracer.traceMessage('F', int32(5+len(msgBody)), msg) } return msg, nil diff --git a/pgproto3/frontend.go b/pgproto3/frontend.go index 342a0ddd..321d0bf9 100644 --- a/pgproto3/frontend.go +++ b/pgproto3/frontend.go @@ -1,6 +1,7 @@ package pgproto3 import ( + "bytes" "encoding/binary" "errors" "fmt" @@ -12,10 +13,10 @@ type Frontend struct { cr *chunkReader w io.Writer - // MessageTracer is used to trace messages when Send or Receive is called. This means an outbound message is traced + // tracer is used to trace messages when Send or Receive is called. This means an outbound message is traced // before it is actually transmitted (i.e. before Flush). It is safe to change this variable when the Frontend is - // idle. Setting and unsetting MessageTracer provides equivalent functionality to PQtrace and PQuntrace in libpq. - MessageTracer MessageTracer + // idle. Setting and unsetting tracer provides equivalent functionality to PQtrace and PQuntrace in libpq. + tracer *tracer wbuf []byte @@ -65,16 +66,25 @@ func NewFrontend(r io.Reader, w io.Writer) *Frontend { // Send sends a message to the backend (i.e. the server). The message is not guaranteed to be written until Flush is // called. +// +// Send can work with any FrontendMessage. Some commonly used message types such as Bind have specialized send methods +// such as SendBind. These methods should be preferred when the type of message is known up front (e.g. when building an +// extended query protocol query) as they may be faster due to knowing the type of msg rather than it being hidden +// behind an interface. func (f *Frontend) Send(msg FrontendMessage) { prevLen := len(f.wbuf) f.wbuf = msg.Encode(f.wbuf) - if f.MessageTracer != nil { - f.MessageTracer.TraceMessage('F', int32(len(f.wbuf)-prevLen), msg) + if f.tracer != nil { + f.tracer.traceMessage('F', int32(len(f.wbuf)-prevLen), msg) } } // Flush writes any pending messages to the backend (i.e. the server). func (f *Frontend) Flush() error { + if len(f.wbuf) == 0 { + return nil + } + n, err := f.w.Write(f.wbuf) const maxLen = 1024 @@ -91,6 +101,102 @@ func (f *Frontend) Flush() error { return nil } +// Trace starts tracing the message traffic to w. It writes in a similar format to that produced by the libpq function +// PQtrace. +func (f *Frontend) Trace(w io.Writer, options TracerOptions) { + f.tracer = &tracer{ + w: w, + buf: &bytes.Buffer{}, + TracerOptions: options, + } +} + +// Untrace stops tracing. +func (f *Frontend) Untrace() { + f.tracer = nil +} + +// SendBind sends a Bind message to the backend (i.e. the server). The message is not guaranteed to be written until +// Flush is called. +func (f *Frontend) SendBind(msg *Bind) { + prevLen := len(f.wbuf) + f.wbuf = msg.Encode(f.wbuf) + if f.tracer != nil { + f.tracer.traceBind('F', int32(len(f.wbuf)-prevLen), msg) + } +} + +// SendParse sends a Parse message to the backend (i.e. the server). The message is not guaranteed to be written until +// Flush is called. +func (f *Frontend) SendParse(msg *Parse) { + prevLen := len(f.wbuf) + f.wbuf = msg.Encode(f.wbuf) + if f.tracer != nil { + f.tracer.traceParse('F', int32(len(f.wbuf)-prevLen), msg) + } +} + +// SendDescribe sends a Describe message to the backend (i.e. the server). The message is not guaranteed to be written until +// Flush is called. +func (f *Frontend) SendDescribe(msg *Describe) { + prevLen := len(f.wbuf) + f.wbuf = msg.Encode(f.wbuf) + if f.tracer != nil { + f.tracer.traceDescribe('F', int32(len(f.wbuf)-prevLen), msg) + } +} + +// SendExecute sends a Execute message to the backend (i.e. the server). The message is not guaranteed to be written until +// Flush is called. +func (f *Frontend) SendExecute(msg *Execute) { + prevLen := len(f.wbuf) + f.wbuf = msg.Encode(f.wbuf) + if f.tracer != nil { + f.tracer.traceExecute('F', int32(len(f.wbuf)-prevLen), msg) + } +} + +// SendSync sends a Sync message to the backend (i.e. the server). The message is not guaranteed to be written until +// Flush is called. +func (f *Frontend) SendSync(msg *Sync) { + prevLen := len(f.wbuf) + f.wbuf = msg.Encode(f.wbuf) + if f.tracer != nil { + f.tracer.traceSync('F', int32(len(f.wbuf)-prevLen), msg) + } +} + +// SendQuery sends a Query message to the backend (i.e. the server). The message is not guaranteed to be written until +// Flush is called. +func (f *Frontend) SendQuery(msg *Query) { + prevLen := len(f.wbuf) + f.wbuf = msg.Encode(f.wbuf) + if f.tracer != nil { + f.tracer.traceQuery('F', int32(len(f.wbuf)-prevLen), msg) + } +} + +// SendUnbufferedEncodedCopyData immediately sends an encoded CopyData message to the backend (i.e. the server). This method +// is more efficient than sending a CopyData message with Send as the message data is not copied to the internal buffer +// before being written out. The internal buffer is flushed before the message is sent. +func (f *Frontend) SendUnbufferedEncodedCopyData(msg []byte) error { + err := f.Flush() + if err != nil { + return err + } + + n, err := f.w.Write(msg) + if err != nil { + return &writeError{err: err, safeToRetry: n == 0} + } + + if f.tracer != nil { + f.tracer.traceCopyData('F', int32(len(msg)-1), &CopyData{}) + } + + return nil +} + func translateEOFtoErrUnexpectedEOF(err error) error { if err == io.EOF { return io.ErrUnexpectedEOF @@ -179,8 +285,8 @@ func (f *Frontend) Receive() (BackendMessage, error) { return nil, err } - if f.MessageTracer != nil { - f.MessageTracer.TraceMessage('B', int32(5+len(msgBody)), msg) + if f.tracer != nil { + f.tracer.traceMessage('B', int32(5+len(msgBody)), msg) } return msg, nil diff --git a/pgproto3/trace.go b/pgproto3/trace.go index b35ecdb6..704b2ee7 100644 --- a/pgproto3/trace.go +++ b/pgproto3/trace.go @@ -8,16 +8,16 @@ import ( "time" ) -// MessageTracer is an interface that traces the messages send to and from a Backend or Frontend. -type MessageTracer interface { - // TraceMessage tracks the sending or receiving of a message. sender is either 'F' for frontend or 'B' for backend. - TraceMessage(sender byte, encodedLen int32, msg Message) +// tracer traces the messages send to and from a Backend or Frontend. The format it produces roughly mimics the +// format produced by the libpq C function PQtrace. +type tracer struct { + w io.Writer + buf *bytes.Buffer + TracerOptions } -// LibpqMessageTracer is a MessageTracer that roughly mimics the format produced by the libpq C function PQtrace. -type LibpqMessageTracer struct { - Writer io.Writer - +// TracerOptions controls tracing behavior. It is roughly equivalent to the libpq function PQsetTraceFlags. +type TracerOptions struct { // SuppressTimestamps prevents printing of timestamps. SuppressTimestamps bool @@ -25,148 +25,394 @@ type LibpqMessageTracer struct { RegressMode bool } -func (t *LibpqMessageTracer) TraceMessage(sender byte, encodedLen int32, msg Message) { - buf := &bytes.Buffer{} - - if !t.SuppressTimestamps { - now := time.Now() - buf.WriteString(now.Format("2006-01-02 15:04:05.000000")) - buf.WriteByte('\t') - } - - buf.WriteByte(sender) - buf.WriteByte('\t') - +func (t *tracer) traceMessage(sender byte, encodedLen int32, msg Message) { switch msg := msg.(type) { case *AuthenticationCleartextPassword: - buf.WriteString("AuthenticationCleartextPassword") + t.traceAuthenticationCleartextPassword(sender, encodedLen, msg) case *AuthenticationGSS: - buf.WriteString("AuthenticationGSS") + t.traceAuthenticationGSS(sender, encodedLen, msg) case *AuthenticationGSSContinue: - buf.WriteString("AuthenticationGSSContinue") + t.traceAuthenticationGSSContinue(sender, encodedLen, msg) case *AuthenticationMD5Password: - buf.WriteString("AuthenticationMD5Password") + t.traceAuthenticationMD5Password(sender, encodedLen, msg) case *AuthenticationOk: - buf.WriteString("AuthenticationOk") + t.traceAuthenticationOk(sender, encodedLen, msg) case *AuthenticationSASL: - buf.WriteString("AuthenticationSASL") + t.traceAuthenticationSASL(sender, encodedLen, msg) case *AuthenticationSASLContinue: - buf.WriteString("AuthenticationSASLContinue") + t.traceAuthenticationSASLContinue(sender, encodedLen, msg) case *AuthenticationSASLFinal: - buf.WriteString("AuthenticationSASLFinal") + t.traceAuthenticationSASLFinal(sender, encodedLen, msg) case *BackendKeyData: - if t.RegressMode { - buf.WriteString("BackendKeyData\t NNNN NNNN") - } else { - fmt.Fprintf(buf, "BackendKeyData\t %d %d", msg.ProcessID, msg.SecretKey) - } + t.traceBackendKeyData(sender, encodedLen, msg) case *Bind: - fmt.Fprintf(buf, "Bind\t %s %s %d", traceDoubleQuotedString([]byte(msg.DestinationPortal)), traceDoubleQuotedString([]byte(msg.PreparedStatement)), len(msg.ParameterFormatCodes)) - for _, fc := range msg.ParameterFormatCodes { - fmt.Fprintf(buf, " %d", fc) - } - fmt.Fprintf(buf, " %d", len(msg.Parameters)) - for _, p := range msg.Parameters { - fmt.Fprintf(buf, " %s", traceSingleQuotedString(p)) - } - fmt.Fprintf(buf, " %d", len(msg.ResultFormatCodes)) - for _, fc := range msg.ResultFormatCodes { - fmt.Fprintf(buf, " %d", fc) - } + t.traceBind(sender, encodedLen, msg) case *BindComplete: - buf.WriteString("BindComplete") + t.traceBindComplete(sender, encodedLen, msg) case *CancelRequest: - buf.WriteString("CancelRequest") + t.traceCancelRequest(sender, encodedLen, msg) case *Close: - buf.WriteString("Close") + t.traceClose(sender, encodedLen, msg) case *CloseComplete: - buf.WriteString("CloseComplete") + t.traceCloseComplete(sender, encodedLen, msg) case *CommandComplete: - fmt.Fprintf(buf, "CommandComplete\t %s", traceDoubleQuotedString(msg.CommandTag)) + t.traceCommandComplete(sender, encodedLen, msg) case *CopyBothResponse: - buf.WriteString("CopyBothResponse") + t.traceCopyBothResponse(sender, encodedLen, msg) case *CopyData: - buf.WriteString("CopyData") + t.traceCopyData(sender, encodedLen, msg) case *CopyDone: - buf.WriteString("CopyDone") + t.traceCopyDone(sender, encodedLen, msg) case *CopyFail: - fmt.Fprintf(buf, "CopyFail\t %s", traceDoubleQuotedString([]byte(msg.Message))) + t.traceCopyFail(sender, encodedLen, msg) case *CopyInResponse: - buf.WriteString("CopyInResponse") + t.traceCopyInResponse(sender, encodedLen, msg) case *CopyOutResponse: - buf.WriteString("CopyOutResponse") + t.traceCopyOutResponse(sender, encodedLen, msg) case *DataRow: - fmt.Fprintf(buf, "DataRow\t %d", len(msg.Values)) - for _, v := range msg.Values { - if v == nil { - buf.WriteString(" -1") - } else { - fmt.Fprintf(buf, " %d %s", len(v), traceSingleQuotedString(v)) - } - } + t.traceDataRow(sender, encodedLen, msg) case *Describe: - fmt.Fprintf(buf, "Describe\t %c %s", msg.ObjectType, traceDoubleQuotedString([]byte(msg.Name))) + t.traceDescribe(sender, encodedLen, msg) case *EmptyQueryResponse: - buf.WriteString("EmptyQueryResponse") + t.traceEmptyQueryResponse(sender, encodedLen, msg) case *ErrorResponse: - buf.WriteString("ErrorResponse") + t.traceErrorResponse(sender, encodedLen, msg) case *Execute: - fmt.Fprintf(buf, "Execute\t %s %d", traceDoubleQuotedString([]byte(msg.Portal)), msg.MaxRows) + t.traceExecute(sender, encodedLen, msg) case *Flush: - buf.WriteString("Flush") + t.traceFlush(sender, encodedLen, msg) case *FunctionCall: - buf.WriteString("FunctionCall") + t.traceFunctionCall(sender, encodedLen, msg) case *FunctionCallResponse: - buf.WriteString("FunctionCallResponse") + t.traceFunctionCallResponse(sender, encodedLen, msg) case *GSSEncRequest: - buf.WriteString("GSSEncRequest") + t.traceGSSEncRequest(sender, encodedLen, msg) case *NoData: - buf.WriteString("NoData") + t.traceNoData(sender, encodedLen, msg) case *NoticeResponse: - buf.WriteString("NoticeResponse") + t.traceNoticeResponse(sender, encodedLen, msg) case *NotificationResponse: - fmt.Fprintf(buf, "NotificationResponse\t %d %s %s", msg.PID, traceDoubleQuotedString([]byte(msg.Channel)), traceDoubleQuotedString([]byte(msg.Payload))) + t.traceNotificationResponse(sender, encodedLen, msg) case *ParameterDescription: - buf.WriteString("ParameterDescription") + t.traceParameterDescription(sender, encodedLen, msg) case *ParameterStatus: - fmt.Fprintf(buf, "ParameterStatus\t %s %s", traceDoubleQuotedString([]byte(msg.Name)), traceDoubleQuotedString([]byte(msg.Value))) + t.traceParameterStatus(sender, encodedLen, msg) case *Parse: - fmt.Fprintf(buf, "Parse\t %s %s %d", traceDoubleQuotedString([]byte(msg.Name)), traceDoubleQuotedString([]byte(msg.Query)), len(msg.ParameterOIDs)) - for _, oid := range msg.ParameterOIDs { - fmt.Fprintf(buf, " %d", oid) - } + t.traceParse(sender, encodedLen, msg) case *ParseComplete: - buf.WriteString("ParseComplete") + t.traceParseComplete(sender, encodedLen, msg) case *PortalSuspended: - buf.WriteString("PortalSuspended") + t.tracePortalSuspended(sender, encodedLen, msg) case *Query: - buf.WriteString("Query\t") - fmt.Fprintf(buf, ` "%s"`, msg.String) + t.traceQuery(sender, encodedLen, msg) case *ReadyForQuery: - fmt.Fprintf(buf, "ReadyForQuery\t %c", msg.TxStatus) + t.traceReadyForQuery(sender, encodedLen, msg) case *RowDescription: - buf.WriteString("RowDescription\t") - fmt.Fprintf(buf, " %d", len(msg.Fields)) - for _, fd := range msg.Fields { - fmt.Fprintf(buf, ` %s %d %d %d %d %d %d`, traceDoubleQuotedString(fd.Name), fd.TableOID, fd.TableAttributeNumber, fd.DataTypeOID, fd.DataTypeSize, fd.TypeModifier, fd.Format) - } + t.traceRowDescription(sender, encodedLen, msg) case *SSLRequest: - buf.WriteString("SSLRequest") + t.traceSSLRequest(sender, encodedLen, msg) case *StartupMessage: - buf.WriteString("StartupMessage") + t.traceStartupMessage(sender, encodedLen, msg) case *Sync: - buf.WriteString("Sync") + t.traceSync(sender, encodedLen, msg) case *Terminate: - buf.WriteString("Terminate") + t.traceTerminate(sender, encodedLen, msg) default: - buf.WriteString("Unknown") + t.beginTrace(sender, encodedLen, "Unknown") + t.finishTrace() } - - buf.WriteByte('\n') - buf.WriteTo(t.Writer) } -// traceDoubleQuotedString returns buf as a double-quoted string without any escaping. It is roughly equivalent to +func (t *tracer) traceAuthenticationCleartextPassword(sender byte, encodedLen int32, msg *AuthenticationCleartextPassword) { + t.beginTrace(sender, encodedLen, "AuthenticationCleartextPassword") + t.finishTrace() +} + +func (t *tracer) traceAuthenticationGSS(sender byte, encodedLen int32, msg *AuthenticationGSS) { + t.beginTrace(sender, encodedLen, "AuthenticationGSS") + t.finishTrace() +} + +func (t *tracer) traceAuthenticationGSSContinue(sender byte, encodedLen int32, msg *AuthenticationGSSContinue) { + t.beginTrace(sender, encodedLen, "AuthenticationGSSContinue") + t.finishTrace() +} + +func (t *tracer) traceAuthenticationMD5Password(sender byte, encodedLen int32, msg *AuthenticationMD5Password) { + t.beginTrace(sender, encodedLen, "AuthenticationMD5Password") + t.finishTrace() +} + +func (t *tracer) traceAuthenticationOk(sender byte, encodedLen int32, msg *AuthenticationOk) { + t.beginTrace(sender, encodedLen, "AuthenticationOk") + t.finishTrace() +} + +func (t *tracer) traceAuthenticationSASL(sender byte, encodedLen int32, msg *AuthenticationSASL) { + t.beginTrace(sender, encodedLen, "AuthenticationSASL") + t.finishTrace() +} + +func (t *tracer) traceAuthenticationSASLContinue(sender byte, encodedLen int32, msg *AuthenticationSASLContinue) { + t.beginTrace(sender, encodedLen, "AuthenticationSASLContinue") + t.finishTrace() +} + +func (t *tracer) traceAuthenticationSASLFinal(sender byte, encodedLen int32, msg *AuthenticationSASLFinal) { + t.beginTrace(sender, encodedLen, "AuthenticationSASLFinal") + t.finishTrace() +} + +func (t *tracer) traceBackendKeyData(sender byte, encodedLen int32, msg *BackendKeyData) { + t.beginTrace(sender, encodedLen, "BackendKeyData") + if t.RegressMode { + t.buf.WriteString("\t NNNN NNNN") + } else { + fmt.Fprintf(t.buf, "\t %d %d", msg.ProcessID, msg.SecretKey) + } + t.finishTrace() +} + +func (t *tracer) traceBind(sender byte, encodedLen int32, msg *Bind) { + t.beginTrace(sender, encodedLen, "Bind") + fmt.Fprintf(t.buf, "\t %s %s %d", traceDoubleQuotedString([]byte(msg.DestinationPortal)), traceDoubleQuotedString([]byte(msg.PreparedStatement)), len(msg.ParameterFormatCodes)) + for _, fc := range msg.ParameterFormatCodes { + fmt.Fprintf(t.buf, " %d", fc) + } + fmt.Fprintf(t.buf, " %d", len(msg.Parameters)) + for _, p := range msg.Parameters { + fmt.Fprintf(t.buf, " %s", traceSingleQuotedString(p)) + } + fmt.Fprintf(t.buf, " %d", len(msg.ResultFormatCodes)) + for _, fc := range msg.ResultFormatCodes { + fmt.Fprintf(t.buf, " %d", fc) + } + t.finishTrace() +} + +func (t *tracer) traceBindComplete(sender byte, encodedLen int32, msg *BindComplete) { + t.beginTrace(sender, encodedLen, "BindComplete") + t.finishTrace() +} + +func (t *tracer) traceCancelRequest(sender byte, encodedLen int32, msg *CancelRequest) { + t.beginTrace(sender, encodedLen, "CancelRequest") + t.finishTrace() +} + +func (t *tracer) traceClose(sender byte, encodedLen int32, msg *Close) { + t.beginTrace(sender, encodedLen, "Close") + t.finishTrace() +} + +func (t *tracer) traceCloseComplete(sender byte, encodedLen int32, msg *CloseComplete) { + t.beginTrace(sender, encodedLen, "CloseComplete") + t.finishTrace() +} + +func (t *tracer) traceCommandComplete(sender byte, encodedLen int32, msg *CommandComplete) { + t.beginTrace(sender, encodedLen, "CommandComplete") + fmt.Fprintf(t.buf, "\t %s", traceDoubleQuotedString(msg.CommandTag)) + t.finishTrace() +} + +func (t *tracer) traceCopyBothResponse(sender byte, encodedLen int32, msg *CopyBothResponse) { + t.beginTrace(sender, encodedLen, "CopyBothResponse") + t.finishTrace() +} + +func (t *tracer) traceCopyData(sender byte, encodedLen int32, msg *CopyData) { + t.beginTrace(sender, encodedLen, "CopyData") + t.finishTrace() +} + +func (t *tracer) traceCopyDone(sender byte, encodedLen int32, msg *CopyDone) { + t.beginTrace(sender, encodedLen, "CopyDone") + t.finishTrace() +} + +func (t *tracer) traceCopyFail(sender byte, encodedLen int32, msg *CopyFail) { + t.beginTrace(sender, encodedLen, "CopyFail") + fmt.Fprintf(t.buf, "\t %s", traceDoubleQuotedString([]byte(msg.Message))) + t.finishTrace() +} + +func (t *tracer) traceCopyInResponse(sender byte, encodedLen int32, msg *CopyInResponse) { + t.beginTrace(sender, encodedLen, "CopyInResponse") + t.finishTrace() +} + +func (t *tracer) traceCopyOutResponse(sender byte, encodedLen int32, msg *CopyOutResponse) { + t.beginTrace(sender, encodedLen, "CopyOutResponse") + t.finishTrace() +} + +func (t *tracer) traceDataRow(sender byte, encodedLen int32, msg *DataRow) { + t.beginTrace(sender, encodedLen, "DataRow") + fmt.Fprintf(t.buf, "\t %d", len(msg.Values)) + for _, v := range msg.Values { + if v == nil { + t.buf.WriteString(" -1") + } else { + fmt.Fprintf(t.buf, " %d %s", len(v), traceSingleQuotedString(v)) + } + } + t.finishTrace() +} + +func (t *tracer) traceDescribe(sender byte, encodedLen int32, msg *Describe) { + t.beginTrace(sender, encodedLen, "Describe") + fmt.Fprintf(t.buf, "\t %c %s", msg.ObjectType, traceDoubleQuotedString([]byte(msg.Name))) + t.finishTrace() +} + +func (t *tracer) traceEmptyQueryResponse(sender byte, encodedLen int32, msg *EmptyQueryResponse) { + t.beginTrace(sender, encodedLen, "EmptyQueryResponse") + t.finishTrace() +} + +func (t *tracer) traceErrorResponse(sender byte, encodedLen int32, msg *ErrorResponse) { + t.beginTrace(sender, encodedLen, "ErrorResponse") + t.finishTrace() +} + +func (t *tracer) traceExecute(sender byte, encodedLen int32, msg *Execute) { + t.beginTrace(sender, encodedLen, "Execute") + fmt.Fprintf(t.buf, "\t %s %d", traceDoubleQuotedString([]byte(msg.Portal)), msg.MaxRows) + t.finishTrace() +} + +func (t *tracer) traceFlush(sender byte, encodedLen int32, msg *Flush) { + t.beginTrace(sender, encodedLen, "Flush") + t.finishTrace() +} + +func (t *tracer) traceFunctionCall(sender byte, encodedLen int32, msg *FunctionCall) { + t.beginTrace(sender, encodedLen, "FunctionCall") + t.finishTrace() +} + +func (t *tracer) traceFunctionCallResponse(sender byte, encodedLen int32, msg *FunctionCallResponse) { + t.beginTrace(sender, encodedLen, "FunctionCallResponse") + t.finishTrace() +} + +func (t *tracer) traceGSSEncRequest(sender byte, encodedLen int32, msg *GSSEncRequest) { + t.beginTrace(sender, encodedLen, "GSSEncRequest") + t.finishTrace() +} + +func (t *tracer) traceNoData(sender byte, encodedLen int32, msg *NoData) { + t.beginTrace(sender, encodedLen, "NoData") + t.finishTrace() +} + +func (t *tracer) traceNoticeResponse(sender byte, encodedLen int32, msg *NoticeResponse) { + t.beginTrace(sender, encodedLen, "NoticeResponse") + t.finishTrace() +} + +func (t *tracer) traceNotificationResponse(sender byte, encodedLen int32, msg *NotificationResponse) { + t.beginTrace(sender, encodedLen, "NotificationResponse") + fmt.Fprintf(t.buf, "\t %d %s %s", msg.PID, traceDoubleQuotedString([]byte(msg.Channel)), traceDoubleQuotedString([]byte(msg.Payload))) + t.finishTrace() +} + +func (t *tracer) traceParameterDescription(sender byte, encodedLen int32, msg *ParameterDescription) { + t.beginTrace(sender, encodedLen, "ParameterDescription") + t.finishTrace() +} + +func (t *tracer) traceParameterStatus(sender byte, encodedLen int32, msg *ParameterStatus) { + t.beginTrace(sender, encodedLen, "ParameterStatus") + fmt.Fprintf(t.buf, "\t %s %s", traceDoubleQuotedString([]byte(msg.Name)), traceDoubleQuotedString([]byte(msg.Value))) + t.finishTrace() +} + +func (t *tracer) traceParse(sender byte, encodedLen int32, msg *Parse) { + t.beginTrace(sender, encodedLen, "Parse") + fmt.Fprintf(t.buf, "\t %s %s %d", traceDoubleQuotedString([]byte(msg.Name)), traceDoubleQuotedString([]byte(msg.Query)), len(msg.ParameterOIDs)) + for _, oid := range msg.ParameterOIDs { + fmt.Fprintf(t.buf, " %d", oid) + } + t.finishTrace() +} + +func (t *tracer) traceParseComplete(sender byte, encodedLen int32, msg *ParseComplete) { + t.beginTrace(sender, encodedLen, "ParseComplete") + t.finishTrace() +} + +func (t *tracer) tracePortalSuspended(sender byte, encodedLen int32, msg *PortalSuspended) { + t.beginTrace(sender, encodedLen, "PortalSuspended") + t.finishTrace() +} + +func (t *tracer) traceQuery(sender byte, encodedLen int32, msg *Query) { + t.beginTrace(sender, encodedLen, "Query") + fmt.Fprintf(t.buf, "\t %s", traceDoubleQuotedString([]byte(msg.String))) + t.finishTrace() +} + +func (t *tracer) traceReadyForQuery(sender byte, encodedLen int32, msg *ReadyForQuery) { + t.beginTrace(sender, encodedLen, "ReadyForQuery") + fmt.Fprintf(t.buf, "\t %c", msg.TxStatus) + t.finishTrace() +} + +func (t *tracer) traceRowDescription(sender byte, encodedLen int32, msg *RowDescription) { + t.beginTrace(sender, encodedLen, "RowDescription") + fmt.Fprintf(t.buf, "\t %d", len(msg.Fields)) + for _, fd := range msg.Fields { + fmt.Fprintf(t.buf, ` %s %d %d %d %d %d %d`, traceDoubleQuotedString(fd.Name), fd.TableOID, fd.TableAttributeNumber, fd.DataTypeOID, fd.DataTypeSize, fd.TypeModifier, fd.Format) + } + t.finishTrace() +} + +func (t *tracer) traceSSLRequest(sender byte, encodedLen int32, msg *SSLRequest) { + t.beginTrace(sender, encodedLen, "SSLRequest") + t.finishTrace() +} + +func (t *tracer) traceStartupMessage(sender byte, encodedLen int32, msg *StartupMessage) { + t.beginTrace(sender, encodedLen, "StartupMessage") + t.finishTrace() +} + +func (t *tracer) traceSync(sender byte, encodedLen int32, msg *Sync) { + t.beginTrace(sender, encodedLen, "Sync") + t.finishTrace() +} + +func (t *tracer) traceTerminate(sender byte, encodedLen int32, msg *Terminate) { + t.beginTrace(sender, encodedLen, "Terminate") + t.finishTrace() +} + +func (t *tracer) beginTrace(sender byte, encodedLen int32, msgType string) { + if !t.SuppressTimestamps { + now := time.Now() + t.buf.WriteString(now.Format("2006-01-02 15:04:05.000000")) + t.buf.WriteByte('\t') + } + + t.buf.WriteByte(sender) + t.buf.WriteByte('\t') + t.buf.WriteString(msgType) +} + +func (t *tracer) finishTrace() { + t.buf.WriteByte('\n') + t.buf.WriteTo(t.w) + + if t.buf.Cap() > 1024 { + t.buf = &bytes.Buffer{} + } else { + t.buf.Reset() + } +} + +// traceDoubleQuotedString returns t.buf as a double-quoted string without any escaping. It is roughly equivalent to // pqTraceOutputString in libpq. func traceDoubleQuotedString(buf []byte) string { return `"` + string(buf) + `"` diff --git a/pgproto3/trace_test.go b/pgproto3/trace_test.go index f78bd346..a4057008 100644 --- a/pgproto3/trace_test.go +++ b/pgproto3/trace_test.go @@ -13,7 +13,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestLibpqMessageTracer(t *testing.T) { +func TestTrace(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) @@ -26,11 +26,10 @@ func TestLibpqMessageTracer(t *testing.T) { config.BuildFrontend = func(r io.Reader, w io.Writer) *pgproto3.Frontend { f := pgproto3.NewFrontend(r, w) - f.MessageTracer = &pgproto3.LibpqMessageTracer{ - Writer: traceOutput, + f.Trace(traceOutput, pgproto3.TracerOptions{ SuppressTimestamps: true, RegressMode: true, - } + }) return f }