mirror of https://github.com/jackc/pgx.git
Add context.Context to Logger interface
This allows custom logger adapters to add additional fields to log messages. For example, a HTTP server may with to log the request ID. fixes #428pull/586/head
parent
ab1edc79e0
commit
3028821487
5
batch.go
5
batch.go
|
@ -1,6 +1,8 @@
|
|||
package pgx
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgtype"
|
||||
errors "golang.org/x/xerrors"
|
||||
|
@ -47,6 +49,7 @@ type BatchResults interface {
|
|||
}
|
||||
|
||||
type batchResults struct {
|
||||
ctx context.Context
|
||||
conn *Conn
|
||||
mrr *pgconn.MultiResultReader
|
||||
err error
|
||||
|
@ -71,7 +74,7 @@ func (br *batchResults) ExecResults() (pgconn.CommandTag, error) {
|
|||
|
||||
// QueryResults reads the results from the next query in the batch as if the query has been sent with Query.
|
||||
func (br *batchResults) QueryResults() (Rows, error) {
|
||||
rows := br.conn.getRows("batch query", nil)
|
||||
rows := br.conn.getRows(br.ctx, "batch query", nil)
|
||||
|
||||
if br.err != nil {
|
||||
rows.err = br.err
|
||||
|
|
|
@ -201,7 +201,8 @@ func BenchmarkSelectWithoutLogging(b *testing.B) {
|
|||
|
||||
type discardLogger struct{}
|
||||
|
||||
func (dl discardLogger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {}
|
||||
func (dl discardLogger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
|
||||
}
|
||||
|
||||
func BenchmarkSelectWithLoggingTraceDiscard(b *testing.B) {
|
||||
conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")))
|
||||
|
|
22
conn.go
22
conn.go
|
@ -146,7 +146,7 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) {
|
|||
c.logger = c.config.Logger
|
||||
|
||||
if c.shouldLog(LogLevelInfo) {
|
||||
c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"host": config.Config.Host})
|
||||
c.log(ctx, LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"host": config.Config.Host})
|
||||
}
|
||||
c.pgConn, err = pgconn.ConnectConfig(ctx, &config.Config)
|
||||
if err != nil {
|
||||
|
@ -155,7 +155,7 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) {
|
|||
|
||||
if err != nil {
|
||||
if c.shouldLog(LogLevelError) {
|
||||
c.log(LogLevelError, "connect failed", map[string]interface{}{"err": err})
|
||||
c.log(ctx, LogLevelError, "connect failed", map[string]interface{}{"err": err})
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
@ -184,7 +184,7 @@ func (c *Conn) Close(ctx context.Context) error {
|
|||
err := c.pgConn.Close(ctx)
|
||||
c.causeOfDeath = errors.New("Closed")
|
||||
if c.shouldLog(LogLevelInfo) {
|
||||
c.log(LogLevelInfo, "closed connection", nil)
|
||||
c.log(ctx, LogLevelInfo, "closed connection", nil)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
@ -205,7 +205,7 @@ func (c *Conn) Prepare(ctx context.Context, name, sql string) (ps *PreparedState
|
|||
if c.shouldLog(LogLevelError) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
c.log(LogLevelError, "Prepare failed", map[string]interface{}{"err": err, "name": name, "sql": sql})
|
||||
c.log(ctx, LogLevelError, "Prepare failed", map[string]interface{}{"err": err, "name": name, "sql": sql})
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
@ -307,7 +307,7 @@ func (c *Conn) shouldLog(lvl LogLevel) bool {
|
|||
return c.logger != nil && c.logLevel >= lvl
|
||||
}
|
||||
|
||||
func (c *Conn) log(lvl LogLevel, msg string, data map[string]interface{}) {
|
||||
func (c *Conn) log(ctx context.Context, lvl LogLevel, msg string, data map[string]interface{}) {
|
||||
if data == nil {
|
||||
data = map[string]interface{}{}
|
||||
}
|
||||
|
@ -315,7 +315,7 @@ func (c *Conn) log(lvl LogLevel, msg string, data map[string]interface{}) {
|
|||
data["pid"] = c.pgConn.PID()
|
||||
}
|
||||
|
||||
c.logger.Log(lvl, msg, data)
|
||||
c.logger.Log(ctx, lvl, msg, data)
|
||||
}
|
||||
|
||||
// SetLogger replaces the current logger and returns the previous logger.
|
||||
|
@ -386,14 +386,14 @@ func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) (
|
|||
commandTag, err := c.exec(ctx, sql, arguments...)
|
||||
if err != nil {
|
||||
if c.shouldLog(LogLevelError) {
|
||||
c.log(LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err})
|
||||
c.log(ctx, LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err})
|
||||
}
|
||||
return commandTag, err
|
||||
}
|
||||
|
||||
if c.shouldLog(LogLevelInfo) {
|
||||
endTime := time.Now()
|
||||
c.log(LogLevelInfo, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag})
|
||||
c.log(ctx, LogLevelInfo, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag})
|
||||
}
|
||||
|
||||
return commandTag, err
|
||||
|
@ -518,7 +518,7 @@ optionLoop:
|
|||
|
||||
}
|
||||
|
||||
func (c *Conn) getRows(sql string, args []interface{}) *connRows {
|
||||
func (c *Conn) getRows(ctx context.Context, sql string, args []interface{}) *connRows {
|
||||
if len(c.preallocatedRows) == 0 {
|
||||
c.preallocatedRows = make([]connRows, 64)
|
||||
}
|
||||
|
@ -526,6 +526,7 @@ func (c *Conn) getRows(sql string, args []interface{}) *connRows {
|
|||
r := &c.preallocatedRows[len(c.preallocatedRows)-1]
|
||||
c.preallocatedRows = c.preallocatedRows[0 : len(c.preallocatedRows)-1]
|
||||
|
||||
r.ctx = ctx
|
||||
r.logger = c
|
||||
r.connInfo = c.ConnInfo
|
||||
r.startTime = time.Now()
|
||||
|
@ -568,7 +569,7 @@ optionLoop:
|
|||
}
|
||||
}
|
||||
|
||||
rows := c.getRows(sql, args)
|
||||
rows := c.getRows(ctx, sql, args)
|
||||
|
||||
var err error
|
||||
if simpleProtocol {
|
||||
|
@ -727,6 +728,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
|
|||
mrr := c.pgConn.ExecBatch(ctx, batch)
|
||||
|
||||
return &batchResults{
|
||||
ctx: ctx,
|
||||
conn: c,
|
||||
mrr: mrr,
|
||||
}
|
||||
|
|
30
conn_test.go
30
conn_test.go
|
@ -543,10 +543,38 @@ type testLogger struct {
|
|||
logs []testLog
|
||||
}
|
||||
|
||||
func (l *testLogger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {
|
||||
func (l *testLogger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
|
||||
data["ctxdata"] = ctx.Value("ctxdata")
|
||||
l.logs = append(l.logs, testLog{lvl: level, msg: msg, data: data})
|
||||
}
|
||||
|
||||
func TestLogPassesContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
ctx := context.WithValue(context.Background(), "ctxdata", "foo")
|
||||
|
||||
l1 := &testLogger{}
|
||||
oldLogger := conn.SetLogger(l1)
|
||||
if oldLogger != nil {
|
||||
t.Fatalf("Expected conn.SetLogger to return %v, but it was %v", nil, oldLogger)
|
||||
}
|
||||
|
||||
if _, err := conn.Exec(ctx, ";"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(l1.logs) != 1 {
|
||||
t.Fatal("Expected new logger l1 to be called once, but it wasn't")
|
||||
}
|
||||
|
||||
if l1.logs[0].data["ctxdata"] != "foo" {
|
||||
t.Fatal("Expected context data to be passed to logger, but it wasn't")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetLogger(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
|
@ -3,6 +3,8 @@
|
|||
package log15adapter
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v4"
|
||||
)
|
||||
|
||||
|
@ -24,7 +26,7 @@ func NewLogger(l Log15Logger) *Logger {
|
|||
return &Logger{l: l}
|
||||
}
|
||||
|
||||
func (l *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {
|
||||
func (l *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
|
||||
logArgs := make([]interface{}, 0, len(data))
|
||||
for k, v := range data {
|
||||
logArgs = append(logArgs, k, v)
|
||||
|
|
|
@ -3,6 +3,8 @@
|
|||
package logrusadapter
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
@ -15,7 +17,7 @@ func NewLogger(l logrus.FieldLogger) *Logger {
|
|||
return &Logger{l: l}
|
||||
}
|
||||
|
||||
func (l *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {
|
||||
func (l *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
|
||||
var logger logrus.FieldLogger
|
||||
if data != nil {
|
||||
logger = l.l.WithFields(data)
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
package testingadapter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/v4"
|
||||
|
@ -22,7 +23,7 @@ func NewLogger(l TestingLogger) *Logger {
|
|||
return &Logger{l: l}
|
||||
}
|
||||
|
||||
func (l *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {
|
||||
func (l *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
|
||||
logArgs := make([]interface{}, 0, 2+len(data))
|
||||
logArgs = append(logArgs, level, msg)
|
||||
for k, v := range data {
|
||||
|
|
|
@ -2,6 +2,8 @@
|
|||
package zapadapter
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v4"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
|
@ -15,7 +17,7 @@ func NewLogger(logger *zap.Logger) *Logger {
|
|||
return &Logger{logger: logger.WithOptions(zap.AddCallerSkip(1))}
|
||||
}
|
||||
|
||||
func (pl *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {
|
||||
func (pl *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
|
||||
fields := make([]zapcore.Field, len(data))
|
||||
i := 0
|
||||
for k, v := range data {
|
||||
|
|
|
@ -2,6 +2,8 @@
|
|||
package zerologadapter
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
@ -18,7 +20,7 @@ func NewLogger(logger zerolog.Logger) *Logger {
|
|||
}
|
||||
}
|
||||
|
||||
func (pl *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {
|
||||
func (pl *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
|
||||
var zlevel zerolog.Level
|
||||
switch level {
|
||||
case pgx.LogLevelNone:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package pgx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
|
||||
|
@ -44,7 +45,7 @@ func (ll LogLevel) String() string {
|
|||
// Logger is the interface used to get logging from pgx internals.
|
||||
type Logger interface {
|
||||
// Log a message at the given level with data key/value pairs. data may be nil.
|
||||
Log(level LogLevel, msg string, data map[string]interface{})
|
||||
Log(ctx context.Context, level LogLevel, msg string, data map[string]interface{})
|
||||
}
|
||||
|
||||
// LogLevelFromString converts log level string to constant
|
||||
|
|
8
rows.go
8
rows.go
|
@ -1,6 +1,7 @@
|
|||
package pgx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
|
@ -70,11 +71,12 @@ func (r *connRow) Scan(dest ...interface{}) (err error) {
|
|||
|
||||
type rowLog interface {
|
||||
shouldLog(lvl LogLevel) bool
|
||||
log(lvl LogLevel, msg string, data map[string]interface{})
|
||||
log(ctx context.Context, lvl LogLevel, msg string, data map[string]interface{})
|
||||
}
|
||||
|
||||
// connRows implements the Rows interface for Conn.Query.
|
||||
type connRows struct {
|
||||
ctx context.Context
|
||||
logger rowLog
|
||||
connInfo *pgtype.ConnInfo
|
||||
values [][]byte
|
||||
|
@ -119,10 +121,10 @@ func (rows *connRows) Close() {
|
|||
if rows.err == nil {
|
||||
if rows.logger.shouldLog(LogLevelInfo) {
|
||||
endTime := time.Now()
|
||||
rows.logger.log(LogLevelInfo, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount})
|
||||
rows.logger.log(rows.ctx, LogLevelInfo, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount})
|
||||
}
|
||||
} else if rows.logger.shouldLog(LogLevelError) {
|
||||
rows.logger.log(LogLevelError, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args)})
|
||||
rows.logger.log(rows.ctx, LogLevelError, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args)})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue