pgx/tracelog/tracelog.go

296 lines
7.7 KiB
Go

// Package tracelog provides a tracer that acts as a traditional logger.
package tracelog
import (
"context"
"encoding/hex"
"errors"
"fmt"
"time"
"github.com/jackc/pgx/v5"
)
// LogLevel represents the pgx logging level. See LogLevel* constants for
// possible values.
type LogLevel int
// The values for log levels are chosen such that the zero value means that no
// log level was specified.
const (
LogLevelTrace = LogLevel(6)
LogLevelDebug = LogLevel(5)
LogLevelInfo = LogLevel(4)
LogLevelWarn = LogLevel(3)
LogLevelError = LogLevel(2)
LogLevelNone = LogLevel(1)
)
func (ll LogLevel) String() string {
switch ll {
case LogLevelTrace:
return "trace"
case LogLevelDebug:
return "debug"
case LogLevelInfo:
return "info"
case LogLevelWarn:
return "warn"
case LogLevelError:
return "error"
case LogLevelNone:
return "none"
default:
return fmt.Sprintf("invalid level %d", ll)
}
}
// Logger is the interface used to get log output from pgx.
type Logger interface {
// Log a message at the given level with data key/value pairs. data may be nil.
Log(ctx context.Context, level LogLevel, msg string, data map[string]any)
}
// LoggerFunc is a wrapper around a function to satisfy the pgx.Logger interface
type LoggerFunc func(ctx context.Context, level LogLevel, msg string, data map[string]interface{})
// Log delegates the logging request to the wrapped function
func (f LoggerFunc) Log(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) {
f(ctx, level, msg, data)
}
// LogLevelFromString converts log level string to constant
//
// Valid levels:
// trace
// debug
// info
// warn
// error
// none
func LogLevelFromString(s string) (LogLevel, error) {
switch s {
case "trace":
return LogLevelTrace, nil
case "debug":
return LogLevelDebug, nil
case "info":
return LogLevelInfo, nil
case "warn":
return LogLevelWarn, nil
case "error":
return LogLevelError, nil
case "none":
return LogLevelNone, nil
default:
return 0, errors.New("invalid log level")
}
}
func logQueryArgs(args []any) []any {
logArgs := make([]any, 0, len(args))
for _, a := range args {
switch v := a.(type) {
case []byte:
if len(v) < 64 {
a = hex.EncodeToString(v)
} else {
a = fmt.Sprintf("%x (truncated %d bytes)", v[:64], len(v)-64)
}
case string:
if len(v) > 64 {
a = fmt.Sprintf("%s (truncated %d bytes)", v[:64], len(v)-64)
}
}
logArgs = append(logArgs, a)
}
return logArgs
}
// TraceLog implements pgx.QueryTracer, pgx.BatchTracer, pgx.ConnectTracer, and pgx.CopyFromTracer. All fields are
// required.
type TraceLog struct {
Logger Logger
LogLevel LogLevel
}
type ctxKey int
const (
_ ctxKey = iota
tracelogQueryCtxKey
tracelogBatchCtxKey
tracelogCopyFromCtxKey
tracelogConnectCtxKey
)
type traceQueryData struct {
startTime time.Time
sql string
args []any
}
func (tl *TraceLog) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
return context.WithValue(ctx, tracelogQueryCtxKey, &traceQueryData{
startTime: time.Now(),
sql: data.SQL,
args: data.Args,
})
}
func (tl *TraceLog) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
queryData := ctx.Value(tracelogQueryCtxKey).(*traceQueryData)
endTime := time.Now()
interval := endTime.Sub(queryData.startTime)
if data.Err != nil {
if tl.shouldLog(LogLevelError) {
tl.log(ctx, conn, LogLevelError, "Query", map[string]any{"sql": queryData.sql, "args": logQueryArgs(queryData.args), "err": data.Err, "time": interval})
}
return
}
if tl.shouldLog(LogLevelInfo) {
tl.log(ctx, conn, LogLevelInfo, "Query", map[string]any{"sql": queryData.sql, "args": logQueryArgs(queryData.args), "time": interval, "commandTag": data.CommandTag.String()})
}
}
type traceBatchData struct {
startTime time.Time
}
func (tl *TraceLog) TraceBatchStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
return context.WithValue(ctx, tracelogBatchCtxKey, &traceBatchData{
startTime: time.Now(),
})
}
func (tl *TraceLog) TraceBatchQuery(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
if data.Err != nil {
if tl.shouldLog(LogLevelError) {
tl.log(ctx, conn, LogLevelError, "BatchQuery", map[string]any{"sql": data.SQL, "args": logQueryArgs(data.Args), "err": data.Err})
}
return
}
if tl.shouldLog(LogLevelInfo) {
tl.log(ctx, conn, LogLevelInfo, "BatchQuery", map[string]any{"sql": data.SQL, "args": logQueryArgs(data.Args), "commandTag": data.CommandTag.String()})
}
}
func (tl *TraceLog) TraceBatchEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
queryData := ctx.Value(tracelogBatchCtxKey).(*traceBatchData)
endTime := time.Now()
interval := endTime.Sub(queryData.startTime)
if data.Err != nil {
if tl.shouldLog(LogLevelError) {
tl.log(ctx, conn, LogLevelError, "BatchClose", map[string]any{"err": data.Err, "time": interval})
}
return
}
if tl.shouldLog(LogLevelInfo) {
tl.log(ctx, conn, LogLevelInfo, "BatchClose", map[string]any{"time": interval})
}
}
type traceCopyFromData struct {
startTime time.Time
TableName pgx.Identifier
ColumnNames []string
}
func (tl *TraceLog) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context {
return context.WithValue(ctx, tracelogCopyFromCtxKey, &traceCopyFromData{
startTime: time.Now(),
TableName: data.TableName,
ColumnNames: data.ColumnNames,
})
}
func (tl *TraceLog) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) {
copyFromData := ctx.Value(tracelogCopyFromCtxKey).(*traceCopyFromData)
endTime := time.Now()
interval := endTime.Sub(copyFromData.startTime)
if data.Err != nil {
if tl.shouldLog(LogLevelError) {
tl.log(ctx, conn, LogLevelError, "CopyFrom", map[string]any{"tableName": copyFromData.TableName, "columnNames": copyFromData.ColumnNames, "err": data.Err, "time": interval})
}
return
}
if tl.shouldLog(LogLevelInfo) {
tl.log(ctx, conn, LogLevelInfo, "CopyFrom", map[string]any{"tableName": copyFromData.TableName, "columnNames": copyFromData.ColumnNames, "err": data.Err, "time": interval, "rowCount": data.CommandTag.RowsAffected()})
}
}
type traceConnectData struct {
startTime time.Time
connConfig *pgx.ConnConfig
}
func (tl *TraceLog) TraceConnectStart(ctx context.Context, data pgx.TraceConnectStartData) context.Context {
return context.WithValue(ctx, tracelogConnectCtxKey, &traceConnectData{
startTime: time.Now(),
connConfig: data.ConnConfig,
})
}
func (tl *TraceLog) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEndData) {
connectData := ctx.Value(tracelogConnectCtxKey).(*traceConnectData)
endTime := time.Now()
interval := endTime.Sub(connectData.startTime)
if data.Err != nil {
if tl.shouldLog(LogLevelError) {
tl.Logger.Log(ctx, LogLevelError, "Connect", map[string]any{
"host": connectData.connConfig.Host,
"port": connectData.connConfig.Port,
"database": connectData.connConfig.Database,
"time": interval,
"err": data.Err,
})
}
return
}
if data.Conn != nil {
if tl.shouldLog(LogLevelInfo) {
tl.log(ctx, data.Conn, LogLevelInfo, "Connect", map[string]any{
"host": connectData.connConfig.Host,
"port": connectData.connConfig.Port,
"database": connectData.connConfig.Database,
"time": interval,
})
}
}
}
func (tl *TraceLog) shouldLog(lvl LogLevel) bool {
return tl.LogLevel >= lvl
}
func (tl *TraceLog) log(ctx context.Context, conn *pgx.Conn, lvl LogLevel, msg string, data map[string]any) {
if data == nil {
data = map[string]any{}
}
pgConn := conn.PgConn()
if pgConn != nil {
pid := pgConn.PID()
if pid != 0 {
data["pid"] = pid
}
}
tl.Logger.Log(ctx, lvl, msg, data)
}