diff --git a/CHANGELOG.md b/CHANGELOG.md index 387150c2..9a402c40 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -159,7 +159,9 @@ Previously, a batch with 10 unique parameterized statements executed 100 times w for each prepare / describe and 1 for executing them all. Now pipeline mode is used to prepare / describe all statements in a single network round trip. So it would only take 2 round trips. -## 3rd Party Logger Integration +## Tracing and Logging + +Internal logging support has been replaced with tracing hooks. This allows custom tracing integration with tools like OpenTelemetry. Package tracelog provides an adapter for pgx v4 loggers to act as a tracer. All integrations with 3rd party loggers have been extracted to separate repositories. This trims the pgx dependency tree. diff --git a/README.md b/README.md index c4f7239f..c7224b22 100644 --- a/README.md +++ b/README.md @@ -163,6 +163,8 @@ pgerrcode contains constants for the PostgreSQL error codes. ## Adapters for 3rd Party Loggers +These adapters can be used with the tracelog package. + * [github.com/jackc/pgx-go-kit-log](https://github.com/jackc/pgx-go-kit-log) * [github.com/jackc/pgx-log15](https://github.com/jackc/pgx-log15) * [github.com/jackc/pgx-logrus](https://github.com/jackc/pgx-logrus) diff --git a/batch.go b/batch.go index f2a9b4c8..a6951096 100644 --- a/batch.go +++ b/batch.go @@ -50,13 +50,14 @@ type BatchResults interface { } type batchResults struct { - ctx context.Context - conn *Conn - mrr *pgconn.MultiResultReader - err error - b *Batch - ix int - closed bool + ctx context.Context + conn *Conn + mrr *pgconn.MultiResultReader + err error + b *Batch + ix int + closed bool + endTraced bool } // Exec reads the results from the next query in the batch as if the query has been sent with Exec. @@ -75,35 +76,29 @@ func (br *batchResults) Exec() (pgconn.CommandTag, error) { if err == nil { err = errors.New("no result") } - if br.conn.shouldLog(LogLevelError) { - br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]any{ - "sql": query, - "args": logQueryArgs(arguments), - "err": err, + if br.conn.batchTracer != nil { + br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{ + SQL: query, + Args: arguments, + Err: err, }) } return pgconn.CommandTag{}, err } commandTag, err := br.mrr.ResultReader().Close() + br.err = err - if err != nil { - if br.conn.shouldLog(LogLevelError) { - br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]any{ - "sql": query, - "args": logQueryArgs(arguments), - "err": err, - }) - } - } else if br.conn.shouldLog(LogLevelInfo) { - br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Exec", map[string]any{ - "sql": query, - "args": logQueryArgs(arguments), - "commandTag": commandTag, + if br.conn.batchTracer != nil { + br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{ + SQL: query, + Args: arguments, + CommandTag: commandTag, + Err: br.err, }) } - return commandTag, err + return commandTag, br.err } // Query reads the results from the next query in the batch as if the query has been sent with Query. @@ -123,6 +118,7 @@ func (br *batchResults) Query() (Rows, error) { } rows := br.conn.getRows(br.ctx, query, arguments) + rows.batchTracer = br.conn.batchTracer if !br.mrr.NextResult() { rows.err = br.mrr.Close() @@ -131,11 +127,11 @@ func (br *batchResults) Query() (Rows, error) { } rows.closed = true - if br.conn.shouldLog(LogLevelError) { - br.conn.log(br.ctx, LogLevelError, "BatchResult.Query", map[string]any{ - "sql": query, - "args": logQueryArgs(arguments), - "err": rows.err, + if br.conn.batchTracer != nil { + br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{ + SQL: query, + Args: arguments, + Err: rows.err, }) } @@ -156,6 +152,15 @@ func (br *batchResults) QueryRow() Row { // Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to // resyncronize the connection with the server. In this case the underlying connection will have been closed. func (br *batchResults) Close() error { + defer func() { + if !br.endTraced { + if br.conn != nil && br.conn.batchTracer != nil { + br.conn.batchTracer.TraceBatchEnd(br.ctx, br.conn, TraceBatchEndData{Err: br.err}) + } + br.endTraced = true + } + }() + if br.err != nil { return br.err } @@ -163,24 +168,26 @@ func (br *batchResults) Close() error { if br.closed { return nil } - br.closed = true - // log any queries that haven't yet been logged by Exec or Query - for { - query, args, ok := br.nextQueryAndArgs() - if !ok { - break - } - - if br.conn.shouldLog(LogLevelInfo) { - br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Close", map[string]any{ - "sql": query, - "args": logQueryArgs(args), - }) + // consume and log any queries that haven't yet been logged by Exec or Query + if br.conn.batchTracer != nil { + for br.err == nil && !br.closed && br.b != nil && br.ix < len(br.b.items) { + br.Exec() } } - return br.mrr.Close() + br.closed = true + + err := br.mrr.Close() + if br.err == nil { + br.err = err + } + + return br.err +} + +func (br *batchResults) earlyError() error { + return br.err } func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) { @@ -195,14 +202,15 @@ func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) { } type pipelineBatchResults struct { - ctx context.Context - conn *Conn - pipeline *pgconn.Pipeline - lastRows *baseRows - err error - b *Batch - ix int - closed bool + ctx context.Context + conn *Conn + pipeline *pgconn.Pipeline + lastRows *baseRows + err error + b *Batch + ix int + closed bool + endTraced bool } // Exec reads the results from the next query in the batch as if the query has been sent with Exec. @@ -227,25 +235,17 @@ func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) { var commandTag pgconn.CommandTag switch results := results.(type) { case *pgconn.ResultReader: - commandTag, err = results.Close() + commandTag, br.err = results.Close() default: return pgconn.CommandTag{}, fmt.Errorf("unexpected pipeline result: %T", results) } - if err != nil { - br.err = err - if br.conn.shouldLog(LogLevelError) { - br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]any{ - "sql": query, - "args": logQueryArgs(arguments), - "err": err, - }) - } - } else if br.conn.shouldLog(LogLevelInfo) { - br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Exec", map[string]any{ - "sql": query, - "args": logQueryArgs(arguments), - "commandTag": commandTag, + if br.conn.batchTracer != nil { + br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{ + SQL: query, + Args: arguments, + CommandTag: commandTag, + Err: br.err, }) } @@ -274,6 +274,7 @@ func (br *pipelineBatchResults) Query() (Rows, error) { } rows := br.conn.getRows(br.ctx, query, arguments) + rows.batchTracer = br.conn.batchTracer br.lastRows = rows results, err := br.pipeline.GetResults() @@ -281,11 +282,12 @@ func (br *pipelineBatchResults) Query() (Rows, error) { br.err = err rows.err = err rows.closed = true - if br.conn.shouldLog(LogLevelError) { - br.conn.log(br.ctx, LogLevelError, "BatchResult.Query", map[string]any{ - "sql": query, - "args": logQueryArgs(arguments), - "err": rows.err, + + if br.conn.batchTracer != nil { + br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{ + SQL: query, + Args: arguments, + Err: err, }) } } else { @@ -313,6 +315,15 @@ func (br *pipelineBatchResults) QueryRow() Row { // Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to // resyncronize the connection with the server. In this case the underlying connection will have been closed. func (br *pipelineBatchResults) Close() error { + defer func() { + if !br.endTraced { + if br.conn.batchTracer != nil { + br.conn.batchTracer.TraceBatchEnd(br.ctx, br.conn, TraceBatchEndData{Err: br.err}) + } + br.endTraced = true + } + }() + if br.err != nil { return br.err } @@ -325,24 +336,25 @@ func (br *pipelineBatchResults) Close() error { if br.closed { return nil } - br.closed = true - // log any queries that haven't yet been logged by Exec or Query - for { - query, args, ok := br.nextQueryAndArgs() - if !ok { - break - } - - if br.conn.shouldLog(LogLevelInfo) { - br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Close", map[string]any{ - "sql": query, - "args": logQueryArgs(args), - }) + // consume and log any queries that haven't yet been logged by Exec or Query + if br.conn.batchTracer != nil { + for br.err == nil && !br.closed && br.b != nil && br.ix < len(br.b.items) { + br.Exec() } } + br.closed = true - return br.pipeline.Close() + err := br.pipeline.Close() + if br.err == nil { + br.err = err + } + + return br.err +} + +func (br *pipelineBatchResults) earlyError() error { + return br.err } func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, ok bool) { diff --git a/batch_test.go b/batch_test.go index e8d6f677..156e8f8f 100644 --- a/batch_test.go +++ b/batch_test.go @@ -733,94 +733,6 @@ func testConnSendBatch(t *testing.T, conn *pgx.Conn, queryCount int) { require.NoError(t, err) } -func TestLogBatchStatementsOnExec(t *testing.T) { - l1 := &testLogger{} - config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) - config.Logger = l1 - - conn := mustConnect(t, config) - defer closeConn(t, conn) - - l1.logs = l1.logs[0:0] // Clear logs written when establishing connection - - batch := &pgx.Batch{} - batch.Queue("create table foo (id bigint)") - batch.Queue("drop table foo") - - br := conn.SendBatch(context.Background(), batch) - - _, err := br.Exec() - if err != nil { - t.Fatalf("Unexpected error creating table: %v", err) - } - - _, err = br.Exec() - if err != nil { - t.Fatalf("Unexpected error dropping table: %v", err) - } - - if len(l1.logs) != 2 { - t.Fatalf("Expected two log entries but got %d", len(l1.logs)) - } - - if l1.logs[0].msg != "BatchResult.Exec" { - t.Errorf("Expected first log message to be 'BatchResult.Exec' but was '%s", l1.logs[0].msg) - } - - if l1.logs[0].data["sql"] != "create table foo (id bigint)" { - t.Errorf("Expected the first query to be 'create table foo (id bigint)' but was '%s'", l1.logs[0].data["sql"]) - } - - if l1.logs[1].msg != "BatchResult.Exec" { - t.Errorf("Expected second log message to be 'BatchResult.Exec' but was '%s", l1.logs[1].msg) - } - - if l1.logs[1].data["sql"] != "drop table foo" { - t.Errorf("Expected the second query to be 'drop table foo' but was '%s'", l1.logs[1].data["sql"]) - } -} - -func TestLogBatchStatementsOnBatchResultClose(t *testing.T) { - l1 := &testLogger{} - config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) - config.Logger = l1 - - conn := mustConnect(t, config) - defer closeConn(t, conn) - - l1.logs = l1.logs[0:0] // Clear logs written when establishing connection - - batch := &pgx.Batch{} - batch.Queue("select generate_series(1,$1)", 100) - batch.Queue("select 1 = 1;") - - br := conn.SendBatch(context.Background(), batch) - - if err := br.Close(); err != nil { - t.Fatalf("Unexpected batch error: %v", err) - } - - if len(l1.logs) != 2 { - t.Fatalf("Expected 2 log statements but found %d", len(l1.logs)) - } - - if l1.logs[0].msg != "BatchResult.Close" { - t.Errorf("Expected first log statement to be 'BatchResult.Close' but was %s", l1.logs[0].msg) - } - - if l1.logs[0].data["sql"] != "select generate_series(1,$1)" { - t.Errorf("Expected first query to be 'select generate_series(1,$1)' but was '%s'", l1.logs[0].data["sql"]) - } - - if l1.logs[1].msg != "BatchResult.Close" { - t.Errorf("Expected second log statement to be 'BatchResult.Close' but was %s", l1.logs[1].msg) - } - - if l1.logs[1].data["sql"] != "select 1 = 1;" { - t.Errorf("Expected second query to be 'select 1 = 1;' but was '%s'", l1.logs[1].data["sql"]) - } -} - func TestSendBatchSimpleProtocol(t *testing.T) { t.Parallel() diff --git a/bench_test.go b/bench_test.go index 31b3b38e..73e1b258 100644 --- a/bench_test.go +++ b/bench_test.go @@ -284,123 +284,6 @@ func BenchmarkPointerPointerWithPresentValues(b *testing.B) { } } -func BenchmarkSelectWithoutLogging(b *testing.B) { - conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))) - defer closeConn(b, conn) - - benchmarkSelectWithLog(b, conn) -} - -type discardLogger struct{} - -func (dl discardLogger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]any) { -} - -func BenchmarkSelectWithLoggingTraceDiscard(b *testing.B) { - var logger discardLogger - config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) - config.Logger = logger - config.LogLevel = pgx.LogLevelTrace - - conn := mustConnect(b, config) - defer closeConn(b, conn) - - benchmarkSelectWithLog(b, conn) -} - -func BenchmarkSelectWithLoggingDebugWithDiscard(b *testing.B) { - var logger discardLogger - config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) - config.Logger = logger - config.LogLevel = pgx.LogLevelDebug - - conn := mustConnect(b, config) - defer closeConn(b, conn) - - benchmarkSelectWithLog(b, conn) -} - -func BenchmarkSelectWithLoggingInfoWithDiscard(b *testing.B) { - var logger discardLogger - config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) - config.Logger = logger - config.LogLevel = pgx.LogLevelInfo - - conn := mustConnect(b, config) - defer closeConn(b, conn) - - benchmarkSelectWithLog(b, conn) -} - -func BenchmarkSelectWithLoggingErrorWithDiscard(b *testing.B) { - var logger discardLogger - config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) - config.Logger = logger - config.LogLevel = pgx.LogLevelError - - conn := mustConnect(b, config) - defer closeConn(b, conn) - - benchmarkSelectWithLog(b, conn) -} - -func benchmarkSelectWithLog(b *testing.B, conn *pgx.Conn) { - _, err := conn.Prepare(context.Background(), "test", "select 1::int4, 'johnsmith', 'johnsmith@example.com', 'John Smith', 'male', '1970-01-01'::date, '2015-01-01 00:00:00'::timestamptz") - if err != nil { - b.Fatal(err) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - var record struct { - id int32 - userName string - email string - name string - sex string - birthDate time.Time - lastLoginTime time.Time - } - - err = conn.QueryRow(context.Background(), "test").Scan( - &record.id, - &record.userName, - &record.email, - &record.name, - &record.sex, - &record.birthDate, - &record.lastLoginTime, - ) - if err != nil { - b.Fatal(err) - } - - // These checks both ensure that the correct data was returned - // and provide a benchmark of accessing the returned values. - if record.id != 1 { - b.Fatalf("bad value for id: %v", record.id) - } - if record.userName != "johnsmith" { - b.Fatalf("bad value for userName: %v", record.userName) - } - if record.email != "johnsmith@example.com" { - b.Fatalf("bad value for email: %v", record.email) - } - if record.name != "John Smith" { - b.Fatalf("bad value for name: %v", record.name) - } - if record.sex != "male" { - b.Fatalf("bad value for sex: %v", record.sex) - } - if record.birthDate != time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC) { - b.Fatalf("bad value for birthDate: %v", record.birthDate) - } - if record.lastLoginTime != time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local) { - b.Fatalf("bad value for lastLoginTime: %v", record.lastLoginTime) - } - } -} - const benchmarkWriteTableCreateSQL = `drop table if exists t; create table t( diff --git a/conn.go b/conn.go index 1f87e9e7..b8e0b232 100644 --- a/conn.go +++ b/conn.go @@ -19,8 +19,8 @@ import ( // then it can be modified. A manually initialized ConnConfig will cause ConnectConfig to panic. type ConnConfig struct { pgconn.Config - Logger Logger - LogLevel LogLevel + + Tracer QueryTracer // Original connection string that was parsed into config. connString string @@ -63,8 +63,11 @@ type Conn struct { preparedStatements map[string]*pgconn.StatementDescription statementCache stmtcache.Cache descriptionCache stmtcache.Cache - logger Logger - logLevel LogLevel + + queryTracer QueryTracer + batchTracer BatchTracer + copyFromTracer CopyFromTracer + prepareTracer PrepareTracer notifications []*pgconn.Notification @@ -94,9 +97,6 @@ func (ident Identifier) Sanitize() string { // ErrNoRows occurs when rows are expected but none are returned. var ErrNoRows = errors.New("no rows in result set") -// ErrInvalidLogLevel occurs on attempt to set an invalid log level. -var ErrInvalidLogLevel = errors.New("invalid log level") - var errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache") var errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache") @@ -182,7 +182,6 @@ func ParseConfig(connString string) (*ConnConfig, error) { connConfig := &ConnConfig{ Config: *config, createdByParseConfig: true, - LogLevel: LogLevelInfo, StatementCacheCapacity: statementCacheCapacity, DescriptionCacheCapacity: descriptionCacheCapacity, DefaultQueryExecMode: defaultQueryExecMode, @@ -194,6 +193,13 @@ func ParseConfig(connString string) (*ConnConfig, error) { // connect connects to a database. connect takes ownership of config. The caller must not use or access it again. func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) { + if connectTracer, ok := config.Tracer.(ConnectTracer); ok { + ctx = connectTracer.TraceConnectStart(ctx, TraceConnectStartData{ConnConfig: config}) + defer func() { + connectTracer.TraceConnectEnd(ctx, TraceConnectEndData{Conn: c, Err: err}) + }() + } + // Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from // zero values. if !config.createdByParseConfig { @@ -201,29 +207,28 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) { } c = &Conn{ - config: config, - typeMap: pgtype.NewMap(), - logLevel: config.LogLevel, - logger: config.Logger, + config: config, + typeMap: pgtype.NewMap(), + queryTracer: config.Tracer, + } + + if t, ok := c.queryTracer.(BatchTracer); ok { + c.batchTracer = t + } + if t, ok := c.queryTracer.(CopyFromTracer); ok { + c.copyFromTracer = t + } + if t, ok := c.queryTracer.(PrepareTracer); ok { + c.prepareTracer = t } // Only install pgx notification system if no other callback handler is present. if config.Config.OnNotification == nil { config.Config.OnNotification = c.bufferNotifications - } else { - if c.shouldLog(LogLevelDebug) { - c.log(ctx, LogLevelDebug, "pgx notification handler disabled by application supplied OnNotification", map[string]any{"host": config.Config.Host}) - } } - if c.shouldLog(LogLevelInfo) { - c.log(ctx, LogLevelInfo, "Dialing PostgreSQL server", map[string]any{"host": config.Config.Host}) - } c.pgConn, err = pgconn.ConnectConfig(ctx, &config.Config) if err != nil { - if c.shouldLog(LogLevelError) { - c.log(ctx, LogLevelError, "connect failed", map[string]any{"err": err}) - } return nil, err } @@ -251,9 +256,6 @@ func (c *Conn) Close(ctx context.Context) error { } err := c.pgConn.Close(ctx) - if c.shouldLog(LogLevelInfo) { - c.log(ctx, LogLevelInfo, "closed connection", nil) - } return err } @@ -264,18 +266,23 @@ func (c *Conn) Close(ctx context.Context) error { // name and sql arguments. This allows a code path to Prepare and Query/Exec without // concern for if the statement has already been prepared. func (c *Conn) Prepare(ctx context.Context, name, sql string) (sd *pgconn.StatementDescription, err error) { + if c.prepareTracer != nil { + ctx = c.prepareTracer.TracePrepareStart(ctx, c, TracePrepareStartData{Name: name, SQL: sql}) + } + if name != "" { var ok bool if sd, ok = c.preparedStatements[name]; ok && sd.SQL == sql { + if c.prepareTracer != nil { + c.prepareTracer.TracePrepareEnd(ctx, c, TracePrepareEndData{AlreadyPrepared: true}) + } return sd, nil } } - if c.shouldLog(LogLevelError) { + if c.prepareTracer != nil { defer func() { - if err != nil { - c.log(ctx, LogLevelError, "Prepare failed", map[string]any{"err": err, "name": name, "sql": sql}) - } + c.prepareTracer.TracePrepareEnd(ctx, c, TracePrepareEndData{Err: err}) }() } @@ -337,21 +344,6 @@ func (c *Conn) die(err error) { c.pgConn.Close(ctx) } -func (c *Conn) shouldLog(lvl LogLevel) bool { - return c.logger != nil && c.logLevel >= lvl -} - -func (c *Conn) log(ctx context.Context, lvl LogLevel, msg string, data map[string]any) { - if data == nil { - data = map[string]any{} - } - if c.pgConn != nil && c.pgConn.PID() != 0 { - data["pid"] = c.pgConn.PID() - } - - c.logger.Log(ctx, lvl, msg, data) -} - func quoteIdentifier(s string) string { return `"` + strings.ReplaceAll(s, `"`, `""`) + `"` } @@ -379,24 +371,18 @@ func (c *Conn) Config() *ConnConfig { return c.config.Copy() } // Exec executes sql. sql can be either a prepared statement name or an SQL string. arguments should be referenced // positionally from the sql string as $1, $2, etc. func (c *Conn) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) { + if c.queryTracer != nil { + ctx = c.queryTracer.TraceQueryStart(ctx, c, TraceQueryStartData{SQL: sql, Args: arguments}) + } + if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil { return pgconn.CommandTag{}, err } - startTime := time.Now() - commandTag, err := c.exec(ctx, sql, arguments...) - if err != nil { - if c.shouldLog(LogLevelError) { - endTime := time.Now() - c.log(ctx, LogLevelError, "Exec", map[string]any{"sql": sql, "args": logQueryArgs(arguments), "err": err, "time": endTime.Sub(startTime)}) - } - return commandTag, err - } - if c.shouldLog(LogLevelInfo) { - endTime := time.Now() - c.log(ctx, LogLevelInfo, "Exec", map[string]any{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag}) + if c.queryTracer != nil { + c.queryTracer.TraceQueryEnd(ctx, c, TraceQueryEndData{CommandTag: commandTag, Err: err}) } return commandTag, err @@ -537,7 +523,7 @@ func (c *Conn) getRows(ctx context.Context, sql string, args []any) *baseRows { r := &baseRows{} r.ctx = ctx - r.logger = c + r.queryTracer = c.queryTracer r.typeMap = c.typeMap r.startTime = time.Now() r.sql = sql @@ -628,7 +614,14 @@ type QueryRewriter interface { // QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely // needed. See the documentation for those types for details. func (c *Conn) Query(ctx context.Context, sql string, args ...any) (Rows, error) { + if c.queryTracer != nil { + ctx = c.queryTracer.TraceQueryStart(ctx, c, TraceQueryStartData{SQL: sql, Args: args}) + } + if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil { + if c.queryTracer != nil { + c.queryTracer.TraceQueryEnd(ctx, c, TraceQueryEndData{Err: err}) + } return &baseRows{err: err, closed: true}, err } @@ -791,7 +784,17 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...any) Row { // SendBatch sends all queued queries to the server at once. All queries are run in an implicit transaction unless // explicit transaction control statements are executed. The returned BatchResults must be closed before the connection // is used again. -func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { +func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) { + if c.batchTracer != nil { + ctx = c.batchTracer.TraceBatchStart(ctx, c, TraceBatchStartData{Batch: b}) + defer func() { + err := br.(interface{ earlyError() error }).earlyError() + if err != nil { + c.batchTracer.TraceBatchEnd(ctx, c, TraceBatchEndData{Err: err}) + } + }() + } + if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil { return &batchResults{ctx: ctx, conn: c, err: err} } diff --git a/conn_test.go b/conn_test.go index 2ead63ce..b84093f4 100644 --- a/conn_test.go +++ b/conn_test.go @@ -3,7 +3,6 @@ package pgx_test import ( "bytes" "context" - "log" "os" "strings" "sync" @@ -743,79 +742,6 @@ func TestInsertTimestampArray(t *testing.T) { }) } -type testLog struct { - lvl pgx.LogLevel - msg string - data map[string]any -} - -type testLogger struct { - logs []testLog -} - -func (l *testLogger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]any) { - data["ctxdata"] = ctx.Value("ctxdata") - l.logs = append(l.logs, testLog{lvl: level, msg: msg, data: data}) -} - -func TestLogPassesContext(t *testing.T) { - t.Parallel() - - l1 := &testLogger{} - config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) - config.Logger = l1 - - conn := mustConnect(t, config) - defer closeConn(t, conn) - - l1.logs = l1.logs[0:0] // Clear logs written when establishing connection - - ctx := context.WithValue(context.Background(), "ctxdata", "foo") - - if _, err := conn.Exec(ctx, ";"); err != nil { - t.Fatal(err) - } - - if len(l1.logs) != 1 { - t.Fatal("Expected logger to be called once, but it wasn't") - } - - if l1.logs[0].data["ctxdata"] != "foo" { - t.Fatal("Expected context data to be passed to logger, but it wasn't") - } -} - -func TestLoggerFunc(t *testing.T) { - t.Parallel() - - const testMsg = "foo" - - buf := bytes.Buffer{} - logger := log.New(&buf, "", 0) - - createAdapterFn := func(logger *log.Logger) pgx.LoggerFunc { - return func(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { - logger.Printf("%s", testMsg) - } - } - - config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) - config.Logger = createAdapterFn(logger) - - conn := mustConnect(t, config) - defer closeConn(t, conn) - - buf.Reset() // Clear logs written when establishing connection - - if _, err := conn.Exec(context.TODO(), ";"); err != nil { - t.Fatal(err) - } - - if strings.TrimSpace(buf.String()) != testMsg { - t.Errorf("Expected logger function to return '%s', but it was '%s'", testMsg, buf.String()) - } -} - func TestIdentifierSanitize(t *testing.T) { t.Parallel() diff --git a/copy_from.go b/copy_from.go index c5e9aae8..c8b98c57 100644 --- a/copy_from.go +++ b/copy_from.go @@ -5,7 +5,6 @@ import ( "context" "fmt" "io" - "time" "github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/pgconn" @@ -89,6 +88,13 @@ type copyFrom struct { } func (ct *copyFrom) run(ctx context.Context) (int64, error) { + if ct.conn.copyFromTracer != nil { + ctx = ct.conn.copyFromTracer.TraceCopyFromStart(ctx, ct.conn, TraceCopyFromStartData{ + TableName: ct.tableName, + ColumnNames: ct.columnNames, + }) + } + quotedTableName := ct.tableName.Sanitize() cbuf := &bytes.Buffer{} for i, cn := range ct.columnNames { @@ -145,24 +151,19 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) { w.Close() }() - startTime := time.Now() - commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames)) r.Close() <-doneChan - rowsAffected := commandTag.RowsAffected() - endTime := time.Now() - if err == nil { - if ct.conn.shouldLog(LogLevelInfo) { - ct.conn.log(ctx, LogLevelInfo, "CopyFrom", map[string]any{"tableName": ct.tableName, "columnNames": ct.columnNames, "time": endTime.Sub(startTime), "rowCount": rowsAffected}) - } - } else if ct.conn.shouldLog(LogLevelError) { - ct.conn.log(ctx, LogLevelError, "CopyFrom", map[string]any{"err": err, "tableName": ct.tableName, "columnNames": ct.columnNames, "time": endTime.Sub(startTime)}) + if ct.conn.copyFromTracer != nil { + ct.conn.copyFromTracer.TraceCopyFromEnd(ctx, ct.conn, TraceCopyFromEndData{ + CommandTag: commandTag, + Err: err, + }) } - return rowsAffected, err + return commandTag.RowsAffected(), err } func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (bool, []byte, error) { diff --git a/doc.go b/doc.go index cfc2af85..b10ab1df 100644 --- a/doc.go +++ b/doc.go @@ -12,15 +12,7 @@ The primary way of establishing a connection is with `pgx.Connect`. The database connection string can be in URL or DSN format. Both PostgreSQL settings and pgx settings can be specified here. In addition, a config struct can be created by `ParseConfig` and modified before establishing the connection with -`ConnectConfig`. - - config, err := pgx.ParseConfig(os.Getenv("DATABASE_URL")) - if err != nil { - // ... - } - config.Logger = log15adapter.NewLogger(log.New("module", "pgx")) - - conn, err := pgx.ConnectConfig(context.Background(), config) +`ConnectConfig` to configure settings such as tracing that cannot be configured with a connection string. Connection Pool @@ -315,11 +307,11 @@ notification is received or the context is canceled. } -Logging +Tracing and Logging -pgx defines a simple logger interface. Connections optionally accept a logger that satisfies this interface. Set -LogLevel to control logging verbosity. Adapters for github.com/inconshreveable/log15, github.com/sirupsen/logrus, -go.uber.org/zap, github.com/rs/zerolog, and the testing log are provided in the log directory. +pgx supports tracing by setting ConnConfig.Tracer. + +In addition, the tracelog package provides the TraceLog type which lets a traditional logger act as a Tracer. Lower Level PostgreSQL Functionality diff --git a/helper_test.go b/helper_test.go index f091d23e..26e54621 100644 --- a/helper_test.go +++ b/helper_test.go @@ -98,8 +98,7 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, testName return } - assert.Equalf(t, expected.Logger, actual.Logger, "%s - Logger", testName) - assert.Equalf(t, expected.LogLevel, actual.LogLevel, "%s - LogLevel", testName) + assert.Equalf(t, expected.Tracer, actual.Tracer, "%s - Tracer", testName) assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName) assert.Equalf(t, expected.StatementCacheCapacity, actual.StatementCacheCapacity, "%s - StatementCacheCapacity", testName) assert.Equalf(t, expected.DescriptionCacheCapacity, actual.DescriptionCacheCapacity, "%s - DescriptionCacheCapacity", testName) diff --git a/pgproto3/frontend.go b/pgproto3/frontend.go index eed8dc4f..09f04141 100644 --- a/pgproto3/frontend.go +++ b/pgproto3/frontend.go @@ -162,7 +162,7 @@ func (f *Frontend) SendExecute(msg *Execute) { prevLen := len(f.wbuf) f.wbuf = msg.Encode(f.wbuf) if f.tracer != nil { - f.tracer.traceExecute('F', int32(len(f.wbuf)-prevLen), msg) + f.tracer.TraceQueryute('F', int32(len(f.wbuf)-prevLen), msg) } } diff --git a/pgproto3/trace.go b/pgproto3/trace.go index d3edc4aa..c09f68d1 100644 --- a/pgproto3/trace.go +++ b/pgproto3/trace.go @@ -79,7 +79,7 @@ func (t *tracer) traceMessage(sender byte, encodedLen int32, msg Message) { case *ErrorResponse: t.traceErrorResponse(sender, encodedLen, msg) case *Execute: - t.traceExecute(sender, encodedLen, msg) + t.TraceQueryute(sender, encodedLen, msg) case *Flush: t.traceFlush(sender, encodedLen, msg) case *FunctionCall: @@ -277,7 +277,7 @@ func (t *tracer) traceErrorResponse(sender byte, encodedLen int32, msg *ErrorRes t.finishTrace() } -func (t *tracer) traceExecute(sender byte, encodedLen int32, msg *Execute) { +func (t *tracer) TraceQueryute(sender byte, encodedLen int32, msg *Execute) { t.beginTrace(sender, encodedLen, "Execute") fmt.Fprintf(t.buf, "\t %s %d", traceDoubleQuotedString([]byte(msg.Portal)), msg.MaxRows) t.finishTrace() diff --git a/pgxpool/common_test.go b/pgxpool/common_test.go index eabc0e3c..16f4f553 100644 --- a/pgxpool/common_test.go +++ b/pgxpool/common_test.go @@ -160,8 +160,7 @@ func assertConnConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, test return } - assert.Equalf(t, expected.Logger, actual.Logger, "%s - Logger", testName) - assert.Equalf(t, expected.LogLevel, actual.LogLevel, "%s - LogLevel", testName) + assert.Equalf(t, expected.Tracer, actual.Tracer, "%s - Tracer", testName) assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName) assert.Equalf(t, expected.StatementCacheCapacity, actual.StatementCacheCapacity, "%s - StatementCacheCapacity", testName) assert.Equalf(t, expected.DescriptionCacheCapacity, actual.DescriptionCacheCapacity, "%s - DescriptionCacheCapacity", testName) diff --git a/rows.go b/rows.go index c91f3aff..ca5533d9 100644 --- a/rows.go +++ b/rows.go @@ -105,11 +105,6 @@ func (r *connRow) Scan(dest ...any) (err error) { return rows.Err() } -type rowLog interface { - shouldLog(lvl LogLevel) bool - log(ctx context.Context, lvl LogLevel, msg string, data map[string]any) -} - // baseRows implements the Rows interface for Conn.Query. type baseRows struct { typeMap *pgtype.Map @@ -127,12 +122,13 @@ type baseRows struct { conn *Conn multiResultReader *pgconn.MultiResultReader - logger rowLog - ctx context.Context - startTime time.Time - sql string - args []any - rowCount int + queryTracer QueryTracer + batchTracer BatchTracer + ctx context.Context + startTime time.Time + sql string + args []any + rowCount int } func (rows *baseRows) FieldDescriptions() []pgproto3.FieldDescription { @@ -161,20 +157,6 @@ func (rows *baseRows) Close() { } } - if rows.logger != nil { - endTime := time.Now() - - if rows.err == nil { - if rows.logger.shouldLog(LogLevelInfo) { - rows.logger.log(rows.ctx, LogLevelInfo, "Query", map[string]any{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount}) - } - } else { - if rows.logger.shouldLog(LogLevelError) { - rows.logger.log(rows.ctx, LogLevelError, "Query", map[string]any{"err": rows.err, "sql": rows.sql, "time": endTime.Sub(rows.startTime), "args": logQueryArgs(rows.args)}) - } - } - } - if rows.err != nil && rows.conn != nil && rows.sql != "" { if stmtcache.IsStatementInvalid(rows.err) { if sc := rows.conn.statementCache; sc != nil { @@ -186,6 +168,12 @@ func (rows *baseRows) Close() { } } } + + if rows.batchTracer != nil { + rows.batchTracer.TraceBatchQuery(rows.ctx, rows.conn, TraceBatchQueryData{SQL: rows.sql, Args: rows.args, CommandTag: rows.commandTag, Err: rows.err}) + } else if rows.queryTracer != nil { + rows.queryTracer.TraceQueryEnd(rows.ctx, rows.conn, TraceQueryEndData{rows.commandTag, rows.err}) + } } func (rows *baseRows) CommandTag() pgconn.CommandTag { diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index ee038add..ca2dccf3 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -17,6 +17,7 @@ import ( "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/stdlib" + "github.com/jackc/pgx/v5/tracelog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -976,7 +977,7 @@ func TestScanJSONIntoJSONRawMessage(t *testing.T) { } type testLog struct { - lvl pgx.LogLevel + lvl tracelog.LogLevel msg string data map[string]any } @@ -985,7 +986,7 @@ type testLogger struct { logs []testLog } -func (l *testLogger) Log(ctx context.Context, lvl pgx.LogLevel, msg string, data map[string]any) { +func (l *testLogger) Log(ctx context.Context, lvl tracelog.LogLevel, msg string, data map[string]any) { l.logs = append(l.logs, testLog{lvl: lvl, msg: msg, data: data}) } @@ -994,7 +995,7 @@ func TestRegisterConnConfig(t *testing.T) { require.NoError(t, err) logger := &testLogger{} - connConfig.Logger = logger + connConfig.Tracer = &tracelog.TraceLog{Logger: logger, LogLevel: tracelog.LogLevelInfo} // Issue 947: Register and unregister a ConnConfig and ensure that the // returned connection string is not reused. diff --git a/tracelog/tracelog.go b/tracelog/tracelog.go new file mode 100644 index 00000000..d51b9b95 --- /dev/null +++ b/tracelog/tracelog.go @@ -0,0 +1,295 @@ +// Package tracelog provides a tracer that acts as a traditional logger. +package tracelog + +import ( + "context" + "encoding/hex" + "errors" + "fmt" + "time" + + "github.com/jackc/pgx/v5" +) + +// LogLevel represents the pgx logging level. See LogLevel* constants for +// possible values. +type LogLevel int + +// The values for log levels are chosen such that the zero value means that no +// log level was specified. +const ( + LogLevelTrace = LogLevel(6) + LogLevelDebug = LogLevel(5) + LogLevelInfo = LogLevel(4) + LogLevelWarn = LogLevel(3) + LogLevelError = LogLevel(2) + LogLevelNone = LogLevel(1) +) + +func (ll LogLevel) String() string { + switch ll { + case LogLevelTrace: + return "trace" + case LogLevelDebug: + return "debug" + case LogLevelInfo: + return "info" + case LogLevelWarn: + return "warn" + case LogLevelError: + return "error" + case LogLevelNone: + return "none" + default: + return fmt.Sprintf("invalid level %d", ll) + } +} + +// Logger is the interface used to get log output from pgx. +type Logger interface { + // Log a message at the given level with data key/value pairs. data may be nil. + Log(ctx context.Context, level LogLevel, msg string, data map[string]any) +} + +// LoggerFunc is a wrapper around a function to satisfy the pgx.Logger interface +type LoggerFunc func(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) + +// Log delegates the logging request to the wrapped function +func (f LoggerFunc) Log(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) { + f(ctx, level, msg, data) +} + +// LogLevelFromString converts log level string to constant +// +// Valid levels: +// trace +// debug +// info +// warn +// error +// none +func LogLevelFromString(s string) (LogLevel, error) { + switch s { + case "trace": + return LogLevelTrace, nil + case "debug": + return LogLevelDebug, nil + case "info": + return LogLevelInfo, nil + case "warn": + return LogLevelWarn, nil + case "error": + return LogLevelError, nil + case "none": + return LogLevelNone, nil + default: + return 0, errors.New("invalid log level") + } +} + +func logQueryArgs(args []any) []any { + logArgs := make([]any, 0, len(args)) + + for _, a := range args { + switch v := a.(type) { + case []byte: + if len(v) < 64 { + a = hex.EncodeToString(v) + } else { + a = fmt.Sprintf("%x (truncated %d bytes)", v[:64], len(v)-64) + } + case string: + if len(v) > 64 { + a = fmt.Sprintf("%s (truncated %d bytes)", v[:64], len(v)-64) + } + } + logArgs = append(logArgs, a) + } + + return logArgs +} + +// TraceLog implements pgx.QueryTracer, pgx.BatchTracer, pgx.ConnectTracer, and pgx.CopyFromTracer. All fields are +// required. +type TraceLog struct { + Logger Logger + LogLevel LogLevel +} + +type ctxKey int + +const ( + _ ctxKey = iota + tracelogQueryCtxKey + tracelogBatchCtxKey + tracelogCopyFromCtxKey + tracelogConnectCtxKey +) + +type traceQueryData struct { + startTime time.Time + sql string + args []any +} + +func (tl *TraceLog) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context { + return context.WithValue(ctx, tracelogQueryCtxKey, &traceQueryData{ + startTime: time.Now(), + sql: data.SQL, + args: data.Args, + }) +} + +func (tl *TraceLog) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) { + queryData := ctx.Value(tracelogQueryCtxKey).(*traceQueryData) + + endTime := time.Now() + interval := endTime.Sub(queryData.startTime) + + if data.Err != nil { + if tl.shouldLog(LogLevelError) { + tl.log(ctx, conn, LogLevelError, "Query", map[string]any{"sql": queryData.sql, "args": logQueryArgs(queryData.args), "err": data.Err, "time": interval}) + } + return + } + + if tl.shouldLog(LogLevelInfo) { + tl.log(ctx, conn, LogLevelInfo, "Query", map[string]any{"sql": queryData.sql, "args": logQueryArgs(queryData.args), "time": interval, "commandTag": data.CommandTag.String()}) + } +} + +type traceBatchData struct { + startTime time.Time +} + +func (tl *TraceLog) TraceBatchStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context { + return context.WithValue(ctx, tracelogBatchCtxKey, &traceBatchData{ + startTime: time.Now(), + }) +} + +func (tl *TraceLog) TraceBatchQuery(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) { + if data.Err != nil { + if tl.shouldLog(LogLevelError) { + tl.log(ctx, conn, LogLevelError, "BatchQuery", map[string]any{"sql": data.SQL, "args": logQueryArgs(data.Args), "err": data.Err}) + } + return + } + + if tl.shouldLog(LogLevelInfo) { + tl.log(ctx, conn, LogLevelInfo, "BatchQuery", map[string]any{"sql": data.SQL, "args": logQueryArgs(data.Args), "commandTag": data.CommandTag.String()}) + } +} + +func (tl *TraceLog) TraceBatchEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) { + queryData := ctx.Value(tracelogBatchCtxKey).(*traceBatchData) + + endTime := time.Now() + interval := endTime.Sub(queryData.startTime) + + if data.Err != nil { + if tl.shouldLog(LogLevelError) { + tl.log(ctx, conn, LogLevelError, "BatchClose", map[string]any{"err": data.Err, "time": interval}) + } + return + } + + if tl.shouldLog(LogLevelInfo) { + tl.log(ctx, conn, LogLevelInfo, "BatchClose", map[string]any{"time": interval}) + } +} + +type traceCopyFromData struct { + startTime time.Time + TableName pgx.Identifier + ColumnNames []string +} + +func (tl *TraceLog) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context { + return context.WithValue(ctx, tracelogCopyFromCtxKey, &traceCopyFromData{ + startTime: time.Now(), + TableName: data.TableName, + ColumnNames: data.ColumnNames, + }) +} + +func (tl *TraceLog) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) { + copyFromData := ctx.Value(tracelogCopyFromCtxKey).(*traceCopyFromData) + + endTime := time.Now() + interval := endTime.Sub(copyFromData.startTime) + + if data.Err != nil { + if tl.shouldLog(LogLevelError) { + tl.log(ctx, conn, LogLevelError, "CopyFrom", map[string]any{"tableName": copyFromData.TableName, "columnNames": copyFromData.ColumnNames, "err": data.Err, "time": interval}) + } + return + } + + if tl.shouldLog(LogLevelInfo) { + tl.log(ctx, conn, LogLevelInfo, "CopyFrom", map[string]any{"tableName": copyFromData.TableName, "columnNames": copyFromData.ColumnNames, "err": data.Err, "time": interval, "rowCount": data.CommandTag.RowsAffected()}) + } +} + +type traceConnectData struct { + startTime time.Time + connConfig *pgx.ConnConfig +} + +func (tl *TraceLog) TraceConnectStart(ctx context.Context, data pgx.TraceConnectStartData) context.Context { + return context.WithValue(ctx, tracelogConnectCtxKey, &traceConnectData{ + startTime: time.Now(), + connConfig: data.ConnConfig, + }) +} + +func (tl *TraceLog) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEndData) { + connectData := ctx.Value(tracelogConnectCtxKey).(*traceConnectData) + + endTime := time.Now() + interval := endTime.Sub(connectData.startTime) + + if data.Err != nil { + if tl.shouldLog(LogLevelError) { + tl.Logger.Log(ctx, LogLevelError, "Connect", map[string]any{ + "host": connectData.connConfig.Host, + "port": connectData.connConfig.Port, + "database": connectData.connConfig.Database, + "time": interval, + "err": data.Err, + }) + } + return + } + + if data.Conn != nil { + if tl.shouldLog(LogLevelInfo) { + tl.log(ctx, data.Conn, LogLevelInfo, "Connect", map[string]any{ + "host": connectData.connConfig.Host, + "port": connectData.connConfig.Port, + "database": connectData.connConfig.Database, + "time": interval, + }) + } + } +} + +func (tl *TraceLog) shouldLog(lvl LogLevel) bool { + return tl.LogLevel >= lvl +} + +func (tl *TraceLog) log(ctx context.Context, conn *pgx.Conn, lvl LogLevel, msg string, data map[string]any) { + if data == nil { + data = map[string]any{} + } + + pgConn := conn.PgConn() + if pgConn != nil { + pid := pgConn.PID() + if pid != 0 { + data["pid"] = pid + } + } + + tl.Logger.Log(ctx, lvl, msg, data) +} diff --git a/tracelog/tracelog_test.go b/tracelog/tracelog_test.go new file mode 100644 index 00000000..ed0f8eab --- /dev/null +++ b/tracelog/tracelog_test.go @@ -0,0 +1,301 @@ +package tracelog_test + +import ( + "bytes" + "context" + "log" + "os" + "strings" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/jackc/pgx/v5/tracelog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var defaultConnTestRunner pgxtest.ConnTestRunner + +func init() { + defaultConnTestRunner = pgxtest.DefaultConnTestRunner() + defaultConnTestRunner.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + return config + } +} + +type testLog struct { + lvl tracelog.LogLevel + msg string + data map[string]any +} + +type testLogger struct { + logs []testLog +} + +func (l *testLogger) Log(ctx context.Context, level tracelog.LogLevel, msg string, data map[string]any) { + data["ctxdata"] = ctx.Value("ctxdata") + l.logs = append(l.logs, testLog{lvl: level, msg: msg, data: data}) +} + +func TestContextGetsPassedToLogMethod(t *testing.T) { + t.Parallel() + + logger := &testLogger{} + tracer := &tracelog.TraceLog{ + Logger: logger, + LogLevel: tracelog.LogLevelTrace, + } + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + logger.logs = logger.logs[0:0] // Clear any logs written when establishing connection + + ctx = context.WithValue(context.Background(), "ctxdata", "foo") + _, err := conn.Exec(ctx, `;`) + require.NoError(t, err) + require.Len(t, logger.logs, 1) + require.Equal(t, "foo", logger.logs[0].data["ctxdata"]) + }) +} + +func TestLoggerFunc(t *testing.T) { + t.Parallel() + + const testMsg = "foo" + + buf := bytes.Buffer{} + logger := log.New(&buf, "", 0) + + createAdapterFn := func(logger *log.Logger) tracelog.LoggerFunc { + return func(ctx context.Context, level tracelog.LogLevel, msg string, data map[string]interface{}) { + logger.Printf("%s", testMsg) + } + } + + config := defaultConnTestRunner.CreateConfig(context.Background(), t) + config.Tracer = &tracelog.TraceLog{ + Logger: createAdapterFn(logger), + LogLevel: tracelog.LogLevelTrace, + } + + conn, err := pgx.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer conn.Close(context.Background()) + + buf.Reset() // Clear logs written when establishing connection + + if _, err := conn.Exec(context.TODO(), ";"); err != nil { + t.Fatal(err) + } + + if strings.TrimSpace(buf.String()) != testMsg { + t.Errorf("Expected logger function to return '%s', but it was '%s'", testMsg, buf.String()) + } +} + +func TestLogQuery(t *testing.T) { + t.Parallel() + + logger := &testLogger{} + tracer := &tracelog.TraceLog{ + Logger: logger, + LogLevel: tracelog.LogLevelTrace, + } + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + logger.logs = logger.logs[0:0] // Clear any logs written when establishing connection + + _, err := conn.Exec(ctx, `select $1::text`, "testing") + require.NoError(t, err) + require.Len(t, logger.logs, 1) + require.Equal(t, "Query", logger.logs[0].msg) + require.Equal(t, tracelog.LogLevelInfo, logger.logs[0].lvl) + + _, err = conn.Exec(ctx, `foo`, "testing") + require.Error(t, err) + require.Len(t, logger.logs, 2) + require.Equal(t, "Query", logger.logs[1].msg) + require.Equal(t, tracelog.LogLevelError, logger.logs[1].lvl) + require.Equal(t, err, logger.logs[1].data["err"]) + }) +} + +func TestLogCopyFrom(t *testing.T) { + t.Parallel() + + logger := &testLogger{} + tracer := &tracelog.TraceLog{ + Logger: logger, + LogLevel: tracelog.LogLevelTrace, + } + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + _, err := conn.Exec(context.Background(), `create temporary table foo(a int4)`) + require.NoError(t, err) + + logger.logs = logger.logs[0:0] + + inputRows := [][]any{ + {int32(1)}, + {nil}, + } + + copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows)) + require.NoError(t, err) + require.EqualValues(t, len(inputRows), copyCount) + require.Len(t, logger.logs, 1) + require.Equal(t, "CopyFrom", logger.logs[0].msg) + require.Equal(t, tracelog.LogLevelInfo, logger.logs[0].lvl) + + logger.logs = logger.logs[0:0] + + inputRows = [][]any{ + {"not an integer"}, + {nil}, + } + + copyCount, err = conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows)) + require.Error(t, err) + require.EqualValues(t, 0, copyCount) + require.Len(t, logger.logs, 1) + require.Equal(t, "CopyFrom", logger.logs[0].msg) + require.Equal(t, tracelog.LogLevelError, logger.logs[0].lvl) + }) +} + +func TestLogConnect(t *testing.T) { + t.Parallel() + + logger := &testLogger{} + tracer := &tracelog.TraceLog{ + Logger: logger, + LogLevel: tracelog.LogLevelTrace, + } + + config := defaultConnTestRunner.CreateConfig(context.Background(), t) + config.Tracer = tracer + + conn1, err := pgx.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer conn1.Close(context.Background()) + require.Len(t, logger.logs, 1) + require.Equal(t, "Connect", logger.logs[0].msg) + require.Equal(t, tracelog.LogLevelInfo, logger.logs[0].lvl) + + logger.logs = logger.logs[0:0] + + config, err = pgx.ParseConfig("host=/invalid") + require.NoError(t, err) + config.Tracer = tracer + + conn2, err := pgx.ConnectConfig(context.Background(), config) + require.Nil(t, conn2) + require.Error(t, err) + require.Len(t, logger.logs, 1) + require.Equal(t, "Connect", logger.logs[0].msg) + require.Equal(t, tracelog.LogLevelError, logger.logs[0].lvl) +} + +func TestLogBatchStatementsOnExec(t *testing.T) { + t.Parallel() + + logger := &testLogger{} + tracer := &tracelog.TraceLog{ + Logger: logger, + LogLevel: tracelog.LogLevelTrace, + } + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + logger.logs = logger.logs[0:0] // Clear any logs written when establishing connection + + batch := &pgx.Batch{} + batch.Queue("create table foo (id bigint)") + batch.Queue("drop table foo") + + br := conn.SendBatch(context.Background(), batch) + + _, err := br.Exec() + require.NoError(t, err) + + _, err = br.Exec() + require.NoError(t, err) + + err = br.Close() + require.NoError(t, err) + + require.Len(t, logger.logs, 3) + assert.Equal(t, "BatchQuery", logger.logs[0].msg) + assert.Equal(t, "create table foo (id bigint)", logger.logs[0].data["sql"]) + assert.Equal(t, "BatchQuery", logger.logs[1].msg) + assert.Equal(t, "drop table foo", logger.logs[1].data["sql"]) + assert.Equal(t, "BatchClose", logger.logs[2].msg) + + }) +} + +func TestLogBatchStatementsOnBatchResultClose(t *testing.T) { + t.Parallel() + + logger := &testLogger{} + tracer := &tracelog.TraceLog{ + Logger: logger, + LogLevel: tracelog.LogLevelTrace, + } + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + logger.logs = logger.logs[0:0] // Clear any logs written when establishing connection + + batch := &pgx.Batch{} + batch.Queue("select generate_series(1,$1)", 100) + batch.Queue("select 1 = 1;") + + br := conn.SendBatch(context.Background(), batch) + err := br.Close() + require.NoError(t, err) + + require.Len(t, logger.logs, 3) + assert.Equal(t, "BatchQuery", logger.logs[0].msg) + assert.Equal(t, "select generate_series(1,$1)", logger.logs[0].data["sql"]) + assert.Equal(t, "BatchQuery", logger.logs[1].msg) + assert.Equal(t, "select 1 = 1;", logger.logs[1].data["sql"]) + assert.Equal(t, "BatchClose", logger.logs[2].msg) + }) +} diff --git a/tracer.go b/tracer.go new file mode 100644 index 00000000..58ca99f7 --- /dev/null +++ b/tracer.go @@ -0,0 +1,107 @@ +package pgx + +import ( + "context" + + "github.com/jackc/pgx/v5/pgconn" +) + +// QueryTracer traces Query, QueryRow, and Exec. +type QueryTracer interface { + // TraceQueryStart is called at the beginning of Query, QueryRow, and Exec calls. The returned context is used for the + // rest of the call and will be passed to TraceQueryEnd. + TraceQueryStart(ctx context.Context, conn *Conn, data TraceQueryStartData) context.Context + + TraceQueryEnd(ctx context.Context, conn *Conn, data TraceQueryEndData) +} + +type TraceQueryStartData struct { + SQL string + Args []any +} + +type TraceQueryEndData struct { + CommandTag pgconn.CommandTag + Err error +} + +// BatchTracer traces SendBatch. +type BatchTracer interface { + // TraceBatchStart is called at the beginning of SendBatch calls. The returned context is used for the + // rest of the call and will be passed to TraceBatchQuery and TraceBatchEnd. + TraceBatchStart(ctx context.Context, conn *Conn, data TraceBatchStartData) context.Context + + TraceBatchQuery(ctx context.Context, conn *Conn, data TraceBatchQueryData) + TraceBatchEnd(ctx context.Context, conn *Conn, data TraceBatchEndData) +} + +type TraceBatchStartData struct { + Batch *Batch +} + +type TraceBatchQueryData struct { + SQL string + Args []any + CommandTag pgconn.CommandTag + Err error +} + +type TraceBatchEndData struct { + Err error +} + +// CopyFromTracer traces CopyFrom. +type CopyFromTracer interface { + // TraceCopyFromStart is called at the beginning of CopyFrom calls. The returned context is used for the + // rest of the call and will be passed to TraceCopyFromEnd. + TraceCopyFromStart(ctx context.Context, conn *Conn, data TraceCopyFromStartData) context.Context + + TraceCopyFromEnd(ctx context.Context, conn *Conn, data TraceCopyFromEndData) +} + +type TraceCopyFromStartData struct { + TableName Identifier + ColumnNames []string +} + +type TraceCopyFromEndData struct { + CommandTag pgconn.CommandTag + Err error +} + +// PrepareTracer traces Prepare. +type PrepareTracer interface { + // TracePrepareStart is called at the beginning of Prepare calls. The returned context is used for the + // rest of the call and will be passed to TracePrepareEnd. + TracePrepareStart(ctx context.Context, conn *Conn, data TracePrepareStartData) context.Context + + TracePrepareEnd(ctx context.Context, conn *Conn, data TracePrepareEndData) +} + +type TracePrepareStartData struct { + Name string + SQL string +} + +type TracePrepareEndData struct { + AlreadyPrepared bool + Err error +} + +// ConnectTracer traces Connect and ConnectConfig. +type ConnectTracer interface { + // TraceConnectStart is called at the beginning of Connect and ConnectConfig calls. The returned context is used for + // the rest of the call and will be passed to TraceConnectEnd. + TraceConnectStart(ctx context.Context, data TraceConnectStartData) context.Context + + TraceConnectEnd(ctx context.Context, data TraceConnectEndData) +} + +type TraceConnectStartData struct { + ConnConfig *ConnConfig +} + +type TraceConnectEndData struct { + Conn *Conn + Err error +} diff --git a/tracer_test.go b/tracer_test.go new file mode 100644 index 00000000..86375b34 --- /dev/null +++ b/tracer_test.go @@ -0,0 +1,538 @@ +package pgx_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/require" +) + +type testTracer struct { + traceQueryStart func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context + traceQueryEnd func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) + traceBatchStart func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context + traceBatchQuery func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) + traceBatchEnd func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) + traceCopyFromStart func(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context + traceCopyFromEnd func(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) + tracePrepareStart func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context + tracePrepareEnd func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) + traceConnectStart func(ctx context.Context, data pgx.TraceConnectStartData) context.Context + traceConnectEnd func(ctx context.Context, data pgx.TraceConnectEndData) +} + +func (tt *testTracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context { + if tt.traceQueryStart != nil { + return tt.traceQueryStart(ctx, conn, data) + } + return ctx +} + +func (tt *testTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) { + if tt.traceQueryEnd != nil { + tt.traceQueryEnd(ctx, conn, data) + } +} + +func (tt *testTracer) TraceBatchStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context { + if tt.traceBatchStart != nil { + return tt.traceBatchStart(ctx, conn, data) + } + return ctx +} + +func (tt *testTracer) TraceBatchQuery(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) { + if tt.traceBatchQuery != nil { + tt.traceBatchQuery(ctx, conn, data) + } +} + +func (tt *testTracer) TraceBatchEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) { + if tt.traceBatchEnd != nil { + tt.traceBatchEnd(ctx, conn, data) + } +} + +func (tt *testTracer) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context { + if tt.traceCopyFromStart != nil { + return tt.traceCopyFromStart(ctx, conn, data) + } + return ctx +} + +func (tt *testTracer) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) { + if tt.traceCopyFromEnd != nil { + tt.traceCopyFromEnd(ctx, conn, data) + } +} + +func (tt *testTracer) TracePrepareStart(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context { + if tt.tracePrepareStart != nil { + return tt.tracePrepareStart(ctx, conn, data) + } + return ctx +} + +func (tt *testTracer) TracePrepareEnd(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) { + if tt.tracePrepareEnd != nil { + tt.tracePrepareEnd(ctx, conn, data) + } +} + +func (tt *testTracer) TraceConnectStart(ctx context.Context, data pgx.TraceConnectStartData) context.Context { + if tt.traceConnectStart != nil { + return tt.traceConnectStart(ctx, data) + } + return ctx +} + +func (tt *testTracer) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEndData) { + if tt.traceConnectEnd != nil { + tt.traceConnectEnd(ctx, data) + } +} + +func TestTraceExec(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + traceQueryStartCalled := false + tracer.traceQueryStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context { + traceQueryStartCalled = true + require.Equal(t, `select $1::text`, data.SQL) + require.Len(t, data.Args, 1) + require.Equal(t, `testing`, data.Args[0]) + return context.WithValue(ctx, "fromTraceQueryStart", "foo") + } + + traceQueryEndCalled := false + tracer.traceQueryEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) { + traceQueryEndCalled = true + require.Equal(t, "foo", ctx.Value("fromTraceQueryStart")) + require.Equal(t, `SELECT 1`, data.CommandTag.String()) + require.NoError(t, data.Err) + } + + _, err := conn.Exec(ctx, `select $1::text`, "testing") + require.NoError(t, err) + require.True(t, traceQueryStartCalled) + require.True(t, traceQueryEndCalled) + }) +} + +func TestTraceQuery(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + traceQueryStartCalled := false + tracer.traceQueryStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context { + traceQueryStartCalled = true + require.Equal(t, `select $1::text`, data.SQL) + require.Len(t, data.Args, 1) + require.Equal(t, `testing`, data.Args[0]) + return context.WithValue(ctx, "fromTraceQueryStart", "foo") + } + + traceQueryEndCalled := false + tracer.traceQueryEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) { + traceQueryEndCalled = true + require.Equal(t, "foo", ctx.Value("fromTraceQueryStart")) + require.Equal(t, `SELECT 1`, data.CommandTag.String()) + require.NoError(t, data.Err) + } + + var s string + err := conn.QueryRow(ctx, `select $1::text`, "testing").Scan(&s) + require.NoError(t, err) + require.Equal(t, "testing", s) + require.True(t, traceQueryStartCalled) + require.True(t, traceQueryEndCalled) + }) +} + +func TestTraceBatchNormal(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + traceBatchStartCalled := false + tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context { + traceBatchStartCalled = true + require.NotNil(t, data.Batch) + require.Equal(t, 2, data.Batch.Len()) + return context.WithValue(ctx, "fromTraceBatchStart", "foo") + } + + traceBatchQueryCalledCount := 0 + tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) { + traceBatchQueryCalledCount++ + require.Equal(t, "foo", ctx.Value("fromTraceBatchStart")) + require.NoError(t, data.Err) + } + + traceBatchEndCalled := false + tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) { + traceBatchEndCalled = true + require.Equal(t, "foo", ctx.Value("fromTraceBatchStart")) + require.NoError(t, data.Err) + } + + batch := &pgx.Batch{} + batch.Queue(`select 1`) + batch.Queue(`select 2`) + + br := conn.SendBatch(context.Background(), batch) + require.True(t, traceBatchStartCalled) + + var n int32 + err := br.QueryRow().Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 1, n) + require.EqualValues(t, 1, traceBatchQueryCalledCount) + + err = br.QueryRow().Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 2, n) + require.EqualValues(t, 2, traceBatchQueryCalledCount) + + err = br.Close() + require.NoError(t, err) + + require.True(t, traceBatchEndCalled) + }) +} + +func TestTraceBatchClose(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + traceBatchStartCalled := false + tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context { + traceBatchStartCalled = true + require.NotNil(t, data.Batch) + require.Equal(t, 2, data.Batch.Len()) + return context.WithValue(ctx, "fromTraceBatchStart", "foo") + } + + traceBatchQueryCalledCount := 0 + tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) { + traceBatchQueryCalledCount++ + require.Equal(t, "foo", ctx.Value("fromTraceBatchStart")) + require.NoError(t, data.Err) + } + + traceBatchEndCalled := false + tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) { + traceBatchEndCalled = true + require.Equal(t, "foo", ctx.Value("fromTraceBatchStart")) + require.NoError(t, data.Err) + } + + batch := &pgx.Batch{} + batch.Queue(`select 1`) + batch.Queue(`select 2`) + + br := conn.SendBatch(context.Background(), batch) + require.True(t, traceBatchStartCalled) + err := br.Close() + require.NoError(t, err) + require.EqualValues(t, 2, traceBatchQueryCalledCount) + require.True(t, traceBatchEndCalled) + }) +} + +func TestTraceBatchErrorWhileReadingResults(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, []pgx.QueryExecMode{pgx.QueryExecModeSimpleProtocol}, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + traceBatchStartCalled := false + tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context { + traceBatchStartCalled = true + require.NotNil(t, data.Batch) + require.Equal(t, 3, data.Batch.Len()) + return context.WithValue(ctx, "fromTraceBatchStart", "foo") + } + + traceBatchQueryCalledCount := 0 + tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) { + traceBatchQueryCalledCount++ + require.Equal(t, "foo", ctx.Value("fromTraceBatchStart")) + if traceBatchQueryCalledCount == 2 { + require.Error(t, data.Err) + } else { + require.NoError(t, data.Err) + } + } + + traceBatchEndCalled := false + tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) { + traceBatchEndCalled = true + require.Equal(t, "foo", ctx.Value("fromTraceBatchStart")) + require.Error(t, data.Err) + } + + batch := &pgx.Batch{} + batch.Queue(`select 1`) + batch.Queue(`select 2/n-2 from generate_series(0,10) n`) + batch.Queue(`select 3`) + + br := conn.SendBatch(context.Background(), batch) + require.True(t, traceBatchStartCalled) + + commandTag, err := br.Exec() + require.NoError(t, err) + require.Equal(t, "SELECT 1", commandTag.String()) + + commandTag, err = br.Exec() + require.Error(t, err) + require.Equal(t, "", commandTag.String()) + + commandTag, err = br.Exec() + require.Error(t, err) + require.Equal(t, "", commandTag.String()) + + err = br.Close() + require.Error(t, err) + require.EqualValues(t, 2, traceBatchQueryCalledCount) + require.True(t, traceBatchEndCalled) + }) +} + +func TestTraceBatchErrorWhileReadingResultsWhileClosing(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, []pgx.QueryExecMode{pgx.QueryExecModeSimpleProtocol}, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + traceBatchStartCalled := false + tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context { + traceBatchStartCalled = true + require.NotNil(t, data.Batch) + require.Equal(t, 3, data.Batch.Len()) + return context.WithValue(ctx, "fromTraceBatchStart", "foo") + } + + traceBatchQueryCalledCount := 0 + tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) { + traceBatchQueryCalledCount++ + require.Equal(t, "foo", ctx.Value("fromTraceBatchStart")) + if traceBatchQueryCalledCount == 2 { + require.Error(t, data.Err) + } else { + require.NoError(t, data.Err) + } + } + + traceBatchEndCalled := false + tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) { + traceBatchEndCalled = true + require.Equal(t, "foo", ctx.Value("fromTraceBatchStart")) + require.Error(t, data.Err) + } + + batch := &pgx.Batch{} + batch.Queue(`select 1`) + batch.Queue(`select 2/n-2 from generate_series(0,10) n`) + batch.Queue(`select 3`) + + br := conn.SendBatch(context.Background(), batch) + require.True(t, traceBatchStartCalled) + err := br.Close() + require.Error(t, err) + require.EqualValues(t, 2, traceBatchQueryCalledCount) + require.True(t, traceBatchEndCalled) + }) +} + +func TestTraceCopyFrom(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + traceCopyFromStartCalled := false + tracer.traceCopyFromStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context { + traceCopyFromStartCalled = true + require.Equal(t, pgx.Identifier{"foo"}, data.TableName) + require.Equal(t, []string{"a"}, data.ColumnNames) + return context.WithValue(ctx, "fromTraceCopyFromStart", "foo") + } + + traceCopyFromEndCalled := false + tracer.traceCopyFromEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) { + traceCopyFromEndCalled = true + require.Equal(t, "foo", ctx.Value("fromTraceCopyFromStart")) + require.Equal(t, `COPY 2`, data.CommandTag.String()) + require.NoError(t, data.Err) + } + + _, err := conn.Exec(context.Background(), `create temporary table foo(a int4)`) + require.NoError(t, err) + + inputRows := [][]any{ + {int32(1)}, + {nil}, + } + + copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows)) + require.NoError(t, err) + require.EqualValues(t, len(inputRows), copyCount) + require.True(t, traceCopyFromStartCalled) + require.True(t, traceCopyFromEndCalled) + }) +} + +func TestTracePrepare(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + tracePrepareStartCalled := false + tracer.tracePrepareStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context { + tracePrepareStartCalled = true + require.Equal(t, `ps`, data.Name) + require.Equal(t, `select $1::text`, data.SQL) + return context.WithValue(ctx, "fromTracePrepareStart", "foo") + } + + tracePrepareEndCalled := false + tracer.tracePrepareEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) { + tracePrepareEndCalled = true + require.False(t, data.AlreadyPrepared) + require.NoError(t, data.Err) + } + + _, err := conn.Prepare(ctx, "ps", `select $1::text`) + require.NoError(t, err) + require.True(t, tracePrepareStartCalled) + require.True(t, tracePrepareEndCalled) + + tracePrepareStartCalled = false + tracePrepareEndCalled = false + tracer.tracePrepareEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) { + tracePrepareEndCalled = true + require.True(t, data.AlreadyPrepared) + require.NoError(t, data.Err) + } + + _, err = conn.Prepare(ctx, "ps", `select $1::text`) + require.NoError(t, err) + require.True(t, tracePrepareStartCalled) + require.True(t, tracePrepareEndCalled) + }) +} + +func TestTraceConnect(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + config := defaultConnTestRunner.CreateConfig(context.Background(), t) + config.Tracer = tracer + + traceConnectStartCalled := false + tracer.traceConnectStart = func(ctx context.Context, data pgx.TraceConnectStartData) context.Context { + traceConnectStartCalled = true + require.NotNil(t, data.ConnConfig) + return context.WithValue(ctx, "fromTraceConnectStart", "foo") + } + + traceConnectEndCalled := false + tracer.traceConnectEnd = func(ctx context.Context, data pgx.TraceConnectEndData) { + traceConnectEndCalled = true + require.NotNil(t, data.Conn) + require.NoError(t, data.Err) + } + + conn1, err := pgx.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer conn1.Close(context.Background()) + require.True(t, traceConnectStartCalled) + require.True(t, traceConnectEndCalled) + + config, err = pgx.ParseConfig("host=/invalid") + require.NoError(t, err) + config.Tracer = tracer + + traceConnectStartCalled = false + traceConnectEndCalled = false + tracer.traceConnectEnd = func(ctx context.Context, data pgx.TraceConnectEndData) { + traceConnectEndCalled = true + require.Nil(t, data.Conn) + require.Error(t, data.Err) + } + + conn2, err := pgx.ConnectConfig(context.Background(), config) + require.Nil(t, conn2) + require.Error(t, err) + require.True(t, traceConnectStartCalled) + require.True(t, traceConnectEndCalled) +}