Add context.Context to Logger interface

This allows custom logger adapters to add additional fields to log
messages. For example, a HTTP server may with to log the request ID.

fixes #428
pull/586/head
Jack Christensen 2019-08-03 16:16:21 -05:00
parent ab1edc79e0
commit 3028821487
11 changed files with 68 additions and 22 deletions

View File

@ -1,6 +1,8 @@
package pgx package pgx
import ( import (
"context"
"github.com/jackc/pgconn" "github.com/jackc/pgconn"
"github.com/jackc/pgtype" "github.com/jackc/pgtype"
errors "golang.org/x/xerrors" errors "golang.org/x/xerrors"
@ -47,6 +49,7 @@ type BatchResults interface {
} }
type batchResults struct { type batchResults struct {
ctx context.Context
conn *Conn conn *Conn
mrr *pgconn.MultiResultReader mrr *pgconn.MultiResultReader
err error err error
@ -71,7 +74,7 @@ func (br *batchResults) ExecResults() (pgconn.CommandTag, error) {
// QueryResults reads the results from the next query in the batch as if the query has been sent with Query. // QueryResults reads the results from the next query in the batch as if the query has been sent with Query.
func (br *batchResults) QueryResults() (Rows, error) { func (br *batchResults) QueryResults() (Rows, error) {
rows := br.conn.getRows("batch query", nil) rows := br.conn.getRows(br.ctx, "batch query", nil)
if br.err != nil { if br.err != nil {
rows.err = br.err rows.err = br.err

View File

@ -201,7 +201,8 @@ func BenchmarkSelectWithoutLogging(b *testing.B) {
type discardLogger struct{} type discardLogger struct{}
func (dl discardLogger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {} func (dl discardLogger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
}
func BenchmarkSelectWithLoggingTraceDiscard(b *testing.B) { func BenchmarkSelectWithLoggingTraceDiscard(b *testing.B) {
conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))) conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")))

22
conn.go
View File

@ -146,7 +146,7 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) {
c.logger = c.config.Logger c.logger = c.config.Logger
if c.shouldLog(LogLevelInfo) { if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"host": config.Config.Host}) c.log(ctx, LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"host": config.Config.Host})
} }
c.pgConn, err = pgconn.ConnectConfig(ctx, &config.Config) c.pgConn, err = pgconn.ConnectConfig(ctx, &config.Config)
if err != nil { if err != nil {
@ -155,7 +155,7 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) {
if err != nil { if err != nil {
if c.shouldLog(LogLevelError) { if c.shouldLog(LogLevelError) {
c.log(LogLevelError, "connect failed", map[string]interface{}{"err": err}) c.log(ctx, LogLevelError, "connect failed", map[string]interface{}{"err": err})
} }
return nil, err return nil, err
} }
@ -184,7 +184,7 @@ func (c *Conn) Close(ctx context.Context) error {
err := c.pgConn.Close(ctx) err := c.pgConn.Close(ctx)
c.causeOfDeath = errors.New("Closed") c.causeOfDeath = errors.New("Closed")
if c.shouldLog(LogLevelInfo) { if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, "closed connection", nil) c.log(ctx, LogLevelInfo, "closed connection", nil)
} }
return err return err
} }
@ -205,7 +205,7 @@ func (c *Conn) Prepare(ctx context.Context, name, sql string) (ps *PreparedState
if c.shouldLog(LogLevelError) { if c.shouldLog(LogLevelError) {
defer func() { defer func() {
if err != nil { if err != nil {
c.log(LogLevelError, "Prepare failed", map[string]interface{}{"err": err, "name": name, "sql": sql}) c.log(ctx, LogLevelError, "Prepare failed", map[string]interface{}{"err": err, "name": name, "sql": sql})
} }
}() }()
} }
@ -307,7 +307,7 @@ func (c *Conn) shouldLog(lvl LogLevel) bool {
return c.logger != nil && c.logLevel >= lvl return c.logger != nil && c.logLevel >= lvl
} }
func (c *Conn) log(lvl LogLevel, msg string, data map[string]interface{}) { func (c *Conn) log(ctx context.Context, lvl LogLevel, msg string, data map[string]interface{}) {
if data == nil { if data == nil {
data = map[string]interface{}{} data = map[string]interface{}{}
} }
@ -315,7 +315,7 @@ func (c *Conn) log(lvl LogLevel, msg string, data map[string]interface{}) {
data["pid"] = c.pgConn.PID() data["pid"] = c.pgConn.PID()
} }
c.logger.Log(lvl, msg, data) c.logger.Log(ctx, lvl, msg, data)
} }
// SetLogger replaces the current logger and returns the previous logger. // SetLogger replaces the current logger and returns the previous logger.
@ -386,14 +386,14 @@ func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) (
commandTag, err := c.exec(ctx, sql, arguments...) commandTag, err := c.exec(ctx, sql, arguments...)
if err != nil { if err != nil {
if c.shouldLog(LogLevelError) { if c.shouldLog(LogLevelError) {
c.log(LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err}) c.log(ctx, LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err})
} }
return commandTag, err return commandTag, err
} }
if c.shouldLog(LogLevelInfo) { if c.shouldLog(LogLevelInfo) {
endTime := time.Now() endTime := time.Now()
c.log(LogLevelInfo, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag}) c.log(ctx, LogLevelInfo, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag})
} }
return commandTag, err return commandTag, err
@ -518,7 +518,7 @@ optionLoop:
} }
func (c *Conn) getRows(sql string, args []interface{}) *connRows { func (c *Conn) getRows(ctx context.Context, sql string, args []interface{}) *connRows {
if len(c.preallocatedRows) == 0 { if len(c.preallocatedRows) == 0 {
c.preallocatedRows = make([]connRows, 64) c.preallocatedRows = make([]connRows, 64)
} }
@ -526,6 +526,7 @@ func (c *Conn) getRows(sql string, args []interface{}) *connRows {
r := &c.preallocatedRows[len(c.preallocatedRows)-1] r := &c.preallocatedRows[len(c.preallocatedRows)-1]
c.preallocatedRows = c.preallocatedRows[0 : len(c.preallocatedRows)-1] c.preallocatedRows = c.preallocatedRows[0 : len(c.preallocatedRows)-1]
r.ctx = ctx
r.logger = c r.logger = c
r.connInfo = c.ConnInfo r.connInfo = c.ConnInfo
r.startTime = time.Now() r.startTime = time.Now()
@ -568,7 +569,7 @@ optionLoop:
} }
} }
rows := c.getRows(sql, args) rows := c.getRows(ctx, sql, args)
var err error var err error
if simpleProtocol { if simpleProtocol {
@ -727,6 +728,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
mrr := c.pgConn.ExecBatch(ctx, batch) mrr := c.pgConn.ExecBatch(ctx, batch)
return &batchResults{ return &batchResults{
ctx: ctx,
conn: c, conn: c,
mrr: mrr, mrr: mrr,
} }

View File

@ -543,10 +543,38 @@ type testLogger struct {
logs []testLog logs []testLog
} }
func (l *testLogger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) { func (l *testLogger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
data["ctxdata"] = ctx.Value("ctxdata")
l.logs = append(l.logs, testLog{lvl: level, msg: msg, data: data}) l.logs = append(l.logs, testLog{lvl: level, msg: msg, data: data})
} }
func TestLogPassesContext(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
ctx := context.WithValue(context.Background(), "ctxdata", "foo")
l1 := &testLogger{}
oldLogger := conn.SetLogger(l1)
if oldLogger != nil {
t.Fatalf("Expected conn.SetLogger to return %v, but it was %v", nil, oldLogger)
}
if _, err := conn.Exec(ctx, ";"); err != nil {
t.Fatal(err)
}
if len(l1.logs) != 1 {
t.Fatal("Expected new logger l1 to be called once, but it wasn't")
}
if l1.logs[0].data["ctxdata"] != "foo" {
t.Fatal("Expected context data to be passed to logger, but it wasn't")
}
}
func TestSetLogger(t *testing.T) { func TestSetLogger(t *testing.T) {
t.Parallel() t.Parallel()

View File

@ -3,6 +3,8 @@
package log15adapter package log15adapter
import ( import (
"context"
"github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4"
) )
@ -24,7 +26,7 @@ func NewLogger(l Log15Logger) *Logger {
return &Logger{l: l} return &Logger{l: l}
} }
func (l *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) { func (l *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
logArgs := make([]interface{}, 0, len(data)) logArgs := make([]interface{}, 0, len(data))
for k, v := range data { for k, v := range data {
logArgs = append(logArgs, k, v) logArgs = append(logArgs, k, v)

View File

@ -3,6 +3,8 @@
package logrusadapter package logrusadapter
import ( import (
"context"
"github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -15,7 +17,7 @@ func NewLogger(l logrus.FieldLogger) *Logger {
return &Logger{l: l} return &Logger{l: l}
} }
func (l *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) { func (l *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
var logger logrus.FieldLogger var logger logrus.FieldLogger
if data != nil { if data != nil {
logger = l.l.WithFields(data) logger = l.l.WithFields(data)

View File

@ -3,6 +3,7 @@
package testingadapter package testingadapter
import ( import (
"context"
"fmt" "fmt"
"github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4"
@ -22,7 +23,7 @@ func NewLogger(l TestingLogger) *Logger {
return &Logger{l: l} return &Logger{l: l}
} }
func (l *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) { func (l *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
logArgs := make([]interface{}, 0, 2+len(data)) logArgs := make([]interface{}, 0, 2+len(data))
logArgs = append(logArgs, level, msg) logArgs = append(logArgs, level, msg)
for k, v := range data { for k, v := range data {

View File

@ -2,6 +2,8 @@
package zapadapter package zapadapter
import ( import (
"context"
"github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4"
"go.uber.org/zap" "go.uber.org/zap"
"go.uber.org/zap/zapcore" "go.uber.org/zap/zapcore"
@ -15,7 +17,7 @@ func NewLogger(logger *zap.Logger) *Logger {
return &Logger{logger: logger.WithOptions(zap.AddCallerSkip(1))} return &Logger{logger: logger.WithOptions(zap.AddCallerSkip(1))}
} }
func (pl *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) { func (pl *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
fields := make([]zapcore.Field, len(data)) fields := make([]zapcore.Field, len(data))
i := 0 i := 0
for k, v := range data { for k, v := range data {

View File

@ -2,6 +2,8 @@
package zerologadapter package zerologadapter
import ( import (
"context"
"github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4"
"github.com/rs/zerolog" "github.com/rs/zerolog"
) )
@ -18,7 +20,7 @@ func NewLogger(logger zerolog.Logger) *Logger {
} }
} }
func (pl *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) { func (pl *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
var zlevel zerolog.Level var zlevel zerolog.Level
switch level { switch level {
case pgx.LogLevelNone: case pgx.LogLevelNone:

View File

@ -1,6 +1,7 @@
package pgx package pgx
import ( import (
"context"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
@ -44,7 +45,7 @@ func (ll LogLevel) String() string {
// Logger is the interface used to get logging from pgx internals. // Logger is the interface used to get logging from pgx internals.
type Logger interface { type Logger interface {
// Log a message at the given level with data key/value pairs. data may be nil. // Log a message at the given level with data key/value pairs. data may be nil.
Log(level LogLevel, msg string, data map[string]interface{}) Log(ctx context.Context, level LogLevel, msg string, data map[string]interface{})
} }
// LogLevelFromString converts log level string to constant // LogLevelFromString converts log level string to constant

View File

@ -1,6 +1,7 @@
package pgx package pgx
import ( import (
"context"
"fmt" "fmt"
"reflect" "reflect"
"time" "time"
@ -70,11 +71,12 @@ func (r *connRow) Scan(dest ...interface{}) (err error) {
type rowLog interface { type rowLog interface {
shouldLog(lvl LogLevel) bool shouldLog(lvl LogLevel) bool
log(lvl LogLevel, msg string, data map[string]interface{}) log(ctx context.Context, lvl LogLevel, msg string, data map[string]interface{})
} }
// connRows implements the Rows interface for Conn.Query. // connRows implements the Rows interface for Conn.Query.
type connRows struct { type connRows struct {
ctx context.Context
logger rowLog logger rowLog
connInfo *pgtype.ConnInfo connInfo *pgtype.ConnInfo
values [][]byte values [][]byte
@ -119,10 +121,10 @@ func (rows *connRows) Close() {
if rows.err == nil { if rows.err == nil {
if rows.logger.shouldLog(LogLevelInfo) { if rows.logger.shouldLog(LogLevelInfo) {
endTime := time.Now() endTime := time.Now()
rows.logger.log(LogLevelInfo, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount}) rows.logger.log(rows.ctx, LogLevelInfo, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount})
} }
} else if rows.logger.shouldLog(LogLevelError) { } else if rows.logger.shouldLog(LogLevelError) {
rows.logger.log(LogLevelError, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args)}) rows.logger.log(rows.ctx, LogLevelError, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args)})
} }
} }
} }