diff --git a/batch.go b/batch.go index c652c2a5..3a9bcdda 100644 --- a/batch.go +++ b/batch.go @@ -1,6 +1,8 @@ package pgx import ( + "context" + "github.com/jackc/pgconn" "github.com/jackc/pgtype" errors "golang.org/x/xerrors" @@ -47,6 +49,7 @@ type BatchResults interface { } type batchResults struct { + ctx context.Context conn *Conn mrr *pgconn.MultiResultReader err error @@ -71,7 +74,7 @@ func (br *batchResults) ExecResults() (pgconn.CommandTag, error) { // QueryResults reads the results from the next query in the batch as if the query has been sent with Query. func (br *batchResults) QueryResults() (Rows, error) { - rows := br.conn.getRows("batch query", nil) + rows := br.conn.getRows(br.ctx, "batch query", nil) if br.err != nil { rows.err = br.err diff --git a/bench_test.go b/bench_test.go index 0722ae6a..765547d7 100644 --- a/bench_test.go +++ b/bench_test.go @@ -201,7 +201,8 @@ func BenchmarkSelectWithoutLogging(b *testing.B) { type discardLogger struct{} -func (dl discardLogger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {} +func (dl discardLogger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { +} func BenchmarkSelectWithLoggingTraceDiscard(b *testing.B) { conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))) diff --git a/conn.go b/conn.go index 0459e6e7..b1018206 100644 --- a/conn.go +++ b/conn.go @@ -146,7 +146,7 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) { c.logger = c.config.Logger if c.shouldLog(LogLevelInfo) { - c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"host": config.Config.Host}) + c.log(ctx, LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"host": config.Config.Host}) } c.pgConn, err = pgconn.ConnectConfig(ctx, &config.Config) if err != nil { @@ -155,7 +155,7 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) { if err != nil { if c.shouldLog(LogLevelError) { - c.log(LogLevelError, "connect failed", map[string]interface{}{"err": err}) + c.log(ctx, LogLevelError, "connect failed", map[string]interface{}{"err": err}) } return nil, err } @@ -184,7 +184,7 @@ func (c *Conn) Close(ctx context.Context) error { err := c.pgConn.Close(ctx) c.causeOfDeath = errors.New("Closed") if c.shouldLog(LogLevelInfo) { - c.log(LogLevelInfo, "closed connection", nil) + c.log(ctx, LogLevelInfo, "closed connection", nil) } return err } @@ -205,7 +205,7 @@ func (c *Conn) Prepare(ctx context.Context, name, sql string) (ps *PreparedState if c.shouldLog(LogLevelError) { defer func() { if err != nil { - c.log(LogLevelError, "Prepare failed", map[string]interface{}{"err": err, "name": name, "sql": sql}) + c.log(ctx, LogLevelError, "Prepare failed", map[string]interface{}{"err": err, "name": name, "sql": sql}) } }() } @@ -307,7 +307,7 @@ func (c *Conn) shouldLog(lvl LogLevel) bool { return c.logger != nil && c.logLevel >= lvl } -func (c *Conn) log(lvl LogLevel, msg string, data map[string]interface{}) { +func (c *Conn) log(ctx context.Context, lvl LogLevel, msg string, data map[string]interface{}) { if data == nil { data = map[string]interface{}{} } @@ -315,7 +315,7 @@ func (c *Conn) log(lvl LogLevel, msg string, data map[string]interface{}) { data["pid"] = c.pgConn.PID() } - c.logger.Log(lvl, msg, data) + c.logger.Log(ctx, lvl, msg, data) } // SetLogger replaces the current logger and returns the previous logger. @@ -386,14 +386,14 @@ func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) ( commandTag, err := c.exec(ctx, sql, arguments...) if err != nil { if c.shouldLog(LogLevelError) { - c.log(LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err}) + c.log(ctx, LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err}) } return commandTag, err } if c.shouldLog(LogLevelInfo) { endTime := time.Now() - c.log(LogLevelInfo, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag}) + c.log(ctx, LogLevelInfo, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag}) } return commandTag, err @@ -518,7 +518,7 @@ optionLoop: } -func (c *Conn) getRows(sql string, args []interface{}) *connRows { +func (c *Conn) getRows(ctx context.Context, sql string, args []interface{}) *connRows { if len(c.preallocatedRows) == 0 { c.preallocatedRows = make([]connRows, 64) } @@ -526,6 +526,7 @@ func (c *Conn) getRows(sql string, args []interface{}) *connRows { r := &c.preallocatedRows[len(c.preallocatedRows)-1] c.preallocatedRows = c.preallocatedRows[0 : len(c.preallocatedRows)-1] + r.ctx = ctx r.logger = c r.connInfo = c.ConnInfo r.startTime = time.Now() @@ -568,7 +569,7 @@ optionLoop: } } - rows := c.getRows(sql, args) + rows := c.getRows(ctx, sql, args) var err error if simpleProtocol { @@ -727,6 +728,7 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { mrr := c.pgConn.ExecBatch(ctx, batch) return &batchResults{ + ctx: ctx, conn: c, mrr: mrr, } diff --git a/conn_test.go b/conn_test.go index 286ff403..a506349c 100644 --- a/conn_test.go +++ b/conn_test.go @@ -543,10 +543,38 @@ type testLogger struct { logs []testLog } -func (l *testLogger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) { +func (l *testLogger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { + data["ctxdata"] = ctx.Value("ctxdata") l.logs = append(l.logs, testLog{lvl: level, msg: msg, data: data}) } +func TestLogPassesContext(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + ctx := context.WithValue(context.Background(), "ctxdata", "foo") + + l1 := &testLogger{} + oldLogger := conn.SetLogger(l1) + if oldLogger != nil { + t.Fatalf("Expected conn.SetLogger to return %v, but it was %v", nil, oldLogger) + } + + if _, err := conn.Exec(ctx, ";"); err != nil { + t.Fatal(err) + } + + if len(l1.logs) != 1 { + t.Fatal("Expected new logger l1 to be called once, but it wasn't") + } + + if l1.logs[0].data["ctxdata"] != "foo" { + t.Fatal("Expected context data to be passed to logger, but it wasn't") + } +} + func TestSetLogger(t *testing.T) { t.Parallel() diff --git a/log/log15adapter/adapter.go b/log/log15adapter/adapter.go index 6f120f88..70608e33 100644 --- a/log/log15adapter/adapter.go +++ b/log/log15adapter/adapter.go @@ -3,6 +3,8 @@ package log15adapter import ( + "context" + "github.com/jackc/pgx/v4" ) @@ -24,7 +26,7 @@ func NewLogger(l Log15Logger) *Logger { return &Logger{l: l} } -func (l *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) { +func (l *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { logArgs := make([]interface{}, 0, len(data)) for k, v := range data { logArgs = append(logArgs, k, v) diff --git a/log/logrusadapter/adapter.go b/log/logrusadapter/adapter.go index 3c3d6f0e..e0cd6328 100644 --- a/log/logrusadapter/adapter.go +++ b/log/logrusadapter/adapter.go @@ -3,6 +3,8 @@ package logrusadapter import ( + "context" + "github.com/jackc/pgx/v4" "github.com/sirupsen/logrus" ) @@ -15,7 +17,7 @@ func NewLogger(l logrus.FieldLogger) *Logger { return &Logger{l: l} } -func (l *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) { +func (l *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { var logger logrus.FieldLogger if data != nil { logger = l.l.WithFields(data) diff --git a/log/testingadapter/adapter.go b/log/testingadapter/adapter.go index 28e89bd1..3ddce5a1 100644 --- a/log/testingadapter/adapter.go +++ b/log/testingadapter/adapter.go @@ -3,6 +3,7 @@ package testingadapter import ( + "context" "fmt" "github.com/jackc/pgx/v4" @@ -22,7 +23,7 @@ func NewLogger(l TestingLogger) *Logger { return &Logger{l: l} } -func (l *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) { +func (l *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { logArgs := make([]interface{}, 0, 2+len(data)) logArgs = append(logArgs, level, msg) for k, v := range data { diff --git a/log/zapadapter/adapter.go b/log/zapadapter/adapter.go index 5011ee48..ff03d9b9 100644 --- a/log/zapadapter/adapter.go +++ b/log/zapadapter/adapter.go @@ -2,6 +2,8 @@ package zapadapter import ( + "context" + "github.com/jackc/pgx/v4" "go.uber.org/zap" "go.uber.org/zap/zapcore" @@ -15,7 +17,7 @@ func NewLogger(logger *zap.Logger) *Logger { return &Logger{logger: logger.WithOptions(zap.AddCallerSkip(1))} } -func (pl *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) { +func (pl *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { fields := make([]zapcore.Field, len(data)) i := 0 for k, v := range data { diff --git a/log/zerologadapter/adapter.go b/log/zerologadapter/adapter.go index f3f07585..efbcb5bf 100644 --- a/log/zerologadapter/adapter.go +++ b/log/zerologadapter/adapter.go @@ -2,6 +2,8 @@ package zerologadapter import ( + "context" + "github.com/jackc/pgx/v4" "github.com/rs/zerolog" ) @@ -18,7 +20,7 @@ func NewLogger(logger zerolog.Logger) *Logger { } } -func (pl *Logger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) { +func (pl *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { var zlevel zerolog.Level switch level { case pgx.LogLevelNone: diff --git a/logger.go b/logger.go index 2e37afdb..f69a152f 100644 --- a/logger.go +++ b/logger.go @@ -1,6 +1,7 @@ package pgx import ( + "context" "encoding/hex" "fmt" @@ -44,7 +45,7 @@ func (ll LogLevel) String() string { // Logger is the interface used to get logging from pgx internals. type Logger interface { // Log a message at the given level with data key/value pairs. data may be nil. - Log(level LogLevel, msg string, data map[string]interface{}) + Log(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) } // LogLevelFromString converts log level string to constant diff --git a/rows.go b/rows.go index 9d5dfcfb..97f152fe 100644 --- a/rows.go +++ b/rows.go @@ -1,6 +1,7 @@ package pgx import ( + "context" "fmt" "reflect" "time" @@ -70,11 +71,12 @@ func (r *connRow) Scan(dest ...interface{}) (err error) { type rowLog interface { shouldLog(lvl LogLevel) bool - log(lvl LogLevel, msg string, data map[string]interface{}) + log(ctx context.Context, lvl LogLevel, msg string, data map[string]interface{}) } // connRows implements the Rows interface for Conn.Query. type connRows struct { + ctx context.Context logger rowLog connInfo *pgtype.ConnInfo values [][]byte @@ -119,10 +121,10 @@ func (rows *connRows) Close() { if rows.err == nil { if rows.logger.shouldLog(LogLevelInfo) { endTime := time.Now() - rows.logger.log(LogLevelInfo, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount}) + rows.logger.log(rows.ctx, LogLevelInfo, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount}) } } else if rows.logger.shouldLog(LogLevelError) { - rows.logger.log(LogLevelError, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args)}) + rows.logger.log(rows.ctx, LogLevelError, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args)}) } } }