Add tracing support

Replaces existing logging support. Package tracelog provides adapter for
old style logging.

https://github.com/jackc/pgx/issues/1061
pull/1281/head
Jack Christensen 2022-07-16 12:27:10 -05:00
parent 9201cc0341
commit 78875bb95a
19 changed files with 1446 additions and 485 deletions

View File

@ -159,7 +159,9 @@ Previously, a batch with 10 unique parameterized statements executed 100 times w
for each prepare / describe and 1 for executing them all. Now pipeline mode is used to prepare / describe all statements
in a single network round trip. So it would only take 2 round trips.
## 3rd Party Logger Integration
## Tracing and Logging
Internal logging support has been replaced with tracing hooks. This allows custom tracing integration with tools like OpenTelemetry. Package tracelog provides an adapter for pgx v4 loggers to act as a tracer.
All integrations with 3rd party loggers have been extracted to separate repositories. This trims the pgx dependency
tree.

View File

@ -163,6 +163,8 @@ pgerrcode contains constants for the PostgreSQL error codes.
## Adapters for 3rd Party Loggers
These adapters can be used with the tracelog package.
* [github.com/jackc/pgx-go-kit-log](https://github.com/jackc/pgx-go-kit-log)
* [github.com/jackc/pgx-log15](https://github.com/jackc/pgx-log15)
* [github.com/jackc/pgx-logrus](https://github.com/jackc/pgx-logrus)

186
batch.go
View File

@ -50,13 +50,14 @@ type BatchResults interface {
}
type batchResults struct {
ctx context.Context
conn *Conn
mrr *pgconn.MultiResultReader
err error
b *Batch
ix int
closed bool
ctx context.Context
conn *Conn
mrr *pgconn.MultiResultReader
err error
b *Batch
ix int
closed bool
endTraced bool
}
// Exec reads the results from the next query in the batch as if the query has been sent with Exec.
@ -75,35 +76,29 @@ func (br *batchResults) Exec() (pgconn.CommandTag, error) {
if err == nil {
err = errors.New("no result")
}
if br.conn.shouldLog(LogLevelError) {
br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]any{
"sql": query,
"args": logQueryArgs(arguments),
"err": err,
if br.conn.batchTracer != nil {
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
SQL: query,
Args: arguments,
Err: err,
})
}
return pgconn.CommandTag{}, err
}
commandTag, err := br.mrr.ResultReader().Close()
br.err = err
if err != nil {
if br.conn.shouldLog(LogLevelError) {
br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]any{
"sql": query,
"args": logQueryArgs(arguments),
"err": err,
})
}
} else if br.conn.shouldLog(LogLevelInfo) {
br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Exec", map[string]any{
"sql": query,
"args": logQueryArgs(arguments),
"commandTag": commandTag,
if br.conn.batchTracer != nil {
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
SQL: query,
Args: arguments,
CommandTag: commandTag,
Err: br.err,
})
}
return commandTag, err
return commandTag, br.err
}
// Query reads the results from the next query in the batch as if the query has been sent with Query.
@ -123,6 +118,7 @@ func (br *batchResults) Query() (Rows, error) {
}
rows := br.conn.getRows(br.ctx, query, arguments)
rows.batchTracer = br.conn.batchTracer
if !br.mrr.NextResult() {
rows.err = br.mrr.Close()
@ -131,11 +127,11 @@ func (br *batchResults) Query() (Rows, error) {
}
rows.closed = true
if br.conn.shouldLog(LogLevelError) {
br.conn.log(br.ctx, LogLevelError, "BatchResult.Query", map[string]any{
"sql": query,
"args": logQueryArgs(arguments),
"err": rows.err,
if br.conn.batchTracer != nil {
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
SQL: query,
Args: arguments,
Err: rows.err,
})
}
@ -156,6 +152,15 @@ func (br *batchResults) QueryRow() Row {
// Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to
// resyncronize the connection with the server. In this case the underlying connection will have been closed.
func (br *batchResults) Close() error {
defer func() {
if !br.endTraced {
if br.conn != nil && br.conn.batchTracer != nil {
br.conn.batchTracer.TraceBatchEnd(br.ctx, br.conn, TraceBatchEndData{Err: br.err})
}
br.endTraced = true
}
}()
if br.err != nil {
return br.err
}
@ -163,24 +168,26 @@ func (br *batchResults) Close() error {
if br.closed {
return nil
}
br.closed = true
// log any queries that haven't yet been logged by Exec or Query
for {
query, args, ok := br.nextQueryAndArgs()
if !ok {
break
}
if br.conn.shouldLog(LogLevelInfo) {
br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Close", map[string]any{
"sql": query,
"args": logQueryArgs(args),
})
// consume and log any queries that haven't yet been logged by Exec or Query
if br.conn.batchTracer != nil {
for br.err == nil && !br.closed && br.b != nil && br.ix < len(br.b.items) {
br.Exec()
}
}
return br.mrr.Close()
br.closed = true
err := br.mrr.Close()
if br.err == nil {
br.err = err
}
return br.err
}
func (br *batchResults) earlyError() error {
return br.err
}
func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
@ -195,14 +202,15 @@ func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
}
type pipelineBatchResults struct {
ctx context.Context
conn *Conn
pipeline *pgconn.Pipeline
lastRows *baseRows
err error
b *Batch
ix int
closed bool
ctx context.Context
conn *Conn
pipeline *pgconn.Pipeline
lastRows *baseRows
err error
b *Batch
ix int
closed bool
endTraced bool
}
// Exec reads the results from the next query in the batch as if the query has been sent with Exec.
@ -227,25 +235,17 @@ func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) {
var commandTag pgconn.CommandTag
switch results := results.(type) {
case *pgconn.ResultReader:
commandTag, err = results.Close()
commandTag, br.err = results.Close()
default:
return pgconn.CommandTag{}, fmt.Errorf("unexpected pipeline result: %T", results)
}
if err != nil {
br.err = err
if br.conn.shouldLog(LogLevelError) {
br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]any{
"sql": query,
"args": logQueryArgs(arguments),
"err": err,
})
}
} else if br.conn.shouldLog(LogLevelInfo) {
br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Exec", map[string]any{
"sql": query,
"args": logQueryArgs(arguments),
"commandTag": commandTag,
if br.conn.batchTracer != nil {
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
SQL: query,
Args: arguments,
CommandTag: commandTag,
Err: br.err,
})
}
@ -274,6 +274,7 @@ func (br *pipelineBatchResults) Query() (Rows, error) {
}
rows := br.conn.getRows(br.ctx, query, arguments)
rows.batchTracer = br.conn.batchTracer
br.lastRows = rows
results, err := br.pipeline.GetResults()
@ -281,11 +282,12 @@ func (br *pipelineBatchResults) Query() (Rows, error) {
br.err = err
rows.err = err
rows.closed = true
if br.conn.shouldLog(LogLevelError) {
br.conn.log(br.ctx, LogLevelError, "BatchResult.Query", map[string]any{
"sql": query,
"args": logQueryArgs(arguments),
"err": rows.err,
if br.conn.batchTracer != nil {
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
SQL: query,
Args: arguments,
Err: err,
})
}
} else {
@ -313,6 +315,15 @@ func (br *pipelineBatchResults) QueryRow() Row {
// Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to
// resyncronize the connection with the server. In this case the underlying connection will have been closed.
func (br *pipelineBatchResults) Close() error {
defer func() {
if !br.endTraced {
if br.conn.batchTracer != nil {
br.conn.batchTracer.TraceBatchEnd(br.ctx, br.conn, TraceBatchEndData{Err: br.err})
}
br.endTraced = true
}
}()
if br.err != nil {
return br.err
}
@ -325,24 +336,25 @@ func (br *pipelineBatchResults) Close() error {
if br.closed {
return nil
}
br.closed = true
// log any queries that haven't yet been logged by Exec or Query
for {
query, args, ok := br.nextQueryAndArgs()
if !ok {
break
}
if br.conn.shouldLog(LogLevelInfo) {
br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Close", map[string]any{
"sql": query,
"args": logQueryArgs(args),
})
// consume and log any queries that haven't yet been logged by Exec or Query
if br.conn.batchTracer != nil {
for br.err == nil && !br.closed && br.b != nil && br.ix < len(br.b.items) {
br.Exec()
}
}
br.closed = true
return br.pipeline.Close()
err := br.pipeline.Close()
if br.err == nil {
br.err = err
}
return br.err
}
func (br *pipelineBatchResults) earlyError() error {
return br.err
}
func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, ok bool) {

View File

@ -733,94 +733,6 @@ func testConnSendBatch(t *testing.T, conn *pgx.Conn, queryCount int) {
require.NoError(t, err)
}
func TestLogBatchStatementsOnExec(t *testing.T) {
l1 := &testLogger{}
config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
config.Logger = l1
conn := mustConnect(t, config)
defer closeConn(t, conn)
l1.logs = l1.logs[0:0] // Clear logs written when establishing connection
batch := &pgx.Batch{}
batch.Queue("create table foo (id bigint)")
batch.Queue("drop table foo")
br := conn.SendBatch(context.Background(), batch)
_, err := br.Exec()
if err != nil {
t.Fatalf("Unexpected error creating table: %v", err)
}
_, err = br.Exec()
if err != nil {
t.Fatalf("Unexpected error dropping table: %v", err)
}
if len(l1.logs) != 2 {
t.Fatalf("Expected two log entries but got %d", len(l1.logs))
}
if l1.logs[0].msg != "BatchResult.Exec" {
t.Errorf("Expected first log message to be 'BatchResult.Exec' but was '%s", l1.logs[0].msg)
}
if l1.logs[0].data["sql"] != "create table foo (id bigint)" {
t.Errorf("Expected the first query to be 'create table foo (id bigint)' but was '%s'", l1.logs[0].data["sql"])
}
if l1.logs[1].msg != "BatchResult.Exec" {
t.Errorf("Expected second log message to be 'BatchResult.Exec' but was '%s", l1.logs[1].msg)
}
if l1.logs[1].data["sql"] != "drop table foo" {
t.Errorf("Expected the second query to be 'drop table foo' but was '%s'", l1.logs[1].data["sql"])
}
}
func TestLogBatchStatementsOnBatchResultClose(t *testing.T) {
l1 := &testLogger{}
config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
config.Logger = l1
conn := mustConnect(t, config)
defer closeConn(t, conn)
l1.logs = l1.logs[0:0] // Clear logs written when establishing connection
batch := &pgx.Batch{}
batch.Queue("select generate_series(1,$1)", 100)
batch.Queue("select 1 = 1;")
br := conn.SendBatch(context.Background(), batch)
if err := br.Close(); err != nil {
t.Fatalf("Unexpected batch error: %v", err)
}
if len(l1.logs) != 2 {
t.Fatalf("Expected 2 log statements but found %d", len(l1.logs))
}
if l1.logs[0].msg != "BatchResult.Close" {
t.Errorf("Expected first log statement to be 'BatchResult.Close' but was %s", l1.logs[0].msg)
}
if l1.logs[0].data["sql"] != "select generate_series(1,$1)" {
t.Errorf("Expected first query to be 'select generate_series(1,$1)' but was '%s'", l1.logs[0].data["sql"])
}
if l1.logs[1].msg != "BatchResult.Close" {
t.Errorf("Expected second log statement to be 'BatchResult.Close' but was %s", l1.logs[1].msg)
}
if l1.logs[1].data["sql"] != "select 1 = 1;" {
t.Errorf("Expected second query to be 'select 1 = 1;' but was '%s'", l1.logs[1].data["sql"])
}
}
func TestSendBatchSimpleProtocol(t *testing.T) {
t.Parallel()

View File

@ -284,123 +284,6 @@ func BenchmarkPointerPointerWithPresentValues(b *testing.B) {
}
}
func BenchmarkSelectWithoutLogging(b *testing.B) {
conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")))
defer closeConn(b, conn)
benchmarkSelectWithLog(b, conn)
}
type discardLogger struct{}
func (dl discardLogger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]any) {
}
func BenchmarkSelectWithLoggingTraceDiscard(b *testing.B) {
var logger discardLogger
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
config.Logger = logger
config.LogLevel = pgx.LogLevelTrace
conn := mustConnect(b, config)
defer closeConn(b, conn)
benchmarkSelectWithLog(b, conn)
}
func BenchmarkSelectWithLoggingDebugWithDiscard(b *testing.B) {
var logger discardLogger
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
config.Logger = logger
config.LogLevel = pgx.LogLevelDebug
conn := mustConnect(b, config)
defer closeConn(b, conn)
benchmarkSelectWithLog(b, conn)
}
func BenchmarkSelectWithLoggingInfoWithDiscard(b *testing.B) {
var logger discardLogger
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
config.Logger = logger
config.LogLevel = pgx.LogLevelInfo
conn := mustConnect(b, config)
defer closeConn(b, conn)
benchmarkSelectWithLog(b, conn)
}
func BenchmarkSelectWithLoggingErrorWithDiscard(b *testing.B) {
var logger discardLogger
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
config.Logger = logger
config.LogLevel = pgx.LogLevelError
conn := mustConnect(b, config)
defer closeConn(b, conn)
benchmarkSelectWithLog(b, conn)
}
func benchmarkSelectWithLog(b *testing.B, conn *pgx.Conn) {
_, err := conn.Prepare(context.Background(), "test", "select 1::int4, 'johnsmith', 'johnsmith@example.com', 'John Smith', 'male', '1970-01-01'::date, '2015-01-01 00:00:00'::timestamptz")
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
var record struct {
id int32
userName string
email string
name string
sex string
birthDate time.Time
lastLoginTime time.Time
}
err = conn.QueryRow(context.Background(), "test").Scan(
&record.id,
&record.userName,
&record.email,
&record.name,
&record.sex,
&record.birthDate,
&record.lastLoginTime,
)
if err != nil {
b.Fatal(err)
}
// These checks both ensure that the correct data was returned
// and provide a benchmark of accessing the returned values.
if record.id != 1 {
b.Fatalf("bad value for id: %v", record.id)
}
if record.userName != "johnsmith" {
b.Fatalf("bad value for userName: %v", record.userName)
}
if record.email != "johnsmith@example.com" {
b.Fatalf("bad value for email: %v", record.email)
}
if record.name != "John Smith" {
b.Fatalf("bad value for name: %v", record.name)
}
if record.sex != "male" {
b.Fatalf("bad value for sex: %v", record.sex)
}
if record.birthDate != time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC) {
b.Fatalf("bad value for birthDate: %v", record.birthDate)
}
if record.lastLoginTime != time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local) {
b.Fatalf("bad value for lastLoginTime: %v", record.lastLoginTime)
}
}
}
const benchmarkWriteTableCreateSQL = `drop table if exists t;
create table t(

119
conn.go
View File

@ -19,8 +19,8 @@ import (
// then it can be modified. A manually initialized ConnConfig will cause ConnectConfig to panic.
type ConnConfig struct {
pgconn.Config
Logger Logger
LogLevel LogLevel
Tracer QueryTracer
// Original connection string that was parsed into config.
connString string
@ -63,8 +63,11 @@ type Conn struct {
preparedStatements map[string]*pgconn.StatementDescription
statementCache stmtcache.Cache
descriptionCache stmtcache.Cache
logger Logger
logLevel LogLevel
queryTracer QueryTracer
batchTracer BatchTracer
copyFromTracer CopyFromTracer
prepareTracer PrepareTracer
notifications []*pgconn.Notification
@ -94,9 +97,6 @@ func (ident Identifier) Sanitize() string {
// ErrNoRows occurs when rows are expected but none are returned.
var ErrNoRows = errors.New("no rows in result set")
// ErrInvalidLogLevel occurs on attempt to set an invalid log level.
var ErrInvalidLogLevel = errors.New("invalid log level")
var errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache")
var errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache")
@ -182,7 +182,6 @@ func ParseConfig(connString string) (*ConnConfig, error) {
connConfig := &ConnConfig{
Config: *config,
createdByParseConfig: true,
LogLevel: LogLevelInfo,
StatementCacheCapacity: statementCacheCapacity,
DescriptionCacheCapacity: descriptionCacheCapacity,
DefaultQueryExecMode: defaultQueryExecMode,
@ -194,6 +193,13 @@ func ParseConfig(connString string) (*ConnConfig, error) {
// connect connects to a database. connect takes ownership of config. The caller must not use or access it again.
func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) {
if connectTracer, ok := config.Tracer.(ConnectTracer); ok {
ctx = connectTracer.TraceConnectStart(ctx, TraceConnectStartData{ConnConfig: config})
defer func() {
connectTracer.TraceConnectEnd(ctx, TraceConnectEndData{Conn: c, Err: err})
}()
}
// Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from
// zero values.
if !config.createdByParseConfig {
@ -201,29 +207,28 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) {
}
c = &Conn{
config: config,
typeMap: pgtype.NewMap(),
logLevel: config.LogLevel,
logger: config.Logger,
config: config,
typeMap: pgtype.NewMap(),
queryTracer: config.Tracer,
}
if t, ok := c.queryTracer.(BatchTracer); ok {
c.batchTracer = t
}
if t, ok := c.queryTracer.(CopyFromTracer); ok {
c.copyFromTracer = t
}
if t, ok := c.queryTracer.(PrepareTracer); ok {
c.prepareTracer = t
}
// Only install pgx notification system if no other callback handler is present.
if config.Config.OnNotification == nil {
config.Config.OnNotification = c.bufferNotifications
} else {
if c.shouldLog(LogLevelDebug) {
c.log(ctx, LogLevelDebug, "pgx notification handler disabled by application supplied OnNotification", map[string]any{"host": config.Config.Host})
}
}
if c.shouldLog(LogLevelInfo) {
c.log(ctx, LogLevelInfo, "Dialing PostgreSQL server", map[string]any{"host": config.Config.Host})
}
c.pgConn, err = pgconn.ConnectConfig(ctx, &config.Config)
if err != nil {
if c.shouldLog(LogLevelError) {
c.log(ctx, LogLevelError, "connect failed", map[string]any{"err": err})
}
return nil, err
}
@ -251,9 +256,6 @@ func (c *Conn) Close(ctx context.Context) error {
}
err := c.pgConn.Close(ctx)
if c.shouldLog(LogLevelInfo) {
c.log(ctx, LogLevelInfo, "closed connection", nil)
}
return err
}
@ -264,18 +266,23 @@ func (c *Conn) Close(ctx context.Context) error {
// name and sql arguments. This allows a code path to Prepare and Query/Exec without
// concern for if the statement has already been prepared.
func (c *Conn) Prepare(ctx context.Context, name, sql string) (sd *pgconn.StatementDescription, err error) {
if c.prepareTracer != nil {
ctx = c.prepareTracer.TracePrepareStart(ctx, c, TracePrepareStartData{Name: name, SQL: sql})
}
if name != "" {
var ok bool
if sd, ok = c.preparedStatements[name]; ok && sd.SQL == sql {
if c.prepareTracer != nil {
c.prepareTracer.TracePrepareEnd(ctx, c, TracePrepareEndData{AlreadyPrepared: true})
}
return sd, nil
}
}
if c.shouldLog(LogLevelError) {
if c.prepareTracer != nil {
defer func() {
if err != nil {
c.log(ctx, LogLevelError, "Prepare failed", map[string]any{"err": err, "name": name, "sql": sql})
}
c.prepareTracer.TracePrepareEnd(ctx, c, TracePrepareEndData{Err: err})
}()
}
@ -337,21 +344,6 @@ func (c *Conn) die(err error) {
c.pgConn.Close(ctx)
}
func (c *Conn) shouldLog(lvl LogLevel) bool {
return c.logger != nil && c.logLevel >= lvl
}
func (c *Conn) log(ctx context.Context, lvl LogLevel, msg string, data map[string]any) {
if data == nil {
data = map[string]any{}
}
if c.pgConn != nil && c.pgConn.PID() != 0 {
data["pid"] = c.pgConn.PID()
}
c.logger.Log(ctx, lvl, msg, data)
}
func quoteIdentifier(s string) string {
return `"` + strings.ReplaceAll(s, `"`, `""`) + `"`
}
@ -379,24 +371,18 @@ func (c *Conn) Config() *ConnConfig { return c.config.Copy() }
// Exec executes sql. sql can be either a prepared statement name or an SQL string. arguments should be referenced
// positionally from the sql string as $1, $2, etc.
func (c *Conn) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) {
if c.queryTracer != nil {
ctx = c.queryTracer.TraceQueryStart(ctx, c, TraceQueryStartData{SQL: sql, Args: arguments})
}
if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil {
return pgconn.CommandTag{}, err
}
startTime := time.Now()
commandTag, err := c.exec(ctx, sql, arguments...)
if err != nil {
if c.shouldLog(LogLevelError) {
endTime := time.Now()
c.log(ctx, LogLevelError, "Exec", map[string]any{"sql": sql, "args": logQueryArgs(arguments), "err": err, "time": endTime.Sub(startTime)})
}
return commandTag, err
}
if c.shouldLog(LogLevelInfo) {
endTime := time.Now()
c.log(ctx, LogLevelInfo, "Exec", map[string]any{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag})
if c.queryTracer != nil {
c.queryTracer.TraceQueryEnd(ctx, c, TraceQueryEndData{CommandTag: commandTag, Err: err})
}
return commandTag, err
@ -537,7 +523,7 @@ func (c *Conn) getRows(ctx context.Context, sql string, args []any) *baseRows {
r := &baseRows{}
r.ctx = ctx
r.logger = c
r.queryTracer = c.queryTracer
r.typeMap = c.typeMap
r.startTime = time.Now()
r.sql = sql
@ -628,7 +614,14 @@ type QueryRewriter interface {
// QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely
// needed. See the documentation for those types for details.
func (c *Conn) Query(ctx context.Context, sql string, args ...any) (Rows, error) {
if c.queryTracer != nil {
ctx = c.queryTracer.TraceQueryStart(ctx, c, TraceQueryStartData{SQL: sql, Args: args})
}
if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil {
if c.queryTracer != nil {
c.queryTracer.TraceQueryEnd(ctx, c, TraceQueryEndData{Err: err})
}
return &baseRows{err: err, closed: true}, err
}
@ -791,7 +784,17 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...any) Row {
// SendBatch sends all queued queries to the server at once. All queries are run in an implicit transaction unless
// explicit transaction control statements are executed. The returned BatchResults must be closed before the connection
// is used again.
func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
if c.batchTracer != nil {
ctx = c.batchTracer.TraceBatchStart(ctx, c, TraceBatchStartData{Batch: b})
defer func() {
err := br.(interface{ earlyError() error }).earlyError()
if err != nil {
c.batchTracer.TraceBatchEnd(ctx, c, TraceBatchEndData{Err: err})
}
}()
}
if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}

View File

@ -3,7 +3,6 @@ package pgx_test
import (
"bytes"
"context"
"log"
"os"
"strings"
"sync"
@ -743,79 +742,6 @@ func TestInsertTimestampArray(t *testing.T) {
})
}
type testLog struct {
lvl pgx.LogLevel
msg string
data map[string]any
}
type testLogger struct {
logs []testLog
}
func (l *testLogger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]any) {
data["ctxdata"] = ctx.Value("ctxdata")
l.logs = append(l.logs, testLog{lvl: level, msg: msg, data: data})
}
func TestLogPassesContext(t *testing.T) {
t.Parallel()
l1 := &testLogger{}
config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
config.Logger = l1
conn := mustConnect(t, config)
defer closeConn(t, conn)
l1.logs = l1.logs[0:0] // Clear logs written when establishing connection
ctx := context.WithValue(context.Background(), "ctxdata", "foo")
if _, err := conn.Exec(ctx, ";"); err != nil {
t.Fatal(err)
}
if len(l1.logs) != 1 {
t.Fatal("Expected logger to be called once, but it wasn't")
}
if l1.logs[0].data["ctxdata"] != "foo" {
t.Fatal("Expected context data to be passed to logger, but it wasn't")
}
}
func TestLoggerFunc(t *testing.T) {
t.Parallel()
const testMsg = "foo"
buf := bytes.Buffer{}
logger := log.New(&buf, "", 0)
createAdapterFn := func(logger *log.Logger) pgx.LoggerFunc {
return func(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
logger.Printf("%s", testMsg)
}
}
config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
config.Logger = createAdapterFn(logger)
conn := mustConnect(t, config)
defer closeConn(t, conn)
buf.Reset() // Clear logs written when establishing connection
if _, err := conn.Exec(context.TODO(), ";"); err != nil {
t.Fatal(err)
}
if strings.TrimSpace(buf.String()) != testMsg {
t.Errorf("Expected logger function to return '%s', but it was '%s'", testMsg, buf.String())
}
}
func TestIdentifierSanitize(t *testing.T) {
t.Parallel()

View File

@ -5,7 +5,6 @@ import (
"context"
"fmt"
"io"
"time"
"github.com/jackc/pgx/v5/internal/pgio"
"github.com/jackc/pgx/v5/pgconn"
@ -89,6 +88,13 @@ type copyFrom struct {
}
func (ct *copyFrom) run(ctx context.Context) (int64, error) {
if ct.conn.copyFromTracer != nil {
ctx = ct.conn.copyFromTracer.TraceCopyFromStart(ctx, ct.conn, TraceCopyFromStartData{
TableName: ct.tableName,
ColumnNames: ct.columnNames,
})
}
quotedTableName := ct.tableName.Sanitize()
cbuf := &bytes.Buffer{}
for i, cn := range ct.columnNames {
@ -145,24 +151,19 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) {
w.Close()
}()
startTime := time.Now()
commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
r.Close()
<-doneChan
rowsAffected := commandTag.RowsAffected()
endTime := time.Now()
if err == nil {
if ct.conn.shouldLog(LogLevelInfo) {
ct.conn.log(ctx, LogLevelInfo, "CopyFrom", map[string]any{"tableName": ct.tableName, "columnNames": ct.columnNames, "time": endTime.Sub(startTime), "rowCount": rowsAffected})
}
} else if ct.conn.shouldLog(LogLevelError) {
ct.conn.log(ctx, LogLevelError, "CopyFrom", map[string]any{"err": err, "tableName": ct.tableName, "columnNames": ct.columnNames, "time": endTime.Sub(startTime)})
if ct.conn.copyFromTracer != nil {
ct.conn.copyFromTracer.TraceCopyFromEnd(ctx, ct.conn, TraceCopyFromEndData{
CommandTag: commandTag,
Err: err,
})
}
return rowsAffected, err
return commandTag.RowsAffected(), err
}
func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (bool, []byte, error) {

18
doc.go
View File

@ -12,15 +12,7 @@ The primary way of establishing a connection is with `pgx.Connect`.
The database connection string can be in URL or DSN format. Both PostgreSQL settings and pgx settings can be specified
here. In addition, a config struct can be created by `ParseConfig` and modified before establishing the connection with
`ConnectConfig`.
config, err := pgx.ParseConfig(os.Getenv("DATABASE_URL"))
if err != nil {
// ...
}
config.Logger = log15adapter.NewLogger(log.New("module", "pgx"))
conn, err := pgx.ConnectConfig(context.Background(), config)
`ConnectConfig` to configure settings such as tracing that cannot be configured with a connection string.
Connection Pool
@ -315,11 +307,11 @@ notification is received or the context is canceled.
}
Logging
Tracing and Logging
pgx defines a simple logger interface. Connections optionally accept a logger that satisfies this interface. Set
LogLevel to control logging verbosity. Adapters for github.com/inconshreveable/log15, github.com/sirupsen/logrus,
go.uber.org/zap, github.com/rs/zerolog, and the testing log are provided in the log directory.
pgx supports tracing by setting ConnConfig.Tracer.
In addition, the tracelog package provides the TraceLog type which lets a traditional logger act as a Tracer.
Lower Level PostgreSQL Functionality

View File

@ -98,8 +98,7 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, testName
return
}
assert.Equalf(t, expected.Logger, actual.Logger, "%s - Logger", testName)
assert.Equalf(t, expected.LogLevel, actual.LogLevel, "%s - LogLevel", testName)
assert.Equalf(t, expected.Tracer, actual.Tracer, "%s - Tracer", testName)
assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName)
assert.Equalf(t, expected.StatementCacheCapacity, actual.StatementCacheCapacity, "%s - StatementCacheCapacity", testName)
assert.Equalf(t, expected.DescriptionCacheCapacity, actual.DescriptionCacheCapacity, "%s - DescriptionCacheCapacity", testName)

View File

@ -162,7 +162,7 @@ func (f *Frontend) SendExecute(msg *Execute) {
prevLen := len(f.wbuf)
f.wbuf = msg.Encode(f.wbuf)
if f.tracer != nil {
f.tracer.traceExecute('F', int32(len(f.wbuf)-prevLen), msg)
f.tracer.TraceQueryute('F', int32(len(f.wbuf)-prevLen), msg)
}
}

View File

@ -79,7 +79,7 @@ func (t *tracer) traceMessage(sender byte, encodedLen int32, msg Message) {
case *ErrorResponse:
t.traceErrorResponse(sender, encodedLen, msg)
case *Execute:
t.traceExecute(sender, encodedLen, msg)
t.TraceQueryute(sender, encodedLen, msg)
case *Flush:
t.traceFlush(sender, encodedLen, msg)
case *FunctionCall:
@ -277,7 +277,7 @@ func (t *tracer) traceErrorResponse(sender byte, encodedLen int32, msg *ErrorRes
t.finishTrace()
}
func (t *tracer) traceExecute(sender byte, encodedLen int32, msg *Execute) {
func (t *tracer) TraceQueryute(sender byte, encodedLen int32, msg *Execute) {
t.beginTrace(sender, encodedLen, "Execute")
fmt.Fprintf(t.buf, "\t %s %d", traceDoubleQuotedString([]byte(msg.Portal)), msg.MaxRows)
t.finishTrace()

View File

@ -160,8 +160,7 @@ func assertConnConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, test
return
}
assert.Equalf(t, expected.Logger, actual.Logger, "%s - Logger", testName)
assert.Equalf(t, expected.LogLevel, actual.LogLevel, "%s - LogLevel", testName)
assert.Equalf(t, expected.Tracer, actual.Tracer, "%s - Tracer", testName)
assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName)
assert.Equalf(t, expected.StatementCacheCapacity, actual.StatementCacheCapacity, "%s - StatementCacheCapacity", testName)
assert.Equalf(t, expected.DescriptionCacheCapacity, actual.DescriptionCacheCapacity, "%s - DescriptionCacheCapacity", testName)

38
rows.go
View File

@ -105,11 +105,6 @@ func (r *connRow) Scan(dest ...any) (err error) {
return rows.Err()
}
type rowLog interface {
shouldLog(lvl LogLevel) bool
log(ctx context.Context, lvl LogLevel, msg string, data map[string]any)
}
// baseRows implements the Rows interface for Conn.Query.
type baseRows struct {
typeMap *pgtype.Map
@ -127,12 +122,13 @@ type baseRows struct {
conn *Conn
multiResultReader *pgconn.MultiResultReader
logger rowLog
ctx context.Context
startTime time.Time
sql string
args []any
rowCount int
queryTracer QueryTracer
batchTracer BatchTracer
ctx context.Context
startTime time.Time
sql string
args []any
rowCount int
}
func (rows *baseRows) FieldDescriptions() []pgproto3.FieldDescription {
@ -161,20 +157,6 @@ func (rows *baseRows) Close() {
}
}
if rows.logger != nil {
endTime := time.Now()
if rows.err == nil {
if rows.logger.shouldLog(LogLevelInfo) {
rows.logger.log(rows.ctx, LogLevelInfo, "Query", map[string]any{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount})
}
} else {
if rows.logger.shouldLog(LogLevelError) {
rows.logger.log(rows.ctx, LogLevelError, "Query", map[string]any{"err": rows.err, "sql": rows.sql, "time": endTime.Sub(rows.startTime), "args": logQueryArgs(rows.args)})
}
}
}
if rows.err != nil && rows.conn != nil && rows.sql != "" {
if stmtcache.IsStatementInvalid(rows.err) {
if sc := rows.conn.statementCache; sc != nil {
@ -186,6 +168,12 @@ func (rows *baseRows) Close() {
}
}
}
if rows.batchTracer != nil {
rows.batchTracer.TraceBatchQuery(rows.ctx, rows.conn, TraceBatchQueryData{SQL: rows.sql, Args: rows.args, CommandTag: rows.commandTag, Err: rows.err})
} else if rows.queryTracer != nil {
rows.queryTracer.TraceQueryEnd(rows.ctx, rows.conn, TraceQueryEndData{rows.commandTag, rows.err})
}
}
func (rows *baseRows) CommandTag() pgconn.CommandTag {

View File

@ -17,6 +17,7 @@ import (
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/stdlib"
"github.com/jackc/pgx/v5/tracelog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -976,7 +977,7 @@ func TestScanJSONIntoJSONRawMessage(t *testing.T) {
}
type testLog struct {
lvl pgx.LogLevel
lvl tracelog.LogLevel
msg string
data map[string]any
}
@ -985,7 +986,7 @@ type testLogger struct {
logs []testLog
}
func (l *testLogger) Log(ctx context.Context, lvl pgx.LogLevel, msg string, data map[string]any) {
func (l *testLogger) Log(ctx context.Context, lvl tracelog.LogLevel, msg string, data map[string]any) {
l.logs = append(l.logs, testLog{lvl: lvl, msg: msg, data: data})
}
@ -994,7 +995,7 @@ func TestRegisterConnConfig(t *testing.T) {
require.NoError(t, err)
logger := &testLogger{}
connConfig.Logger = logger
connConfig.Tracer = &tracelog.TraceLog{Logger: logger, LogLevel: tracelog.LogLevelInfo}
// Issue 947: Register and unregister a ConnConfig and ensure that the
// returned connection string is not reused.

295
tracelog/tracelog.go Normal file
View File

@ -0,0 +1,295 @@
// Package tracelog provides a tracer that acts as a traditional logger.
package tracelog
import (
"context"
"encoding/hex"
"errors"
"fmt"
"time"
"github.com/jackc/pgx/v5"
)
// LogLevel represents the pgx logging level. See LogLevel* constants for
// possible values.
type LogLevel int
// The values for log levels are chosen such that the zero value means that no
// log level was specified.
const (
LogLevelTrace = LogLevel(6)
LogLevelDebug = LogLevel(5)
LogLevelInfo = LogLevel(4)
LogLevelWarn = LogLevel(3)
LogLevelError = LogLevel(2)
LogLevelNone = LogLevel(1)
)
func (ll LogLevel) String() string {
switch ll {
case LogLevelTrace:
return "trace"
case LogLevelDebug:
return "debug"
case LogLevelInfo:
return "info"
case LogLevelWarn:
return "warn"
case LogLevelError:
return "error"
case LogLevelNone:
return "none"
default:
return fmt.Sprintf("invalid level %d", ll)
}
}
// Logger is the interface used to get log output from pgx.
type Logger interface {
// Log a message at the given level with data key/value pairs. data may be nil.
Log(ctx context.Context, level LogLevel, msg string, data map[string]any)
}
// LoggerFunc is a wrapper around a function to satisfy the pgx.Logger interface
type LoggerFunc func(ctx context.Context, level LogLevel, msg string, data map[string]interface{})
// Log delegates the logging request to the wrapped function
func (f LoggerFunc) Log(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) {
f(ctx, level, msg, data)
}
// LogLevelFromString converts log level string to constant
//
// Valid levels:
// trace
// debug
// info
// warn
// error
// none
func LogLevelFromString(s string) (LogLevel, error) {
switch s {
case "trace":
return LogLevelTrace, nil
case "debug":
return LogLevelDebug, nil
case "info":
return LogLevelInfo, nil
case "warn":
return LogLevelWarn, nil
case "error":
return LogLevelError, nil
case "none":
return LogLevelNone, nil
default:
return 0, errors.New("invalid log level")
}
}
func logQueryArgs(args []any) []any {
logArgs := make([]any, 0, len(args))
for _, a := range args {
switch v := a.(type) {
case []byte:
if len(v) < 64 {
a = hex.EncodeToString(v)
} else {
a = fmt.Sprintf("%x (truncated %d bytes)", v[:64], len(v)-64)
}
case string:
if len(v) > 64 {
a = fmt.Sprintf("%s (truncated %d bytes)", v[:64], len(v)-64)
}
}
logArgs = append(logArgs, a)
}
return logArgs
}
// TraceLog implements pgx.QueryTracer, pgx.BatchTracer, pgx.ConnectTracer, and pgx.CopyFromTracer. All fields are
// required.
type TraceLog struct {
Logger Logger
LogLevel LogLevel
}
type ctxKey int
const (
_ ctxKey = iota
tracelogQueryCtxKey
tracelogBatchCtxKey
tracelogCopyFromCtxKey
tracelogConnectCtxKey
)
type traceQueryData struct {
startTime time.Time
sql string
args []any
}
func (tl *TraceLog) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
return context.WithValue(ctx, tracelogQueryCtxKey, &traceQueryData{
startTime: time.Now(),
sql: data.SQL,
args: data.Args,
})
}
func (tl *TraceLog) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
queryData := ctx.Value(tracelogQueryCtxKey).(*traceQueryData)
endTime := time.Now()
interval := endTime.Sub(queryData.startTime)
if data.Err != nil {
if tl.shouldLog(LogLevelError) {
tl.log(ctx, conn, LogLevelError, "Query", map[string]any{"sql": queryData.sql, "args": logQueryArgs(queryData.args), "err": data.Err, "time": interval})
}
return
}
if tl.shouldLog(LogLevelInfo) {
tl.log(ctx, conn, LogLevelInfo, "Query", map[string]any{"sql": queryData.sql, "args": logQueryArgs(queryData.args), "time": interval, "commandTag": data.CommandTag.String()})
}
}
type traceBatchData struct {
startTime time.Time
}
func (tl *TraceLog) TraceBatchStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
return context.WithValue(ctx, tracelogBatchCtxKey, &traceBatchData{
startTime: time.Now(),
})
}
func (tl *TraceLog) TraceBatchQuery(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
if data.Err != nil {
if tl.shouldLog(LogLevelError) {
tl.log(ctx, conn, LogLevelError, "BatchQuery", map[string]any{"sql": data.SQL, "args": logQueryArgs(data.Args), "err": data.Err})
}
return
}
if tl.shouldLog(LogLevelInfo) {
tl.log(ctx, conn, LogLevelInfo, "BatchQuery", map[string]any{"sql": data.SQL, "args": logQueryArgs(data.Args), "commandTag": data.CommandTag.String()})
}
}
func (tl *TraceLog) TraceBatchEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
queryData := ctx.Value(tracelogBatchCtxKey).(*traceBatchData)
endTime := time.Now()
interval := endTime.Sub(queryData.startTime)
if data.Err != nil {
if tl.shouldLog(LogLevelError) {
tl.log(ctx, conn, LogLevelError, "BatchClose", map[string]any{"err": data.Err, "time": interval})
}
return
}
if tl.shouldLog(LogLevelInfo) {
tl.log(ctx, conn, LogLevelInfo, "BatchClose", map[string]any{"time": interval})
}
}
type traceCopyFromData struct {
startTime time.Time
TableName pgx.Identifier
ColumnNames []string
}
func (tl *TraceLog) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context {
return context.WithValue(ctx, tracelogCopyFromCtxKey, &traceCopyFromData{
startTime: time.Now(),
TableName: data.TableName,
ColumnNames: data.ColumnNames,
})
}
func (tl *TraceLog) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) {
copyFromData := ctx.Value(tracelogCopyFromCtxKey).(*traceCopyFromData)
endTime := time.Now()
interval := endTime.Sub(copyFromData.startTime)
if data.Err != nil {
if tl.shouldLog(LogLevelError) {
tl.log(ctx, conn, LogLevelError, "CopyFrom", map[string]any{"tableName": copyFromData.TableName, "columnNames": copyFromData.ColumnNames, "err": data.Err, "time": interval})
}
return
}
if tl.shouldLog(LogLevelInfo) {
tl.log(ctx, conn, LogLevelInfo, "CopyFrom", map[string]any{"tableName": copyFromData.TableName, "columnNames": copyFromData.ColumnNames, "err": data.Err, "time": interval, "rowCount": data.CommandTag.RowsAffected()})
}
}
type traceConnectData struct {
startTime time.Time
connConfig *pgx.ConnConfig
}
func (tl *TraceLog) TraceConnectStart(ctx context.Context, data pgx.TraceConnectStartData) context.Context {
return context.WithValue(ctx, tracelogConnectCtxKey, &traceConnectData{
startTime: time.Now(),
connConfig: data.ConnConfig,
})
}
func (tl *TraceLog) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEndData) {
connectData := ctx.Value(tracelogConnectCtxKey).(*traceConnectData)
endTime := time.Now()
interval := endTime.Sub(connectData.startTime)
if data.Err != nil {
if tl.shouldLog(LogLevelError) {
tl.Logger.Log(ctx, LogLevelError, "Connect", map[string]any{
"host": connectData.connConfig.Host,
"port": connectData.connConfig.Port,
"database": connectData.connConfig.Database,
"time": interval,
"err": data.Err,
})
}
return
}
if data.Conn != nil {
if tl.shouldLog(LogLevelInfo) {
tl.log(ctx, data.Conn, LogLevelInfo, "Connect", map[string]any{
"host": connectData.connConfig.Host,
"port": connectData.connConfig.Port,
"database": connectData.connConfig.Database,
"time": interval,
})
}
}
}
func (tl *TraceLog) shouldLog(lvl LogLevel) bool {
return tl.LogLevel >= lvl
}
func (tl *TraceLog) log(ctx context.Context, conn *pgx.Conn, lvl LogLevel, msg string, data map[string]any) {
if data == nil {
data = map[string]any{}
}
pgConn := conn.PgConn()
if pgConn != nil {
pid := pgConn.PID()
if pid != 0 {
data["pid"] = pid
}
}
tl.Logger.Log(ctx, lvl, msg, data)
}

301
tracelog/tracelog_test.go Normal file
View File

@ -0,0 +1,301 @@
package tracelog_test
import (
"bytes"
"context"
"log"
"os"
"strings"
"testing"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxtest"
"github.com/jackc/pgx/v5/tracelog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var defaultConnTestRunner pgxtest.ConnTestRunner
func init() {
defaultConnTestRunner = pgxtest.DefaultConnTestRunner()
defaultConnTestRunner.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
return config
}
}
type testLog struct {
lvl tracelog.LogLevel
msg string
data map[string]any
}
type testLogger struct {
logs []testLog
}
func (l *testLogger) Log(ctx context.Context, level tracelog.LogLevel, msg string, data map[string]any) {
data["ctxdata"] = ctx.Value("ctxdata")
l.logs = append(l.logs, testLog{lvl: level, msg: msg, data: data})
}
func TestContextGetsPassedToLogMethod(t *testing.T) {
t.Parallel()
logger := &testLogger{}
tracer := &tracelog.TraceLog{
Logger: logger,
LogLevel: tracelog.LogLevelTrace,
}
ctr := defaultConnTestRunner
ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
config := defaultConnTestRunner.CreateConfig(ctx, t)
config.Tracer = tracer
return config
}
pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
logger.logs = logger.logs[0:0] // Clear any logs written when establishing connection
ctx = context.WithValue(context.Background(), "ctxdata", "foo")
_, err := conn.Exec(ctx, `;`)
require.NoError(t, err)
require.Len(t, logger.logs, 1)
require.Equal(t, "foo", logger.logs[0].data["ctxdata"])
})
}
func TestLoggerFunc(t *testing.T) {
t.Parallel()
const testMsg = "foo"
buf := bytes.Buffer{}
logger := log.New(&buf, "", 0)
createAdapterFn := func(logger *log.Logger) tracelog.LoggerFunc {
return func(ctx context.Context, level tracelog.LogLevel, msg string, data map[string]interface{}) {
logger.Printf("%s", testMsg)
}
}
config := defaultConnTestRunner.CreateConfig(context.Background(), t)
config.Tracer = &tracelog.TraceLog{
Logger: createAdapterFn(logger),
LogLevel: tracelog.LogLevelTrace,
}
conn, err := pgx.ConnectConfig(context.Background(), config)
require.NoError(t, err)
defer conn.Close(context.Background())
buf.Reset() // Clear logs written when establishing connection
if _, err := conn.Exec(context.TODO(), ";"); err != nil {
t.Fatal(err)
}
if strings.TrimSpace(buf.String()) != testMsg {
t.Errorf("Expected logger function to return '%s', but it was '%s'", testMsg, buf.String())
}
}
func TestLogQuery(t *testing.T) {
t.Parallel()
logger := &testLogger{}
tracer := &tracelog.TraceLog{
Logger: logger,
LogLevel: tracelog.LogLevelTrace,
}
ctr := defaultConnTestRunner
ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
config := defaultConnTestRunner.CreateConfig(ctx, t)
config.Tracer = tracer
return config
}
pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
logger.logs = logger.logs[0:0] // Clear any logs written when establishing connection
_, err := conn.Exec(ctx, `select $1::text`, "testing")
require.NoError(t, err)
require.Len(t, logger.logs, 1)
require.Equal(t, "Query", logger.logs[0].msg)
require.Equal(t, tracelog.LogLevelInfo, logger.logs[0].lvl)
_, err = conn.Exec(ctx, `foo`, "testing")
require.Error(t, err)
require.Len(t, logger.logs, 2)
require.Equal(t, "Query", logger.logs[1].msg)
require.Equal(t, tracelog.LogLevelError, logger.logs[1].lvl)
require.Equal(t, err, logger.logs[1].data["err"])
})
}
func TestLogCopyFrom(t *testing.T) {
t.Parallel()
logger := &testLogger{}
tracer := &tracelog.TraceLog{
Logger: logger,
LogLevel: tracelog.LogLevelTrace,
}
ctr := defaultConnTestRunner
ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
config := defaultConnTestRunner.CreateConfig(ctx, t)
config.Tracer = tracer
return config
}
pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
_, err := conn.Exec(context.Background(), `create temporary table foo(a int4)`)
require.NoError(t, err)
logger.logs = logger.logs[0:0]
inputRows := [][]any{
{int32(1)},
{nil},
}
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows))
require.NoError(t, err)
require.EqualValues(t, len(inputRows), copyCount)
require.Len(t, logger.logs, 1)
require.Equal(t, "CopyFrom", logger.logs[0].msg)
require.Equal(t, tracelog.LogLevelInfo, logger.logs[0].lvl)
logger.logs = logger.logs[0:0]
inputRows = [][]any{
{"not an integer"},
{nil},
}
copyCount, err = conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows))
require.Error(t, err)
require.EqualValues(t, 0, copyCount)
require.Len(t, logger.logs, 1)
require.Equal(t, "CopyFrom", logger.logs[0].msg)
require.Equal(t, tracelog.LogLevelError, logger.logs[0].lvl)
})
}
func TestLogConnect(t *testing.T) {
t.Parallel()
logger := &testLogger{}
tracer := &tracelog.TraceLog{
Logger: logger,
LogLevel: tracelog.LogLevelTrace,
}
config := defaultConnTestRunner.CreateConfig(context.Background(), t)
config.Tracer = tracer
conn1, err := pgx.ConnectConfig(context.Background(), config)
require.NoError(t, err)
defer conn1.Close(context.Background())
require.Len(t, logger.logs, 1)
require.Equal(t, "Connect", logger.logs[0].msg)
require.Equal(t, tracelog.LogLevelInfo, logger.logs[0].lvl)
logger.logs = logger.logs[0:0]
config, err = pgx.ParseConfig("host=/invalid")
require.NoError(t, err)
config.Tracer = tracer
conn2, err := pgx.ConnectConfig(context.Background(), config)
require.Nil(t, conn2)
require.Error(t, err)
require.Len(t, logger.logs, 1)
require.Equal(t, "Connect", logger.logs[0].msg)
require.Equal(t, tracelog.LogLevelError, logger.logs[0].lvl)
}
func TestLogBatchStatementsOnExec(t *testing.T) {
t.Parallel()
logger := &testLogger{}
tracer := &tracelog.TraceLog{
Logger: logger,
LogLevel: tracelog.LogLevelTrace,
}
ctr := defaultConnTestRunner
ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
config := defaultConnTestRunner.CreateConfig(ctx, t)
config.Tracer = tracer
return config
}
pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
logger.logs = logger.logs[0:0] // Clear any logs written when establishing connection
batch := &pgx.Batch{}
batch.Queue("create table foo (id bigint)")
batch.Queue("drop table foo")
br := conn.SendBatch(context.Background(), batch)
_, err := br.Exec()
require.NoError(t, err)
_, err = br.Exec()
require.NoError(t, err)
err = br.Close()
require.NoError(t, err)
require.Len(t, logger.logs, 3)
assert.Equal(t, "BatchQuery", logger.logs[0].msg)
assert.Equal(t, "create table foo (id bigint)", logger.logs[0].data["sql"])
assert.Equal(t, "BatchQuery", logger.logs[1].msg)
assert.Equal(t, "drop table foo", logger.logs[1].data["sql"])
assert.Equal(t, "BatchClose", logger.logs[2].msg)
})
}
func TestLogBatchStatementsOnBatchResultClose(t *testing.T) {
t.Parallel()
logger := &testLogger{}
tracer := &tracelog.TraceLog{
Logger: logger,
LogLevel: tracelog.LogLevelTrace,
}
ctr := defaultConnTestRunner
ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
config := defaultConnTestRunner.CreateConfig(ctx, t)
config.Tracer = tracer
return config
}
pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
logger.logs = logger.logs[0:0] // Clear any logs written when establishing connection
batch := &pgx.Batch{}
batch.Queue("select generate_series(1,$1)", 100)
batch.Queue("select 1 = 1;")
br := conn.SendBatch(context.Background(), batch)
err := br.Close()
require.NoError(t, err)
require.Len(t, logger.logs, 3)
assert.Equal(t, "BatchQuery", logger.logs[0].msg)
assert.Equal(t, "select generate_series(1,$1)", logger.logs[0].data["sql"])
assert.Equal(t, "BatchQuery", logger.logs[1].msg)
assert.Equal(t, "select 1 = 1;", logger.logs[1].data["sql"])
assert.Equal(t, "BatchClose", logger.logs[2].msg)
})
}

107
tracer.go Normal file
View File

@ -0,0 +1,107 @@
package pgx
import (
"context"
"github.com/jackc/pgx/v5/pgconn"
)
// QueryTracer traces Query, QueryRow, and Exec.
type QueryTracer interface {
// TraceQueryStart is called at the beginning of Query, QueryRow, and Exec calls. The returned context is used for the
// rest of the call and will be passed to TraceQueryEnd.
TraceQueryStart(ctx context.Context, conn *Conn, data TraceQueryStartData) context.Context
TraceQueryEnd(ctx context.Context, conn *Conn, data TraceQueryEndData)
}
type TraceQueryStartData struct {
SQL string
Args []any
}
type TraceQueryEndData struct {
CommandTag pgconn.CommandTag
Err error
}
// BatchTracer traces SendBatch.
type BatchTracer interface {
// TraceBatchStart is called at the beginning of SendBatch calls. The returned context is used for the
// rest of the call and will be passed to TraceBatchQuery and TraceBatchEnd.
TraceBatchStart(ctx context.Context, conn *Conn, data TraceBatchStartData) context.Context
TraceBatchQuery(ctx context.Context, conn *Conn, data TraceBatchQueryData)
TraceBatchEnd(ctx context.Context, conn *Conn, data TraceBatchEndData)
}
type TraceBatchStartData struct {
Batch *Batch
}
type TraceBatchQueryData struct {
SQL string
Args []any
CommandTag pgconn.CommandTag
Err error
}
type TraceBatchEndData struct {
Err error
}
// CopyFromTracer traces CopyFrom.
type CopyFromTracer interface {
// TraceCopyFromStart is called at the beginning of CopyFrom calls. The returned context is used for the
// rest of the call and will be passed to TraceCopyFromEnd.
TraceCopyFromStart(ctx context.Context, conn *Conn, data TraceCopyFromStartData) context.Context
TraceCopyFromEnd(ctx context.Context, conn *Conn, data TraceCopyFromEndData)
}
type TraceCopyFromStartData struct {
TableName Identifier
ColumnNames []string
}
type TraceCopyFromEndData struct {
CommandTag pgconn.CommandTag
Err error
}
// PrepareTracer traces Prepare.
type PrepareTracer interface {
// TracePrepareStart is called at the beginning of Prepare calls. The returned context is used for the
// rest of the call and will be passed to TracePrepareEnd.
TracePrepareStart(ctx context.Context, conn *Conn, data TracePrepareStartData) context.Context
TracePrepareEnd(ctx context.Context, conn *Conn, data TracePrepareEndData)
}
type TracePrepareStartData struct {
Name string
SQL string
}
type TracePrepareEndData struct {
AlreadyPrepared bool
Err error
}
// ConnectTracer traces Connect and ConnectConfig.
type ConnectTracer interface {
// TraceConnectStart is called at the beginning of Connect and ConnectConfig calls. The returned context is used for
// the rest of the call and will be passed to TraceConnectEnd.
TraceConnectStart(ctx context.Context, data TraceConnectStartData) context.Context
TraceConnectEnd(ctx context.Context, data TraceConnectEndData)
}
type TraceConnectStartData struct {
ConnConfig *ConnConfig
}
type TraceConnectEndData struct {
Conn *Conn
Err error
}

538
tracer_test.go Normal file
View File

@ -0,0 +1,538 @@
package pgx_test
import (
"context"
"testing"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxtest"
"github.com/stretchr/testify/require"
)
type testTracer struct {
traceQueryStart func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context
traceQueryEnd func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData)
traceBatchStart func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context
traceBatchQuery func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData)
traceBatchEnd func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData)
traceCopyFromStart func(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context
traceCopyFromEnd func(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData)
tracePrepareStart func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context
tracePrepareEnd func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData)
traceConnectStart func(ctx context.Context, data pgx.TraceConnectStartData) context.Context
traceConnectEnd func(ctx context.Context, data pgx.TraceConnectEndData)
}
func (tt *testTracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
if tt.traceQueryStart != nil {
return tt.traceQueryStart(ctx, conn, data)
}
return ctx
}
func (tt *testTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
if tt.traceQueryEnd != nil {
tt.traceQueryEnd(ctx, conn, data)
}
}
func (tt *testTracer) TraceBatchStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
if tt.traceBatchStart != nil {
return tt.traceBatchStart(ctx, conn, data)
}
return ctx
}
func (tt *testTracer) TraceBatchQuery(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
if tt.traceBatchQuery != nil {
tt.traceBatchQuery(ctx, conn, data)
}
}
func (tt *testTracer) TraceBatchEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
if tt.traceBatchEnd != nil {
tt.traceBatchEnd(ctx, conn, data)
}
}
func (tt *testTracer) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context {
if tt.traceCopyFromStart != nil {
return tt.traceCopyFromStart(ctx, conn, data)
}
return ctx
}
func (tt *testTracer) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) {
if tt.traceCopyFromEnd != nil {
tt.traceCopyFromEnd(ctx, conn, data)
}
}
func (tt *testTracer) TracePrepareStart(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context {
if tt.tracePrepareStart != nil {
return tt.tracePrepareStart(ctx, conn, data)
}
return ctx
}
func (tt *testTracer) TracePrepareEnd(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) {
if tt.tracePrepareEnd != nil {
tt.tracePrepareEnd(ctx, conn, data)
}
}
func (tt *testTracer) TraceConnectStart(ctx context.Context, data pgx.TraceConnectStartData) context.Context {
if tt.traceConnectStart != nil {
return tt.traceConnectStart(ctx, data)
}
return ctx
}
func (tt *testTracer) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEndData) {
if tt.traceConnectEnd != nil {
tt.traceConnectEnd(ctx, data)
}
}
func TestTraceExec(t *testing.T) {
t.Parallel()
tracer := &testTracer{}
ctr := defaultConnTestRunner
ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
config := defaultConnTestRunner.CreateConfig(ctx, t)
config.Tracer = tracer
return config
}
pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
traceQueryStartCalled := false
tracer.traceQueryStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
traceQueryStartCalled = true
require.Equal(t, `select $1::text`, data.SQL)
require.Len(t, data.Args, 1)
require.Equal(t, `testing`, data.Args[0])
return context.WithValue(ctx, "fromTraceQueryStart", "foo")
}
traceQueryEndCalled := false
tracer.traceQueryEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
traceQueryEndCalled = true
require.Equal(t, "foo", ctx.Value("fromTraceQueryStart"))
require.Equal(t, `SELECT 1`, data.CommandTag.String())
require.NoError(t, data.Err)
}
_, err := conn.Exec(ctx, `select $1::text`, "testing")
require.NoError(t, err)
require.True(t, traceQueryStartCalled)
require.True(t, traceQueryEndCalled)
})
}
func TestTraceQuery(t *testing.T) {
t.Parallel()
tracer := &testTracer{}
ctr := defaultConnTestRunner
ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
config := defaultConnTestRunner.CreateConfig(ctx, t)
config.Tracer = tracer
return config
}
pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
traceQueryStartCalled := false
tracer.traceQueryStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
traceQueryStartCalled = true
require.Equal(t, `select $1::text`, data.SQL)
require.Len(t, data.Args, 1)
require.Equal(t, `testing`, data.Args[0])
return context.WithValue(ctx, "fromTraceQueryStart", "foo")
}
traceQueryEndCalled := false
tracer.traceQueryEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
traceQueryEndCalled = true
require.Equal(t, "foo", ctx.Value("fromTraceQueryStart"))
require.Equal(t, `SELECT 1`, data.CommandTag.String())
require.NoError(t, data.Err)
}
var s string
err := conn.QueryRow(ctx, `select $1::text`, "testing").Scan(&s)
require.NoError(t, err)
require.Equal(t, "testing", s)
require.True(t, traceQueryStartCalled)
require.True(t, traceQueryEndCalled)
})
}
func TestTraceBatchNormal(t *testing.T) {
t.Parallel()
tracer := &testTracer{}
ctr := defaultConnTestRunner
ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
config := defaultConnTestRunner.CreateConfig(ctx, t)
config.Tracer = tracer
return config
}
pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
traceBatchStartCalled := false
tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
traceBatchStartCalled = true
require.NotNil(t, data.Batch)
require.Equal(t, 2, data.Batch.Len())
return context.WithValue(ctx, "fromTraceBatchStart", "foo")
}
traceBatchQueryCalledCount := 0
tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
traceBatchQueryCalledCount++
require.Equal(t, "foo", ctx.Value("fromTraceBatchStart"))
require.NoError(t, data.Err)
}
traceBatchEndCalled := false
tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
traceBatchEndCalled = true
require.Equal(t, "foo", ctx.Value("fromTraceBatchStart"))
require.NoError(t, data.Err)
}
batch := &pgx.Batch{}
batch.Queue(`select 1`)
batch.Queue(`select 2`)
br := conn.SendBatch(context.Background(), batch)
require.True(t, traceBatchStartCalled)
var n int32
err := br.QueryRow().Scan(&n)
require.NoError(t, err)
require.EqualValues(t, 1, n)
require.EqualValues(t, 1, traceBatchQueryCalledCount)
err = br.QueryRow().Scan(&n)
require.NoError(t, err)
require.EqualValues(t, 2, n)
require.EqualValues(t, 2, traceBatchQueryCalledCount)
err = br.Close()
require.NoError(t, err)
require.True(t, traceBatchEndCalled)
})
}
func TestTraceBatchClose(t *testing.T) {
t.Parallel()
tracer := &testTracer{}
ctr := defaultConnTestRunner
ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
config := defaultConnTestRunner.CreateConfig(ctx, t)
config.Tracer = tracer
return config
}
pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
traceBatchStartCalled := false
tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
traceBatchStartCalled = true
require.NotNil(t, data.Batch)
require.Equal(t, 2, data.Batch.Len())
return context.WithValue(ctx, "fromTraceBatchStart", "foo")
}
traceBatchQueryCalledCount := 0
tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
traceBatchQueryCalledCount++
require.Equal(t, "foo", ctx.Value("fromTraceBatchStart"))
require.NoError(t, data.Err)
}
traceBatchEndCalled := false
tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
traceBatchEndCalled = true
require.Equal(t, "foo", ctx.Value("fromTraceBatchStart"))
require.NoError(t, data.Err)
}
batch := &pgx.Batch{}
batch.Queue(`select 1`)
batch.Queue(`select 2`)
br := conn.SendBatch(context.Background(), batch)
require.True(t, traceBatchStartCalled)
err := br.Close()
require.NoError(t, err)
require.EqualValues(t, 2, traceBatchQueryCalledCount)
require.True(t, traceBatchEndCalled)
})
}
func TestTraceBatchErrorWhileReadingResults(t *testing.T) {
t.Parallel()
tracer := &testTracer{}
ctr := defaultConnTestRunner
ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
config := defaultConnTestRunner.CreateConfig(ctx, t)
config.Tracer = tracer
return config
}
pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, []pgx.QueryExecMode{pgx.QueryExecModeSimpleProtocol}, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
traceBatchStartCalled := false
tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
traceBatchStartCalled = true
require.NotNil(t, data.Batch)
require.Equal(t, 3, data.Batch.Len())
return context.WithValue(ctx, "fromTraceBatchStart", "foo")
}
traceBatchQueryCalledCount := 0
tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
traceBatchQueryCalledCount++
require.Equal(t, "foo", ctx.Value("fromTraceBatchStart"))
if traceBatchQueryCalledCount == 2 {
require.Error(t, data.Err)
} else {
require.NoError(t, data.Err)
}
}
traceBatchEndCalled := false
tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
traceBatchEndCalled = true
require.Equal(t, "foo", ctx.Value("fromTraceBatchStart"))
require.Error(t, data.Err)
}
batch := &pgx.Batch{}
batch.Queue(`select 1`)
batch.Queue(`select 2/n-2 from generate_series(0,10) n`)
batch.Queue(`select 3`)
br := conn.SendBatch(context.Background(), batch)
require.True(t, traceBatchStartCalled)
commandTag, err := br.Exec()
require.NoError(t, err)
require.Equal(t, "SELECT 1", commandTag.String())
commandTag, err = br.Exec()
require.Error(t, err)
require.Equal(t, "", commandTag.String())
commandTag, err = br.Exec()
require.Error(t, err)
require.Equal(t, "", commandTag.String())
err = br.Close()
require.Error(t, err)
require.EqualValues(t, 2, traceBatchQueryCalledCount)
require.True(t, traceBatchEndCalled)
})
}
func TestTraceBatchErrorWhileReadingResultsWhileClosing(t *testing.T) {
t.Parallel()
tracer := &testTracer{}
ctr := defaultConnTestRunner
ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
config := defaultConnTestRunner.CreateConfig(ctx, t)
config.Tracer = tracer
return config
}
pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, []pgx.QueryExecMode{pgx.QueryExecModeSimpleProtocol}, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
traceBatchStartCalled := false
tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
traceBatchStartCalled = true
require.NotNil(t, data.Batch)
require.Equal(t, 3, data.Batch.Len())
return context.WithValue(ctx, "fromTraceBatchStart", "foo")
}
traceBatchQueryCalledCount := 0
tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
traceBatchQueryCalledCount++
require.Equal(t, "foo", ctx.Value("fromTraceBatchStart"))
if traceBatchQueryCalledCount == 2 {
require.Error(t, data.Err)
} else {
require.NoError(t, data.Err)
}
}
traceBatchEndCalled := false
tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
traceBatchEndCalled = true
require.Equal(t, "foo", ctx.Value("fromTraceBatchStart"))
require.Error(t, data.Err)
}
batch := &pgx.Batch{}
batch.Queue(`select 1`)
batch.Queue(`select 2/n-2 from generate_series(0,10) n`)
batch.Queue(`select 3`)
br := conn.SendBatch(context.Background(), batch)
require.True(t, traceBatchStartCalled)
err := br.Close()
require.Error(t, err)
require.EqualValues(t, 2, traceBatchQueryCalledCount)
require.True(t, traceBatchEndCalled)
})
}
func TestTraceCopyFrom(t *testing.T) {
t.Parallel()
tracer := &testTracer{}
ctr := defaultConnTestRunner
ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
config := defaultConnTestRunner.CreateConfig(ctx, t)
config.Tracer = tracer
return config
}
pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
traceCopyFromStartCalled := false
tracer.traceCopyFromStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context {
traceCopyFromStartCalled = true
require.Equal(t, pgx.Identifier{"foo"}, data.TableName)
require.Equal(t, []string{"a"}, data.ColumnNames)
return context.WithValue(ctx, "fromTraceCopyFromStart", "foo")
}
traceCopyFromEndCalled := false
tracer.traceCopyFromEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) {
traceCopyFromEndCalled = true
require.Equal(t, "foo", ctx.Value("fromTraceCopyFromStart"))
require.Equal(t, `COPY 2`, data.CommandTag.String())
require.NoError(t, data.Err)
}
_, err := conn.Exec(context.Background(), `create temporary table foo(a int4)`)
require.NoError(t, err)
inputRows := [][]any{
{int32(1)},
{nil},
}
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows))
require.NoError(t, err)
require.EqualValues(t, len(inputRows), copyCount)
require.True(t, traceCopyFromStartCalled)
require.True(t, traceCopyFromEndCalled)
})
}
func TestTracePrepare(t *testing.T) {
t.Parallel()
tracer := &testTracer{}
ctr := defaultConnTestRunner
ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
config := defaultConnTestRunner.CreateConfig(ctx, t)
config.Tracer = tracer
return config
}
pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
tracePrepareStartCalled := false
tracer.tracePrepareStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context {
tracePrepareStartCalled = true
require.Equal(t, `ps`, data.Name)
require.Equal(t, `select $1::text`, data.SQL)
return context.WithValue(ctx, "fromTracePrepareStart", "foo")
}
tracePrepareEndCalled := false
tracer.tracePrepareEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) {
tracePrepareEndCalled = true
require.False(t, data.AlreadyPrepared)
require.NoError(t, data.Err)
}
_, err := conn.Prepare(ctx, "ps", `select $1::text`)
require.NoError(t, err)
require.True(t, tracePrepareStartCalled)
require.True(t, tracePrepareEndCalled)
tracePrepareStartCalled = false
tracePrepareEndCalled = false
tracer.tracePrepareEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) {
tracePrepareEndCalled = true
require.True(t, data.AlreadyPrepared)
require.NoError(t, data.Err)
}
_, err = conn.Prepare(ctx, "ps", `select $1::text`)
require.NoError(t, err)
require.True(t, tracePrepareStartCalled)
require.True(t, tracePrepareEndCalled)
})
}
func TestTraceConnect(t *testing.T) {
t.Parallel()
tracer := &testTracer{}
config := defaultConnTestRunner.CreateConfig(context.Background(), t)
config.Tracer = tracer
traceConnectStartCalled := false
tracer.traceConnectStart = func(ctx context.Context, data pgx.TraceConnectStartData) context.Context {
traceConnectStartCalled = true
require.NotNil(t, data.ConnConfig)
return context.WithValue(ctx, "fromTraceConnectStart", "foo")
}
traceConnectEndCalled := false
tracer.traceConnectEnd = func(ctx context.Context, data pgx.TraceConnectEndData) {
traceConnectEndCalled = true
require.NotNil(t, data.Conn)
require.NoError(t, data.Err)
}
conn1, err := pgx.ConnectConfig(context.Background(), config)
require.NoError(t, err)
defer conn1.Close(context.Background())
require.True(t, traceConnectStartCalled)
require.True(t, traceConnectEndCalled)
config, err = pgx.ParseConfig("host=/invalid")
require.NoError(t, err)
config.Tracer = tracer
traceConnectStartCalled = false
traceConnectEndCalled = false
tracer.traceConnectEnd = func(ctx context.Context, data pgx.TraceConnectEndData) {
traceConnectEndCalled = true
require.Nil(t, data.Conn)
require.Error(t, data.Err)
}
conn2, err := pgx.ConnectConfig(context.Background(), config)
require.Nil(t, conn2)
require.Error(t, err)
require.True(t, traceConnectStartCalled)
require.True(t, traceConnectEndCalled)
}