Add context.Context to Logger interface

This allows custom logger adapters to add additional fields to log
messages. For example, a HTTP server may with to log the request ID.

fixes #428
pull/586/head
Jack Christensen 2019-08-03 16:16:21 -05:00
parent ab1edc79e0
commit 3028821487
11 changed files with 68 additions and 22 deletions

View File

@ -1,6 +1,8 @@
package pgx
import (
"context"
"github.com/jackc/pgconn"
"github.com/jackc/pgtype"
errors "golang.org/x/xerrors"
@ -47,6 +49,7 @@ type BatchResults interface {
}
type batchResults struct {
ctx context.Context
conn *Conn
mrr *pgconn.MultiResultReader
err error
@ -71,7 +74,7 @@ func (br *batchResults) ExecResults() (pgconn.CommandTag, error) {
// QueryResults reads the results from the next query in the batch as if the query has been sent with Query.
func (br *batchResults) QueryResults() (Rows, error) {
rows := br.conn.getRows("batch query", nil)
rows := br.conn.getRows(br.ctx, "batch query", nil)
if br.err != nil {
rows.err = br.err

View File

@ -201,7 +201,8 @@ func BenchmarkSelectWithoutLogging(b *testing.B) {
type discardLogger struct{}
func (dl discardLogger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {}
func (dl discardLogger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
}
func BenchmarkSelectWithLoggingTraceDiscard(b *testing.B) {
conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")))

22
conn.go
View File

@ -146,7 +146,7 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) {
c.logger = c.config.Logger
if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"host": config.Config.Host})
c.log(ctx, LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"host": config.Config.Host})
}
c.pgConn, err = pgconn.ConnectConfig(ctx, &config.Config)
if err != nil {
@ -155,7 +155,7 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) {
if err != nil {
if c.shouldLog(LogLevelError) {
c.log(LogLevelError, "connect failed", map[string]interface{}{"err": err})
c.log(ctx, LogLevelError, "connect failed", map[string]interface{}{"err": err})
}
return nil, err
}
@ -184,7 +184,7 @@ func (c *Conn) Close(ctx context.Context) error {
err := c.pgConn.Close(ctx)
c.causeOfDeath = errors.New("Closed")
if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, "closed connection", nil)
c.log(ctx, LogLevelInfo, "closed connection", nil)
}
return err
}
@ -205,7 +205,7 @@ func (c *Conn) Prepare(ctx context.Context, name, sql string) (ps *PreparedState
if c.shouldLog(LogLevelError) {
defer func() {
if err != nil {
c.log(LogLevelError, "Prepare failed", map[string]interface{}{"err": err, "name": name, "sql": sql})
c.log(ctx, LogLevelError, "Prepare failed", map[string]interface{}{"err": err, "name": name, "sql": sql})
}
}()
}
@ -307,7 +307,7 @@ func (c *Conn) shouldLog(lvl LogLevel) bool {
return c.logger != nil && c.logLevel >= lvl
}
func (c *Conn) log(lvl LogLevel, msg string, data map[string]interface{}) {
func (c *Conn) log(ctx context.Context, lvl LogLevel, msg string, data map[string]interface{}) {
if data == nil {
data = map[string]interface{}{}
}
@ -315,7 +315,7 @@ func (c *Conn) log(lvl LogLevel, msg string, data map[string]interface{}) {
data["pid"] = c.pgConn.PID()
}
c.logger.Log(lvl, msg, data)
c.logger.Log(ctx, lvl, msg, data)
}
// SetLogger replaces the current logger and returns the previous logger.
@ -386,14 +386,14 @@ func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) (
commandTag, err := c.exec(ctx, sql, arguments...)
if err != nil {
if c.shouldLog(LogLevelError) {
c.log(LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err})
c.log(ctx, LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err})
}
return commandTag, err
}
if c.shouldLog(LogLevelInfo) {
endTime := time.Now()
c.log(LogLevelInfo, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag})
c.log(ctx, LogLevelInfo, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag})
}
return commandTag, err
@ -518,7 +518,7 @@ optionLoop:
}
func (c *Conn) getRows(sql string, args []interface{}) *connRows {
func (c *Conn) getRows(ctx context.Context, sql string, args []interface{}) *connRows {
if len(c.preallocatedRows) == 0 {
c.preallocatedRows = make([]connRows, 64)
}
@ -526,6 +526,7 @@ func (c *Conn) getRows(sql string, args []interface{}) *connRows {
r := &c.preallocatedRows[len(c.preallocatedRows)-1]
c.preallocatedRows = c.preallocatedRows[0 : len(c.preallocatedRows)-1]
r.ctx = ctx
r.logger = c
r.connInfo = c.ConnInfo
r.startTime = time.Now()
@ -568,7 +569,7 @@ optionLoop:
}
}
rows := c.getRows(sql, args)
rows := c.getRows(ctx, sql, args)
var err error
if simpleProtocol {
@ -727,6 +728,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
mrr := c.pgConn.ExecBatch(ctx, batch)
return &batchResults{
ctx: ctx,
conn: c,
mrr: mrr,
}

View File

@ -543,10 +543,38 @@ type testLogger struct {
logs []testLog
}
func (l *testLogger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {
func (l *testLogger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
data["ctxdata"] = ctx.Value("ctxdata")
l.logs = append(l.logs, testLog{lvl: level, msg: msg, data: data})
}
func TestLogPassesContext(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
ctx := context.WithValue(context.Background(), "ctxdata", "foo")
l1 := &testLogger{}
oldLogger := conn.SetLogger(l1)
if oldLogger != nil {
t.Fatalf("Expected conn.SetLogger to return %v, but it was %v", nil, oldLogger)
}
if _, err := conn.Exec(ctx, ";"); err != nil {
t.Fatal(err)
}
if len(l1.logs) != 1 {
t.Fatal("Expected new logger l1 to be called once, but it wasn't")
}
if l1.logs[0].data["ctxdata"] != "foo" {
t.Fatal("Expected context data to be passed to logger, but it wasn't")
}
}
func TestSetLogger(t *testing.T) {
t.Parallel()

View File

@ -3,6 +3,8 @@
package log15adapter
import (
"context"
"github.com/jackc/pgx/v4"
)
@ -24,7 +26,7 @@ func NewLogger(l Log15Logger) *Logger {
return &Logger{l: l}
}
func (l *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {
func (l *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
logArgs := make([]interface{}, 0, len(data))
for k, v := range data {
logArgs = append(logArgs, k, v)

View File

@ -3,6 +3,8 @@
package logrusadapter
import (
"context"
"github.com/jackc/pgx/v4"
"github.com/sirupsen/logrus"
)
@ -15,7 +17,7 @@ func NewLogger(l logrus.FieldLogger) *Logger {
return &Logger{l: l}
}
func (l *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {
func (l *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
var logger logrus.FieldLogger
if data != nil {
logger = l.l.WithFields(data)

View File

@ -3,6 +3,7 @@
package testingadapter
import (
"context"
"fmt"
"github.com/jackc/pgx/v4"
@ -22,7 +23,7 @@ func NewLogger(l TestingLogger) *Logger {
return &Logger{l: l}
}
func (l *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {
func (l *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
logArgs := make([]interface{}, 0, 2+len(data))
logArgs = append(logArgs, level, msg)
for k, v := range data {

View File

@ -2,6 +2,8 @@
package zapadapter
import (
"context"
"github.com/jackc/pgx/v4"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
@ -15,7 +17,7 @@ func NewLogger(logger *zap.Logger) *Logger {
return &Logger{logger: logger.WithOptions(zap.AddCallerSkip(1))}
}
func (pl *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {
func (pl *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
fields := make([]zapcore.Field, len(data))
i := 0
for k, v := range data {

View File

@ -2,6 +2,8 @@
package zerologadapter
import (
"context"
"github.com/jackc/pgx/v4"
"github.com/rs/zerolog"
)
@ -18,7 +20,7 @@ func NewLogger(logger zerolog.Logger) *Logger {
}
}
func (pl *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {
func (pl *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
var zlevel zerolog.Level
switch level {
case pgx.LogLevelNone:

View File

@ -1,6 +1,7 @@
package pgx
import (
"context"
"encoding/hex"
"fmt"
@ -44,7 +45,7 @@ func (ll LogLevel) String() string {
// Logger is the interface used to get logging from pgx internals.
type Logger interface {
// Log a message at the given level with data key/value pairs. data may be nil.
Log(level LogLevel, msg string, data map[string]interface{})
Log(ctx context.Context, level LogLevel, msg string, data map[string]interface{})
}
// LogLevelFromString converts log level string to constant

View File

@ -1,6 +1,7 @@
package pgx
import (
"context"
"fmt"
"reflect"
"time"
@ -70,11 +71,12 @@ func (r *connRow) Scan(dest ...interface{}) (err error) {
type rowLog interface {
shouldLog(lvl LogLevel) bool
log(lvl LogLevel, msg string, data map[string]interface{})
log(ctx context.Context, lvl LogLevel, msg string, data map[string]interface{})
}
// connRows implements the Rows interface for Conn.Query.
type connRows struct {
ctx context.Context
logger rowLog
connInfo *pgtype.ConnInfo
values [][]byte
@ -119,10 +121,10 @@ func (rows *connRows) Close() {
if rows.err == nil {
if rows.logger.shouldLog(LogLevelInfo) {
endTime := time.Now()
rows.logger.log(LogLevelInfo, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount})
rows.logger.log(rows.ctx, LogLevelInfo, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount})
}
} else if rows.logger.shouldLog(LogLevelError) {
rows.logger.log(LogLevelError, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args)})
rows.logger.log(rows.ctx, LogLevelError, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args)})
}
}
}