mirror of https://github.com/jackc/pgx.git
Add tracing support
Replaces existing logging support. Package tracelog provides adapter for old style logging. https://github.com/jackc/pgx/issues/1061pull/1281/head
parent
9201cc0341
commit
78875bb95a
|
@ -159,7 +159,9 @@ Previously, a batch with 10 unique parameterized statements executed 100 times w
|
|||
for each prepare / describe and 1 for executing them all. Now pipeline mode is used to prepare / describe all statements
|
||||
in a single network round trip. So it would only take 2 round trips.
|
||||
|
||||
## 3rd Party Logger Integration
|
||||
## Tracing and Logging
|
||||
|
||||
Internal logging support has been replaced with tracing hooks. This allows custom tracing integration with tools like OpenTelemetry. Package tracelog provides an adapter for pgx v4 loggers to act as a tracer.
|
||||
|
||||
All integrations with 3rd party loggers have been extracted to separate repositories. This trims the pgx dependency
|
||||
tree.
|
||||
|
|
|
@ -163,6 +163,8 @@ pgerrcode contains constants for the PostgreSQL error codes.
|
|||
|
||||
## Adapters for 3rd Party Loggers
|
||||
|
||||
These adapters can be used with the tracelog package.
|
||||
|
||||
* [github.com/jackc/pgx-go-kit-log](https://github.com/jackc/pgx-go-kit-log)
|
||||
* [github.com/jackc/pgx-log15](https://github.com/jackc/pgx-log15)
|
||||
* [github.com/jackc/pgx-logrus](https://github.com/jackc/pgx-logrus)
|
||||
|
|
186
batch.go
186
batch.go
|
@ -50,13 +50,14 @@ type BatchResults interface {
|
|||
}
|
||||
|
||||
type batchResults struct {
|
||||
ctx context.Context
|
||||
conn *Conn
|
||||
mrr *pgconn.MultiResultReader
|
||||
err error
|
||||
b *Batch
|
||||
ix int
|
||||
closed bool
|
||||
ctx context.Context
|
||||
conn *Conn
|
||||
mrr *pgconn.MultiResultReader
|
||||
err error
|
||||
b *Batch
|
||||
ix int
|
||||
closed bool
|
||||
endTraced bool
|
||||
}
|
||||
|
||||
// Exec reads the results from the next query in the batch as if the query has been sent with Exec.
|
||||
|
@ -75,35 +76,29 @@ func (br *batchResults) Exec() (pgconn.CommandTag, error) {
|
|||
if err == nil {
|
||||
err = errors.New("no result")
|
||||
}
|
||||
if br.conn.shouldLog(LogLevelError) {
|
||||
br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]any{
|
||||
"sql": query,
|
||||
"args": logQueryArgs(arguments),
|
||||
"err": err,
|
||||
if br.conn.batchTracer != nil {
|
||||
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
|
||||
SQL: query,
|
||||
Args: arguments,
|
||||
Err: err,
|
||||
})
|
||||
}
|
||||
return pgconn.CommandTag{}, err
|
||||
}
|
||||
|
||||
commandTag, err := br.mrr.ResultReader().Close()
|
||||
br.err = err
|
||||
|
||||
if err != nil {
|
||||
if br.conn.shouldLog(LogLevelError) {
|
||||
br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]any{
|
||||
"sql": query,
|
||||
"args": logQueryArgs(arguments),
|
||||
"err": err,
|
||||
})
|
||||
}
|
||||
} else if br.conn.shouldLog(LogLevelInfo) {
|
||||
br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Exec", map[string]any{
|
||||
"sql": query,
|
||||
"args": logQueryArgs(arguments),
|
||||
"commandTag": commandTag,
|
||||
if br.conn.batchTracer != nil {
|
||||
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
|
||||
SQL: query,
|
||||
Args: arguments,
|
||||
CommandTag: commandTag,
|
||||
Err: br.err,
|
||||
})
|
||||
}
|
||||
|
||||
return commandTag, err
|
||||
return commandTag, br.err
|
||||
}
|
||||
|
||||
// Query reads the results from the next query in the batch as if the query has been sent with Query.
|
||||
|
@ -123,6 +118,7 @@ func (br *batchResults) Query() (Rows, error) {
|
|||
}
|
||||
|
||||
rows := br.conn.getRows(br.ctx, query, arguments)
|
||||
rows.batchTracer = br.conn.batchTracer
|
||||
|
||||
if !br.mrr.NextResult() {
|
||||
rows.err = br.mrr.Close()
|
||||
|
@ -131,11 +127,11 @@ func (br *batchResults) Query() (Rows, error) {
|
|||
}
|
||||
rows.closed = true
|
||||
|
||||
if br.conn.shouldLog(LogLevelError) {
|
||||
br.conn.log(br.ctx, LogLevelError, "BatchResult.Query", map[string]any{
|
||||
"sql": query,
|
||||
"args": logQueryArgs(arguments),
|
||||
"err": rows.err,
|
||||
if br.conn.batchTracer != nil {
|
||||
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
|
||||
SQL: query,
|
||||
Args: arguments,
|
||||
Err: rows.err,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -156,6 +152,15 @@ func (br *batchResults) QueryRow() Row {
|
|||
// Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to
|
||||
// resyncronize the connection with the server. In this case the underlying connection will have been closed.
|
||||
func (br *batchResults) Close() error {
|
||||
defer func() {
|
||||
if !br.endTraced {
|
||||
if br.conn != nil && br.conn.batchTracer != nil {
|
||||
br.conn.batchTracer.TraceBatchEnd(br.ctx, br.conn, TraceBatchEndData{Err: br.err})
|
||||
}
|
||||
br.endTraced = true
|
||||
}
|
||||
}()
|
||||
|
||||
if br.err != nil {
|
||||
return br.err
|
||||
}
|
||||
|
@ -163,24 +168,26 @@ func (br *batchResults) Close() error {
|
|||
if br.closed {
|
||||
return nil
|
||||
}
|
||||
br.closed = true
|
||||
|
||||
// log any queries that haven't yet been logged by Exec or Query
|
||||
for {
|
||||
query, args, ok := br.nextQueryAndArgs()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
if br.conn.shouldLog(LogLevelInfo) {
|
||||
br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Close", map[string]any{
|
||||
"sql": query,
|
||||
"args": logQueryArgs(args),
|
||||
})
|
||||
// consume and log any queries that haven't yet been logged by Exec or Query
|
||||
if br.conn.batchTracer != nil {
|
||||
for br.err == nil && !br.closed && br.b != nil && br.ix < len(br.b.items) {
|
||||
br.Exec()
|
||||
}
|
||||
}
|
||||
|
||||
return br.mrr.Close()
|
||||
br.closed = true
|
||||
|
||||
err := br.mrr.Close()
|
||||
if br.err == nil {
|
||||
br.err = err
|
||||
}
|
||||
|
||||
return br.err
|
||||
}
|
||||
|
||||
func (br *batchResults) earlyError() error {
|
||||
return br.err
|
||||
}
|
||||
|
||||
func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
|
||||
|
@ -195,14 +202,15 @@ func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
|
|||
}
|
||||
|
||||
type pipelineBatchResults struct {
|
||||
ctx context.Context
|
||||
conn *Conn
|
||||
pipeline *pgconn.Pipeline
|
||||
lastRows *baseRows
|
||||
err error
|
||||
b *Batch
|
||||
ix int
|
||||
closed bool
|
||||
ctx context.Context
|
||||
conn *Conn
|
||||
pipeline *pgconn.Pipeline
|
||||
lastRows *baseRows
|
||||
err error
|
||||
b *Batch
|
||||
ix int
|
||||
closed bool
|
||||
endTraced bool
|
||||
}
|
||||
|
||||
// Exec reads the results from the next query in the batch as if the query has been sent with Exec.
|
||||
|
@ -227,25 +235,17 @@ func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) {
|
|||
var commandTag pgconn.CommandTag
|
||||
switch results := results.(type) {
|
||||
case *pgconn.ResultReader:
|
||||
commandTag, err = results.Close()
|
||||
commandTag, br.err = results.Close()
|
||||
default:
|
||||
return pgconn.CommandTag{}, fmt.Errorf("unexpected pipeline result: %T", results)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
br.err = err
|
||||
if br.conn.shouldLog(LogLevelError) {
|
||||
br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]any{
|
||||
"sql": query,
|
||||
"args": logQueryArgs(arguments),
|
||||
"err": err,
|
||||
})
|
||||
}
|
||||
} else if br.conn.shouldLog(LogLevelInfo) {
|
||||
br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Exec", map[string]any{
|
||||
"sql": query,
|
||||
"args": logQueryArgs(arguments),
|
||||
"commandTag": commandTag,
|
||||
if br.conn.batchTracer != nil {
|
||||
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
|
||||
SQL: query,
|
||||
Args: arguments,
|
||||
CommandTag: commandTag,
|
||||
Err: br.err,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -274,6 +274,7 @@ func (br *pipelineBatchResults) Query() (Rows, error) {
|
|||
}
|
||||
|
||||
rows := br.conn.getRows(br.ctx, query, arguments)
|
||||
rows.batchTracer = br.conn.batchTracer
|
||||
br.lastRows = rows
|
||||
|
||||
results, err := br.pipeline.GetResults()
|
||||
|
@ -281,11 +282,12 @@ func (br *pipelineBatchResults) Query() (Rows, error) {
|
|||
br.err = err
|
||||
rows.err = err
|
||||
rows.closed = true
|
||||
if br.conn.shouldLog(LogLevelError) {
|
||||
br.conn.log(br.ctx, LogLevelError, "BatchResult.Query", map[string]any{
|
||||
"sql": query,
|
||||
"args": logQueryArgs(arguments),
|
||||
"err": rows.err,
|
||||
|
||||
if br.conn.batchTracer != nil {
|
||||
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
|
||||
SQL: query,
|
||||
Args: arguments,
|
||||
Err: err,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
|
@ -313,6 +315,15 @@ func (br *pipelineBatchResults) QueryRow() Row {
|
|||
// Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to
|
||||
// resyncronize the connection with the server. In this case the underlying connection will have been closed.
|
||||
func (br *pipelineBatchResults) Close() error {
|
||||
defer func() {
|
||||
if !br.endTraced {
|
||||
if br.conn.batchTracer != nil {
|
||||
br.conn.batchTracer.TraceBatchEnd(br.ctx, br.conn, TraceBatchEndData{Err: br.err})
|
||||
}
|
||||
br.endTraced = true
|
||||
}
|
||||
}()
|
||||
|
||||
if br.err != nil {
|
||||
return br.err
|
||||
}
|
||||
|
@ -325,24 +336,25 @@ func (br *pipelineBatchResults) Close() error {
|
|||
if br.closed {
|
||||
return nil
|
||||
}
|
||||
br.closed = true
|
||||
|
||||
// log any queries that haven't yet been logged by Exec or Query
|
||||
for {
|
||||
query, args, ok := br.nextQueryAndArgs()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
if br.conn.shouldLog(LogLevelInfo) {
|
||||
br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Close", map[string]any{
|
||||
"sql": query,
|
||||
"args": logQueryArgs(args),
|
||||
})
|
||||
// consume and log any queries that haven't yet been logged by Exec or Query
|
||||
if br.conn.batchTracer != nil {
|
||||
for br.err == nil && !br.closed && br.b != nil && br.ix < len(br.b.items) {
|
||||
br.Exec()
|
||||
}
|
||||
}
|
||||
br.closed = true
|
||||
|
||||
return br.pipeline.Close()
|
||||
err := br.pipeline.Close()
|
||||
if br.err == nil {
|
||||
br.err = err
|
||||
}
|
||||
|
||||
return br.err
|
||||
}
|
||||
|
||||
func (br *pipelineBatchResults) earlyError() error {
|
||||
return br.err
|
||||
}
|
||||
|
||||
func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
|
||||
|
|
|
@ -733,94 +733,6 @@ func testConnSendBatch(t *testing.T, conn *pgx.Conn, queryCount int) {
|
|||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestLogBatchStatementsOnExec(t *testing.T) {
|
||||
l1 := &testLogger{}
|
||||
config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
config.Logger = l1
|
||||
|
||||
conn := mustConnect(t, config)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
l1.logs = l1.logs[0:0] // Clear logs written when establishing connection
|
||||
|
||||
batch := &pgx.Batch{}
|
||||
batch.Queue("create table foo (id bigint)")
|
||||
batch.Queue("drop table foo")
|
||||
|
||||
br := conn.SendBatch(context.Background(), batch)
|
||||
|
||||
_, err := br.Exec()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error creating table: %v", err)
|
||||
}
|
||||
|
||||
_, err = br.Exec()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error dropping table: %v", err)
|
||||
}
|
||||
|
||||
if len(l1.logs) != 2 {
|
||||
t.Fatalf("Expected two log entries but got %d", len(l1.logs))
|
||||
}
|
||||
|
||||
if l1.logs[0].msg != "BatchResult.Exec" {
|
||||
t.Errorf("Expected first log message to be 'BatchResult.Exec' but was '%s", l1.logs[0].msg)
|
||||
}
|
||||
|
||||
if l1.logs[0].data["sql"] != "create table foo (id bigint)" {
|
||||
t.Errorf("Expected the first query to be 'create table foo (id bigint)' but was '%s'", l1.logs[0].data["sql"])
|
||||
}
|
||||
|
||||
if l1.logs[1].msg != "BatchResult.Exec" {
|
||||
t.Errorf("Expected second log message to be 'BatchResult.Exec' but was '%s", l1.logs[1].msg)
|
||||
}
|
||||
|
||||
if l1.logs[1].data["sql"] != "drop table foo" {
|
||||
t.Errorf("Expected the second query to be 'drop table foo' but was '%s'", l1.logs[1].data["sql"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogBatchStatementsOnBatchResultClose(t *testing.T) {
|
||||
l1 := &testLogger{}
|
||||
config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
config.Logger = l1
|
||||
|
||||
conn := mustConnect(t, config)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
l1.logs = l1.logs[0:0] // Clear logs written when establishing connection
|
||||
|
||||
batch := &pgx.Batch{}
|
||||
batch.Queue("select generate_series(1,$1)", 100)
|
||||
batch.Queue("select 1 = 1;")
|
||||
|
||||
br := conn.SendBatch(context.Background(), batch)
|
||||
|
||||
if err := br.Close(); err != nil {
|
||||
t.Fatalf("Unexpected batch error: %v", err)
|
||||
}
|
||||
|
||||
if len(l1.logs) != 2 {
|
||||
t.Fatalf("Expected 2 log statements but found %d", len(l1.logs))
|
||||
}
|
||||
|
||||
if l1.logs[0].msg != "BatchResult.Close" {
|
||||
t.Errorf("Expected first log statement to be 'BatchResult.Close' but was %s", l1.logs[0].msg)
|
||||
}
|
||||
|
||||
if l1.logs[0].data["sql"] != "select generate_series(1,$1)" {
|
||||
t.Errorf("Expected first query to be 'select generate_series(1,$1)' but was '%s'", l1.logs[0].data["sql"])
|
||||
}
|
||||
|
||||
if l1.logs[1].msg != "BatchResult.Close" {
|
||||
t.Errorf("Expected second log statement to be 'BatchResult.Close' but was %s", l1.logs[1].msg)
|
||||
}
|
||||
|
||||
if l1.logs[1].data["sql"] != "select 1 = 1;" {
|
||||
t.Errorf("Expected second query to be 'select 1 = 1;' but was '%s'", l1.logs[1].data["sql"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendBatchSimpleProtocol(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
117
bench_test.go
117
bench_test.go
|
@ -284,123 +284,6 @@ func BenchmarkPointerPointerWithPresentValues(b *testing.B) {
|
|||
}
|
||||
}
|
||||
|
||||
func BenchmarkSelectWithoutLogging(b *testing.B) {
|
||||
conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")))
|
||||
defer closeConn(b, conn)
|
||||
|
||||
benchmarkSelectWithLog(b, conn)
|
||||
}
|
||||
|
||||
type discardLogger struct{}
|
||||
|
||||
func (dl discardLogger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]any) {
|
||||
}
|
||||
|
||||
func BenchmarkSelectWithLoggingTraceDiscard(b *testing.B) {
|
||||
var logger discardLogger
|
||||
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
||||
config.Logger = logger
|
||||
config.LogLevel = pgx.LogLevelTrace
|
||||
|
||||
conn := mustConnect(b, config)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
benchmarkSelectWithLog(b, conn)
|
||||
}
|
||||
|
||||
func BenchmarkSelectWithLoggingDebugWithDiscard(b *testing.B) {
|
||||
var logger discardLogger
|
||||
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
||||
config.Logger = logger
|
||||
config.LogLevel = pgx.LogLevelDebug
|
||||
|
||||
conn := mustConnect(b, config)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
benchmarkSelectWithLog(b, conn)
|
||||
}
|
||||
|
||||
func BenchmarkSelectWithLoggingInfoWithDiscard(b *testing.B) {
|
||||
var logger discardLogger
|
||||
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
||||
config.Logger = logger
|
||||
config.LogLevel = pgx.LogLevelInfo
|
||||
|
||||
conn := mustConnect(b, config)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
benchmarkSelectWithLog(b, conn)
|
||||
}
|
||||
|
||||
func BenchmarkSelectWithLoggingErrorWithDiscard(b *testing.B) {
|
||||
var logger discardLogger
|
||||
config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))
|
||||
config.Logger = logger
|
||||
config.LogLevel = pgx.LogLevelError
|
||||
|
||||
conn := mustConnect(b, config)
|
||||
defer closeConn(b, conn)
|
||||
|
||||
benchmarkSelectWithLog(b, conn)
|
||||
}
|
||||
|
||||
func benchmarkSelectWithLog(b *testing.B, conn *pgx.Conn) {
|
||||
_, err := conn.Prepare(context.Background(), "test", "select 1::int4, 'johnsmith', 'johnsmith@example.com', 'John Smith', 'male', '1970-01-01'::date, '2015-01-01 00:00:00'::timestamptz")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
var record struct {
|
||||
id int32
|
||||
userName string
|
||||
email string
|
||||
name string
|
||||
sex string
|
||||
birthDate time.Time
|
||||
lastLoginTime time.Time
|
||||
}
|
||||
|
||||
err = conn.QueryRow(context.Background(), "test").Scan(
|
||||
&record.id,
|
||||
&record.userName,
|
||||
&record.email,
|
||||
&record.name,
|
||||
&record.sex,
|
||||
&record.birthDate,
|
||||
&record.lastLoginTime,
|
||||
)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
// These checks both ensure that the correct data was returned
|
||||
// and provide a benchmark of accessing the returned values.
|
||||
if record.id != 1 {
|
||||
b.Fatalf("bad value for id: %v", record.id)
|
||||
}
|
||||
if record.userName != "johnsmith" {
|
||||
b.Fatalf("bad value for userName: %v", record.userName)
|
||||
}
|
||||
if record.email != "johnsmith@example.com" {
|
||||
b.Fatalf("bad value for email: %v", record.email)
|
||||
}
|
||||
if record.name != "John Smith" {
|
||||
b.Fatalf("bad value for name: %v", record.name)
|
||||
}
|
||||
if record.sex != "male" {
|
||||
b.Fatalf("bad value for sex: %v", record.sex)
|
||||
}
|
||||
if record.birthDate != time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC) {
|
||||
b.Fatalf("bad value for birthDate: %v", record.birthDate)
|
||||
}
|
||||
if record.lastLoginTime != time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local) {
|
||||
b.Fatalf("bad value for lastLoginTime: %v", record.lastLoginTime)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const benchmarkWriteTableCreateSQL = `drop table if exists t;
|
||||
|
||||
create table t(
|
||||
|
|
119
conn.go
119
conn.go
|
@ -19,8 +19,8 @@ import (
|
|||
// then it can be modified. A manually initialized ConnConfig will cause ConnectConfig to panic.
|
||||
type ConnConfig struct {
|
||||
pgconn.Config
|
||||
Logger Logger
|
||||
LogLevel LogLevel
|
||||
|
||||
Tracer QueryTracer
|
||||
|
||||
// Original connection string that was parsed into config.
|
||||
connString string
|
||||
|
@ -63,8 +63,11 @@ type Conn struct {
|
|||
preparedStatements map[string]*pgconn.StatementDescription
|
||||
statementCache stmtcache.Cache
|
||||
descriptionCache stmtcache.Cache
|
||||
logger Logger
|
||||
logLevel LogLevel
|
||||
|
||||
queryTracer QueryTracer
|
||||
batchTracer BatchTracer
|
||||
copyFromTracer CopyFromTracer
|
||||
prepareTracer PrepareTracer
|
||||
|
||||
notifications []*pgconn.Notification
|
||||
|
||||
|
@ -94,9 +97,6 @@ func (ident Identifier) Sanitize() string {
|
|||
// ErrNoRows occurs when rows are expected but none are returned.
|
||||
var ErrNoRows = errors.New("no rows in result set")
|
||||
|
||||
// ErrInvalidLogLevel occurs on attempt to set an invalid log level.
|
||||
var ErrInvalidLogLevel = errors.New("invalid log level")
|
||||
|
||||
var errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache")
|
||||
var errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache")
|
||||
|
||||
|
@ -182,7 +182,6 @@ func ParseConfig(connString string) (*ConnConfig, error) {
|
|||
connConfig := &ConnConfig{
|
||||
Config: *config,
|
||||
createdByParseConfig: true,
|
||||
LogLevel: LogLevelInfo,
|
||||
StatementCacheCapacity: statementCacheCapacity,
|
||||
DescriptionCacheCapacity: descriptionCacheCapacity,
|
||||
DefaultQueryExecMode: defaultQueryExecMode,
|
||||
|
@ -194,6 +193,13 @@ func ParseConfig(connString string) (*ConnConfig, error) {
|
|||
|
||||
// connect connects to a database. connect takes ownership of config. The caller must not use or access it again.
|
||||
func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) {
|
||||
if connectTracer, ok := config.Tracer.(ConnectTracer); ok {
|
||||
ctx = connectTracer.TraceConnectStart(ctx, TraceConnectStartData{ConnConfig: config})
|
||||
defer func() {
|
||||
connectTracer.TraceConnectEnd(ctx, TraceConnectEndData{Conn: c, Err: err})
|
||||
}()
|
||||
}
|
||||
|
||||
// Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from
|
||||
// zero values.
|
||||
if !config.createdByParseConfig {
|
||||
|
@ -201,29 +207,28 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) {
|
|||
}
|
||||
|
||||
c = &Conn{
|
||||
config: config,
|
||||
typeMap: pgtype.NewMap(),
|
||||
logLevel: config.LogLevel,
|
||||
logger: config.Logger,
|
||||
config: config,
|
||||
typeMap: pgtype.NewMap(),
|
||||
queryTracer: config.Tracer,
|
||||
}
|
||||
|
||||
if t, ok := c.queryTracer.(BatchTracer); ok {
|
||||
c.batchTracer = t
|
||||
}
|
||||
if t, ok := c.queryTracer.(CopyFromTracer); ok {
|
||||
c.copyFromTracer = t
|
||||
}
|
||||
if t, ok := c.queryTracer.(PrepareTracer); ok {
|
||||
c.prepareTracer = t
|
||||
}
|
||||
|
||||
// Only install pgx notification system if no other callback handler is present.
|
||||
if config.Config.OnNotification == nil {
|
||||
config.Config.OnNotification = c.bufferNotifications
|
||||
} else {
|
||||
if c.shouldLog(LogLevelDebug) {
|
||||
c.log(ctx, LogLevelDebug, "pgx notification handler disabled by application supplied OnNotification", map[string]any{"host": config.Config.Host})
|
||||
}
|
||||
}
|
||||
|
||||
if c.shouldLog(LogLevelInfo) {
|
||||
c.log(ctx, LogLevelInfo, "Dialing PostgreSQL server", map[string]any{"host": config.Config.Host})
|
||||
}
|
||||
c.pgConn, err = pgconn.ConnectConfig(ctx, &config.Config)
|
||||
if err != nil {
|
||||
if c.shouldLog(LogLevelError) {
|
||||
c.log(ctx, LogLevelError, "connect failed", map[string]any{"err": err})
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -251,9 +256,6 @@ func (c *Conn) Close(ctx context.Context) error {
|
|||
}
|
||||
|
||||
err := c.pgConn.Close(ctx)
|
||||
if c.shouldLog(LogLevelInfo) {
|
||||
c.log(ctx, LogLevelInfo, "closed connection", nil)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -264,18 +266,23 @@ func (c *Conn) Close(ctx context.Context) error {
|
|||
// name and sql arguments. This allows a code path to Prepare and Query/Exec without
|
||||
// concern for if the statement has already been prepared.
|
||||
func (c *Conn) Prepare(ctx context.Context, name, sql string) (sd *pgconn.StatementDescription, err error) {
|
||||
if c.prepareTracer != nil {
|
||||
ctx = c.prepareTracer.TracePrepareStart(ctx, c, TracePrepareStartData{Name: name, SQL: sql})
|
||||
}
|
||||
|
||||
if name != "" {
|
||||
var ok bool
|
||||
if sd, ok = c.preparedStatements[name]; ok && sd.SQL == sql {
|
||||
if c.prepareTracer != nil {
|
||||
c.prepareTracer.TracePrepareEnd(ctx, c, TracePrepareEndData{AlreadyPrepared: true})
|
||||
}
|
||||
return sd, nil
|
||||
}
|
||||
}
|
||||
|
||||
if c.shouldLog(LogLevelError) {
|
||||
if c.prepareTracer != nil {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
c.log(ctx, LogLevelError, "Prepare failed", map[string]any{"err": err, "name": name, "sql": sql})
|
||||
}
|
||||
c.prepareTracer.TracePrepareEnd(ctx, c, TracePrepareEndData{Err: err})
|
||||
}()
|
||||
}
|
||||
|
||||
|
@ -337,21 +344,6 @@ func (c *Conn) die(err error) {
|
|||
c.pgConn.Close(ctx)
|
||||
}
|
||||
|
||||
func (c *Conn) shouldLog(lvl LogLevel) bool {
|
||||
return c.logger != nil && c.logLevel >= lvl
|
||||
}
|
||||
|
||||
func (c *Conn) log(ctx context.Context, lvl LogLevel, msg string, data map[string]any) {
|
||||
if data == nil {
|
||||
data = map[string]any{}
|
||||
}
|
||||
if c.pgConn != nil && c.pgConn.PID() != 0 {
|
||||
data["pid"] = c.pgConn.PID()
|
||||
}
|
||||
|
||||
c.logger.Log(ctx, lvl, msg, data)
|
||||
}
|
||||
|
||||
func quoteIdentifier(s string) string {
|
||||
return `"` + strings.ReplaceAll(s, `"`, `""`) + `"`
|
||||
}
|
||||
|
@ -379,24 +371,18 @@ func (c *Conn) Config() *ConnConfig { return c.config.Copy() }
|
|||
// Exec executes sql. sql can be either a prepared statement name or an SQL string. arguments should be referenced
|
||||
// positionally from the sql string as $1, $2, etc.
|
||||
func (c *Conn) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) {
|
||||
if c.queryTracer != nil {
|
||||
ctx = c.queryTracer.TraceQueryStart(ctx, c, TraceQueryStartData{SQL: sql, Args: arguments})
|
||||
}
|
||||
|
||||
if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil {
|
||||
return pgconn.CommandTag{}, err
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
commandTag, err := c.exec(ctx, sql, arguments...)
|
||||
if err != nil {
|
||||
if c.shouldLog(LogLevelError) {
|
||||
endTime := time.Now()
|
||||
c.log(ctx, LogLevelError, "Exec", map[string]any{"sql": sql, "args": logQueryArgs(arguments), "err": err, "time": endTime.Sub(startTime)})
|
||||
}
|
||||
return commandTag, err
|
||||
}
|
||||
|
||||
if c.shouldLog(LogLevelInfo) {
|
||||
endTime := time.Now()
|
||||
c.log(ctx, LogLevelInfo, "Exec", map[string]any{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag})
|
||||
if c.queryTracer != nil {
|
||||
c.queryTracer.TraceQueryEnd(ctx, c, TraceQueryEndData{CommandTag: commandTag, Err: err})
|
||||
}
|
||||
|
||||
return commandTag, err
|
||||
|
@ -537,7 +523,7 @@ func (c *Conn) getRows(ctx context.Context, sql string, args []any) *baseRows {
|
|||
r := &baseRows{}
|
||||
|
||||
r.ctx = ctx
|
||||
r.logger = c
|
||||
r.queryTracer = c.queryTracer
|
||||
r.typeMap = c.typeMap
|
||||
r.startTime = time.Now()
|
||||
r.sql = sql
|
||||
|
@ -628,7 +614,14 @@ type QueryRewriter interface {
|
|||
// QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely
|
||||
// needed. See the documentation for those types for details.
|
||||
func (c *Conn) Query(ctx context.Context, sql string, args ...any) (Rows, error) {
|
||||
if c.queryTracer != nil {
|
||||
ctx = c.queryTracer.TraceQueryStart(ctx, c, TraceQueryStartData{SQL: sql, Args: args})
|
||||
}
|
||||
|
||||
if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil {
|
||||
if c.queryTracer != nil {
|
||||
c.queryTracer.TraceQueryEnd(ctx, c, TraceQueryEndData{Err: err})
|
||||
}
|
||||
return &baseRows{err: err, closed: true}, err
|
||||
}
|
||||
|
||||
|
@ -791,7 +784,17 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...any) Row {
|
|||
// SendBatch sends all queued queries to the server at once. All queries are run in an implicit transaction unless
|
||||
// explicit transaction control statements are executed. The returned BatchResults must be closed before the connection
|
||||
// is used again.
|
||||
func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
|
||||
func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
|
||||
if c.batchTracer != nil {
|
||||
ctx = c.batchTracer.TraceBatchStart(ctx, c, TraceBatchStartData{Batch: b})
|
||||
defer func() {
|
||||
err := br.(interface{ earlyError() error }).earlyError()
|
||||
if err != nil {
|
||||
c.batchTracer.TraceBatchEnd(ctx, c, TraceBatchEndData{Err: err})
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil {
|
||||
return &batchResults{ctx: ctx, conn: c, err: err}
|
||||
}
|
||||
|
|
74
conn_test.go
74
conn_test.go
|
@ -3,7 +3,6 @@ package pgx_test
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -743,79 +742,6 @@ func TestInsertTimestampArray(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
type testLog struct {
|
||||
lvl pgx.LogLevel
|
||||
msg string
|
||||
data map[string]any
|
||||
}
|
||||
|
||||
type testLogger struct {
|
||||
logs []testLog
|
||||
}
|
||||
|
||||
func (l *testLogger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]any) {
|
||||
data["ctxdata"] = ctx.Value("ctxdata")
|
||||
l.logs = append(l.logs, testLog{lvl: level, msg: msg, data: data})
|
||||
}
|
||||
|
||||
func TestLogPassesContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
l1 := &testLogger{}
|
||||
config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
config.Logger = l1
|
||||
|
||||
conn := mustConnect(t, config)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
l1.logs = l1.logs[0:0] // Clear logs written when establishing connection
|
||||
|
||||
ctx := context.WithValue(context.Background(), "ctxdata", "foo")
|
||||
|
||||
if _, err := conn.Exec(ctx, ";"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(l1.logs) != 1 {
|
||||
t.Fatal("Expected logger to be called once, but it wasn't")
|
||||
}
|
||||
|
||||
if l1.logs[0].data["ctxdata"] != "foo" {
|
||||
t.Fatal("Expected context data to be passed to logger, but it wasn't")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggerFunc(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const testMsg = "foo"
|
||||
|
||||
buf := bytes.Buffer{}
|
||||
logger := log.New(&buf, "", 0)
|
||||
|
||||
createAdapterFn := func(logger *log.Logger) pgx.LoggerFunc {
|
||||
return func(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
|
||||
logger.Printf("%s", testMsg)
|
||||
}
|
||||
}
|
||||
|
||||
config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
config.Logger = createAdapterFn(logger)
|
||||
|
||||
conn := mustConnect(t, config)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
buf.Reset() // Clear logs written when establishing connection
|
||||
|
||||
if _, err := conn.Exec(context.TODO(), ";"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(buf.String()) != testMsg {
|
||||
t.Errorf("Expected logger function to return '%s', but it was '%s'", testMsg, buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestIdentifierSanitize(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
25
copy_from.go
25
copy_from.go
|
@ -5,7 +5,6 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
|
@ -89,6 +88,13 @@ type copyFrom struct {
|
|||
}
|
||||
|
||||
func (ct *copyFrom) run(ctx context.Context) (int64, error) {
|
||||
if ct.conn.copyFromTracer != nil {
|
||||
ctx = ct.conn.copyFromTracer.TraceCopyFromStart(ctx, ct.conn, TraceCopyFromStartData{
|
||||
TableName: ct.tableName,
|
||||
ColumnNames: ct.columnNames,
|
||||
})
|
||||
}
|
||||
|
||||
quotedTableName := ct.tableName.Sanitize()
|
||||
cbuf := &bytes.Buffer{}
|
||||
for i, cn := range ct.columnNames {
|
||||
|
@ -145,24 +151,19 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) {
|
|||
w.Close()
|
||||
}()
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
|
||||
|
||||
r.Close()
|
||||
<-doneChan
|
||||
|
||||
rowsAffected := commandTag.RowsAffected()
|
||||
endTime := time.Now()
|
||||
if err == nil {
|
||||
if ct.conn.shouldLog(LogLevelInfo) {
|
||||
ct.conn.log(ctx, LogLevelInfo, "CopyFrom", map[string]any{"tableName": ct.tableName, "columnNames": ct.columnNames, "time": endTime.Sub(startTime), "rowCount": rowsAffected})
|
||||
}
|
||||
} else if ct.conn.shouldLog(LogLevelError) {
|
||||
ct.conn.log(ctx, LogLevelError, "CopyFrom", map[string]any{"err": err, "tableName": ct.tableName, "columnNames": ct.columnNames, "time": endTime.Sub(startTime)})
|
||||
if ct.conn.copyFromTracer != nil {
|
||||
ct.conn.copyFromTracer.TraceCopyFromEnd(ctx, ct.conn, TraceCopyFromEndData{
|
||||
CommandTag: commandTag,
|
||||
Err: err,
|
||||
})
|
||||
}
|
||||
|
||||
return rowsAffected, err
|
||||
return commandTag.RowsAffected(), err
|
||||
}
|
||||
|
||||
func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (bool, []byte, error) {
|
||||
|
|
18
doc.go
18
doc.go
|
@ -12,15 +12,7 @@ The primary way of establishing a connection is with `pgx.Connect`.
|
|||
|
||||
The database connection string can be in URL or DSN format. Both PostgreSQL settings and pgx settings can be specified
|
||||
here. In addition, a config struct can be created by `ParseConfig` and modified before establishing the connection with
|
||||
`ConnectConfig`.
|
||||
|
||||
config, err := pgx.ParseConfig(os.Getenv("DATABASE_URL"))
|
||||
if err != nil {
|
||||
// ...
|
||||
}
|
||||
config.Logger = log15adapter.NewLogger(log.New("module", "pgx"))
|
||||
|
||||
conn, err := pgx.ConnectConfig(context.Background(), config)
|
||||
`ConnectConfig` to configure settings such as tracing that cannot be configured with a connection string.
|
||||
|
||||
Connection Pool
|
||||
|
||||
|
@ -315,11 +307,11 @@ notification is received or the context is canceled.
|
|||
}
|
||||
|
||||
|
||||
Logging
|
||||
Tracing and Logging
|
||||
|
||||
pgx defines a simple logger interface. Connections optionally accept a logger that satisfies this interface. Set
|
||||
LogLevel to control logging verbosity. Adapters for github.com/inconshreveable/log15, github.com/sirupsen/logrus,
|
||||
go.uber.org/zap, github.com/rs/zerolog, and the testing log are provided in the log directory.
|
||||
pgx supports tracing by setting ConnConfig.Tracer.
|
||||
|
||||
In addition, the tracelog package provides the TraceLog type which lets a traditional logger act as a Tracer.
|
||||
|
||||
Lower Level PostgreSQL Functionality
|
||||
|
||||
|
|
|
@ -98,8 +98,7 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, testName
|
|||
return
|
||||
}
|
||||
|
||||
assert.Equalf(t, expected.Logger, actual.Logger, "%s - Logger", testName)
|
||||
assert.Equalf(t, expected.LogLevel, actual.LogLevel, "%s - LogLevel", testName)
|
||||
assert.Equalf(t, expected.Tracer, actual.Tracer, "%s - Tracer", testName)
|
||||
assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName)
|
||||
assert.Equalf(t, expected.StatementCacheCapacity, actual.StatementCacheCapacity, "%s - StatementCacheCapacity", testName)
|
||||
assert.Equalf(t, expected.DescriptionCacheCapacity, actual.DescriptionCacheCapacity, "%s - DescriptionCacheCapacity", testName)
|
||||
|
|
|
@ -162,7 +162,7 @@ func (f *Frontend) SendExecute(msg *Execute) {
|
|||
prevLen := len(f.wbuf)
|
||||
f.wbuf = msg.Encode(f.wbuf)
|
||||
if f.tracer != nil {
|
||||
f.tracer.traceExecute('F', int32(len(f.wbuf)-prevLen), msg)
|
||||
f.tracer.TraceQueryute('F', int32(len(f.wbuf)-prevLen), msg)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -79,7 +79,7 @@ func (t *tracer) traceMessage(sender byte, encodedLen int32, msg Message) {
|
|||
case *ErrorResponse:
|
||||
t.traceErrorResponse(sender, encodedLen, msg)
|
||||
case *Execute:
|
||||
t.traceExecute(sender, encodedLen, msg)
|
||||
t.TraceQueryute(sender, encodedLen, msg)
|
||||
case *Flush:
|
||||
t.traceFlush(sender, encodedLen, msg)
|
||||
case *FunctionCall:
|
||||
|
@ -277,7 +277,7 @@ func (t *tracer) traceErrorResponse(sender byte, encodedLen int32, msg *ErrorRes
|
|||
t.finishTrace()
|
||||
}
|
||||
|
||||
func (t *tracer) traceExecute(sender byte, encodedLen int32, msg *Execute) {
|
||||
func (t *tracer) TraceQueryute(sender byte, encodedLen int32, msg *Execute) {
|
||||
t.beginTrace(sender, encodedLen, "Execute")
|
||||
fmt.Fprintf(t.buf, "\t %s %d", traceDoubleQuotedString([]byte(msg.Portal)), msg.MaxRows)
|
||||
t.finishTrace()
|
||||
|
|
|
@ -160,8 +160,7 @@ func assertConnConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, test
|
|||
return
|
||||
}
|
||||
|
||||
assert.Equalf(t, expected.Logger, actual.Logger, "%s - Logger", testName)
|
||||
assert.Equalf(t, expected.LogLevel, actual.LogLevel, "%s - LogLevel", testName)
|
||||
assert.Equalf(t, expected.Tracer, actual.Tracer, "%s - Tracer", testName)
|
||||
assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName)
|
||||
assert.Equalf(t, expected.StatementCacheCapacity, actual.StatementCacheCapacity, "%s - StatementCacheCapacity", testName)
|
||||
assert.Equalf(t, expected.DescriptionCacheCapacity, actual.DescriptionCacheCapacity, "%s - DescriptionCacheCapacity", testName)
|
||||
|
|
38
rows.go
38
rows.go
|
@ -105,11 +105,6 @@ func (r *connRow) Scan(dest ...any) (err error) {
|
|||
return rows.Err()
|
||||
}
|
||||
|
||||
type rowLog interface {
|
||||
shouldLog(lvl LogLevel) bool
|
||||
log(ctx context.Context, lvl LogLevel, msg string, data map[string]any)
|
||||
}
|
||||
|
||||
// baseRows implements the Rows interface for Conn.Query.
|
||||
type baseRows struct {
|
||||
typeMap *pgtype.Map
|
||||
|
@ -127,12 +122,13 @@ type baseRows struct {
|
|||
conn *Conn
|
||||
multiResultReader *pgconn.MultiResultReader
|
||||
|
||||
logger rowLog
|
||||
ctx context.Context
|
||||
startTime time.Time
|
||||
sql string
|
||||
args []any
|
||||
rowCount int
|
||||
queryTracer QueryTracer
|
||||
batchTracer BatchTracer
|
||||
ctx context.Context
|
||||
startTime time.Time
|
||||
sql string
|
||||
args []any
|
||||
rowCount int
|
||||
}
|
||||
|
||||
func (rows *baseRows) FieldDescriptions() []pgproto3.FieldDescription {
|
||||
|
@ -161,20 +157,6 @@ func (rows *baseRows) Close() {
|
|||
}
|
||||
}
|
||||
|
||||
if rows.logger != nil {
|
||||
endTime := time.Now()
|
||||
|
||||
if rows.err == nil {
|
||||
if rows.logger.shouldLog(LogLevelInfo) {
|
||||
rows.logger.log(rows.ctx, LogLevelInfo, "Query", map[string]any{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount})
|
||||
}
|
||||
} else {
|
||||
if rows.logger.shouldLog(LogLevelError) {
|
||||
rows.logger.log(rows.ctx, LogLevelError, "Query", map[string]any{"err": rows.err, "sql": rows.sql, "time": endTime.Sub(rows.startTime), "args": logQueryArgs(rows.args)})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if rows.err != nil && rows.conn != nil && rows.sql != "" {
|
||||
if stmtcache.IsStatementInvalid(rows.err) {
|
||||
if sc := rows.conn.statementCache; sc != nil {
|
||||
|
@ -186,6 +168,12 @@ func (rows *baseRows) Close() {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
if rows.batchTracer != nil {
|
||||
rows.batchTracer.TraceBatchQuery(rows.ctx, rows.conn, TraceBatchQueryData{SQL: rows.sql, Args: rows.args, CommandTag: rows.commandTag, Err: rows.err})
|
||||
} else if rows.queryTracer != nil {
|
||||
rows.queryTracer.TraceQueryEnd(rows.ctx, rows.conn, TraceQueryEndData{rows.commandTag, rows.err})
|
||||
}
|
||||
}
|
||||
|
||||
func (rows *baseRows) CommandTag() pgconn.CommandTag {
|
||||
|
|
|
@ -17,6 +17,7 @@ import (
|
|||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/jackc/pgx/v5/stdlib"
|
||||
"github.com/jackc/pgx/v5/tracelog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
@ -976,7 +977,7 @@ func TestScanJSONIntoJSONRawMessage(t *testing.T) {
|
|||
}
|
||||
|
||||
type testLog struct {
|
||||
lvl pgx.LogLevel
|
||||
lvl tracelog.LogLevel
|
||||
msg string
|
||||
data map[string]any
|
||||
}
|
||||
|
@ -985,7 +986,7 @@ type testLogger struct {
|
|||
logs []testLog
|
||||
}
|
||||
|
||||
func (l *testLogger) Log(ctx context.Context, lvl pgx.LogLevel, msg string, data map[string]any) {
|
||||
func (l *testLogger) Log(ctx context.Context, lvl tracelog.LogLevel, msg string, data map[string]any) {
|
||||
l.logs = append(l.logs, testLog{lvl: lvl, msg: msg, data: data})
|
||||
}
|
||||
|
||||
|
@ -994,7 +995,7 @@ func TestRegisterConnConfig(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
logger := &testLogger{}
|
||||
connConfig.Logger = logger
|
||||
connConfig.Tracer = &tracelog.TraceLog{Logger: logger, LogLevel: tracelog.LogLevelInfo}
|
||||
|
||||
// Issue 947: Register and unregister a ConnConfig and ensure that the
|
||||
// returned connection string is not reused.
|
||||
|
|
|
@ -0,0 +1,295 @@
|
|||
// Package tracelog provides a tracer that acts as a traditional logger.
|
||||
package tracelog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
)
|
||||
|
||||
// LogLevel represents the pgx logging level. See LogLevel* constants for
|
||||
// possible values.
|
||||
type LogLevel int
|
||||
|
||||
// The values for log levels are chosen such that the zero value means that no
|
||||
// log level was specified.
|
||||
const (
|
||||
LogLevelTrace = LogLevel(6)
|
||||
LogLevelDebug = LogLevel(5)
|
||||
LogLevelInfo = LogLevel(4)
|
||||
LogLevelWarn = LogLevel(3)
|
||||
LogLevelError = LogLevel(2)
|
||||
LogLevelNone = LogLevel(1)
|
||||
)
|
||||
|
||||
func (ll LogLevel) String() string {
|
||||
switch ll {
|
||||
case LogLevelTrace:
|
||||
return "trace"
|
||||
case LogLevelDebug:
|
||||
return "debug"
|
||||
case LogLevelInfo:
|
||||
return "info"
|
||||
case LogLevelWarn:
|
||||
return "warn"
|
||||
case LogLevelError:
|
||||
return "error"
|
||||
case LogLevelNone:
|
||||
return "none"
|
||||
default:
|
||||
return fmt.Sprintf("invalid level %d", ll)
|
||||
}
|
||||
}
|
||||
|
||||
// Logger is the interface used to get log output from pgx.
|
||||
type Logger interface {
|
||||
// Log a message at the given level with data key/value pairs. data may be nil.
|
||||
Log(ctx context.Context, level LogLevel, msg string, data map[string]any)
|
||||
}
|
||||
|
||||
// LoggerFunc is a wrapper around a function to satisfy the pgx.Logger interface
|
||||
type LoggerFunc func(ctx context.Context, level LogLevel, msg string, data map[string]interface{})
|
||||
|
||||
// Log delegates the logging request to the wrapped function
|
||||
func (f LoggerFunc) Log(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) {
|
||||
f(ctx, level, msg, data)
|
||||
}
|
||||
|
||||
// LogLevelFromString converts log level string to constant
|
||||
//
|
||||
// Valid levels:
|
||||
// trace
|
||||
// debug
|
||||
// info
|
||||
// warn
|
||||
// error
|
||||
// none
|
||||
func LogLevelFromString(s string) (LogLevel, error) {
|
||||
switch s {
|
||||
case "trace":
|
||||
return LogLevelTrace, nil
|
||||
case "debug":
|
||||
return LogLevelDebug, nil
|
||||
case "info":
|
||||
return LogLevelInfo, nil
|
||||
case "warn":
|
||||
return LogLevelWarn, nil
|
||||
case "error":
|
||||
return LogLevelError, nil
|
||||
case "none":
|
||||
return LogLevelNone, nil
|
||||
default:
|
||||
return 0, errors.New("invalid log level")
|
||||
}
|
||||
}
|
||||
|
||||
func logQueryArgs(args []any) []any {
|
||||
logArgs := make([]any, 0, len(args))
|
||||
|
||||
for _, a := range args {
|
||||
switch v := a.(type) {
|
||||
case []byte:
|
||||
if len(v) < 64 {
|
||||
a = hex.EncodeToString(v)
|
||||
} else {
|
||||
a = fmt.Sprintf("%x (truncated %d bytes)", v[:64], len(v)-64)
|
||||
}
|
||||
case string:
|
||||
if len(v) > 64 {
|
||||
a = fmt.Sprintf("%s (truncated %d bytes)", v[:64], len(v)-64)
|
||||
}
|
||||
}
|
||||
logArgs = append(logArgs, a)
|
||||
}
|
||||
|
||||
return logArgs
|
||||
}
|
||||
|
||||
// TraceLog implements pgx.QueryTracer, pgx.BatchTracer, pgx.ConnectTracer, and pgx.CopyFromTracer. All fields are
|
||||
// required.
|
||||
type TraceLog struct {
|
||||
Logger Logger
|
||||
LogLevel LogLevel
|
||||
}
|
||||
|
||||
type ctxKey int
|
||||
|
||||
const (
|
||||
_ ctxKey = iota
|
||||
tracelogQueryCtxKey
|
||||
tracelogBatchCtxKey
|
||||
tracelogCopyFromCtxKey
|
||||
tracelogConnectCtxKey
|
||||
)
|
||||
|
||||
type traceQueryData struct {
|
||||
startTime time.Time
|
||||
sql string
|
||||
args []any
|
||||
}
|
||||
|
||||
func (tl *TraceLog) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
|
||||
return context.WithValue(ctx, tracelogQueryCtxKey, &traceQueryData{
|
||||
startTime: time.Now(),
|
||||
sql: data.SQL,
|
||||
args: data.Args,
|
||||
})
|
||||
}
|
||||
|
||||
func (tl *TraceLog) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
|
||||
queryData := ctx.Value(tracelogQueryCtxKey).(*traceQueryData)
|
||||
|
||||
endTime := time.Now()
|
||||
interval := endTime.Sub(queryData.startTime)
|
||||
|
||||
if data.Err != nil {
|
||||
if tl.shouldLog(LogLevelError) {
|
||||
tl.log(ctx, conn, LogLevelError, "Query", map[string]any{"sql": queryData.sql, "args": logQueryArgs(queryData.args), "err": data.Err, "time": interval})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if tl.shouldLog(LogLevelInfo) {
|
||||
tl.log(ctx, conn, LogLevelInfo, "Query", map[string]any{"sql": queryData.sql, "args": logQueryArgs(queryData.args), "time": interval, "commandTag": data.CommandTag.String()})
|
||||
}
|
||||
}
|
||||
|
||||
type traceBatchData struct {
|
||||
startTime time.Time
|
||||
}
|
||||
|
||||
func (tl *TraceLog) TraceBatchStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
|
||||
return context.WithValue(ctx, tracelogBatchCtxKey, &traceBatchData{
|
||||
startTime: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
func (tl *TraceLog) TraceBatchQuery(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
|
||||
if data.Err != nil {
|
||||
if tl.shouldLog(LogLevelError) {
|
||||
tl.log(ctx, conn, LogLevelError, "BatchQuery", map[string]any{"sql": data.SQL, "args": logQueryArgs(data.Args), "err": data.Err})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if tl.shouldLog(LogLevelInfo) {
|
||||
tl.log(ctx, conn, LogLevelInfo, "BatchQuery", map[string]any{"sql": data.SQL, "args": logQueryArgs(data.Args), "commandTag": data.CommandTag.String()})
|
||||
}
|
||||
}
|
||||
|
||||
func (tl *TraceLog) TraceBatchEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
|
||||
queryData := ctx.Value(tracelogBatchCtxKey).(*traceBatchData)
|
||||
|
||||
endTime := time.Now()
|
||||
interval := endTime.Sub(queryData.startTime)
|
||||
|
||||
if data.Err != nil {
|
||||
if tl.shouldLog(LogLevelError) {
|
||||
tl.log(ctx, conn, LogLevelError, "BatchClose", map[string]any{"err": data.Err, "time": interval})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if tl.shouldLog(LogLevelInfo) {
|
||||
tl.log(ctx, conn, LogLevelInfo, "BatchClose", map[string]any{"time": interval})
|
||||
}
|
||||
}
|
||||
|
||||
type traceCopyFromData struct {
|
||||
startTime time.Time
|
||||
TableName pgx.Identifier
|
||||
ColumnNames []string
|
||||
}
|
||||
|
||||
func (tl *TraceLog) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context {
|
||||
return context.WithValue(ctx, tracelogCopyFromCtxKey, &traceCopyFromData{
|
||||
startTime: time.Now(),
|
||||
TableName: data.TableName,
|
||||
ColumnNames: data.ColumnNames,
|
||||
})
|
||||
}
|
||||
|
||||
func (tl *TraceLog) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) {
|
||||
copyFromData := ctx.Value(tracelogCopyFromCtxKey).(*traceCopyFromData)
|
||||
|
||||
endTime := time.Now()
|
||||
interval := endTime.Sub(copyFromData.startTime)
|
||||
|
||||
if data.Err != nil {
|
||||
if tl.shouldLog(LogLevelError) {
|
||||
tl.log(ctx, conn, LogLevelError, "CopyFrom", map[string]any{"tableName": copyFromData.TableName, "columnNames": copyFromData.ColumnNames, "err": data.Err, "time": interval})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if tl.shouldLog(LogLevelInfo) {
|
||||
tl.log(ctx, conn, LogLevelInfo, "CopyFrom", map[string]any{"tableName": copyFromData.TableName, "columnNames": copyFromData.ColumnNames, "err": data.Err, "time": interval, "rowCount": data.CommandTag.RowsAffected()})
|
||||
}
|
||||
}
|
||||
|
||||
type traceConnectData struct {
|
||||
startTime time.Time
|
||||
connConfig *pgx.ConnConfig
|
||||
}
|
||||
|
||||
func (tl *TraceLog) TraceConnectStart(ctx context.Context, data pgx.TraceConnectStartData) context.Context {
|
||||
return context.WithValue(ctx, tracelogConnectCtxKey, &traceConnectData{
|
||||
startTime: time.Now(),
|
||||
connConfig: data.ConnConfig,
|
||||
})
|
||||
}
|
||||
|
||||
func (tl *TraceLog) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEndData) {
|
||||
connectData := ctx.Value(tracelogConnectCtxKey).(*traceConnectData)
|
||||
|
||||
endTime := time.Now()
|
||||
interval := endTime.Sub(connectData.startTime)
|
||||
|
||||
if data.Err != nil {
|
||||
if tl.shouldLog(LogLevelError) {
|
||||
tl.Logger.Log(ctx, LogLevelError, "Connect", map[string]any{
|
||||
"host": connectData.connConfig.Host,
|
||||
"port": connectData.connConfig.Port,
|
||||
"database": connectData.connConfig.Database,
|
||||
"time": interval,
|
||||
"err": data.Err,
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if data.Conn != nil {
|
||||
if tl.shouldLog(LogLevelInfo) {
|
||||
tl.log(ctx, data.Conn, LogLevelInfo, "Connect", map[string]any{
|
||||
"host": connectData.connConfig.Host,
|
||||
"port": connectData.connConfig.Port,
|
||||
"database": connectData.connConfig.Database,
|
||||
"time": interval,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (tl *TraceLog) shouldLog(lvl LogLevel) bool {
|
||||
return tl.LogLevel >= lvl
|
||||
}
|
||||
|
||||
func (tl *TraceLog) log(ctx context.Context, conn *pgx.Conn, lvl LogLevel, msg string, data map[string]any) {
|
||||
if data == nil {
|
||||
data = map[string]any{}
|
||||
}
|
||||
|
||||
pgConn := conn.PgConn()
|
||||
if pgConn != nil {
|
||||
pid := pgConn.PID()
|
||||
if pid != 0 {
|
||||
data["pid"] = pid
|
||||
}
|
||||
}
|
||||
|
||||
tl.Logger.Log(ctx, lvl, msg, data)
|
||||
}
|
|
@ -0,0 +1,301 @@
|
|||
package tracelog_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxtest"
|
||||
"github.com/jackc/pgx/v5/tracelog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var defaultConnTestRunner pgxtest.ConnTestRunner
|
||||
|
||||
func init() {
|
||||
defaultConnTestRunner = pgxtest.DefaultConnTestRunner()
|
||||
defaultConnTestRunner.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
|
||||
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
return config
|
||||
}
|
||||
}
|
||||
|
||||
type testLog struct {
|
||||
lvl tracelog.LogLevel
|
||||
msg string
|
||||
data map[string]any
|
||||
}
|
||||
|
||||
type testLogger struct {
|
||||
logs []testLog
|
||||
}
|
||||
|
||||
func (l *testLogger) Log(ctx context.Context, level tracelog.LogLevel, msg string, data map[string]any) {
|
||||
data["ctxdata"] = ctx.Value("ctxdata")
|
||||
l.logs = append(l.logs, testLog{lvl: level, msg: msg, data: data})
|
||||
}
|
||||
|
||||
func TestContextGetsPassedToLogMethod(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := &testLogger{}
|
||||
tracer := &tracelog.TraceLog{
|
||||
Logger: logger,
|
||||
LogLevel: tracelog.LogLevelTrace,
|
||||
}
|
||||
|
||||
ctr := defaultConnTestRunner
|
||||
ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
|
||||
config := defaultConnTestRunner.CreateConfig(ctx, t)
|
||||
config.Tracer = tracer
|
||||
return config
|
||||
}
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
logger.logs = logger.logs[0:0] // Clear any logs written when establishing connection
|
||||
|
||||
ctx = context.WithValue(context.Background(), "ctxdata", "foo")
|
||||
_, err := conn.Exec(ctx, `;`)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, logger.logs, 1)
|
||||
require.Equal(t, "foo", logger.logs[0].data["ctxdata"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoggerFunc(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const testMsg = "foo"
|
||||
|
||||
buf := bytes.Buffer{}
|
||||
logger := log.New(&buf, "", 0)
|
||||
|
||||
createAdapterFn := func(logger *log.Logger) tracelog.LoggerFunc {
|
||||
return func(ctx context.Context, level tracelog.LogLevel, msg string, data map[string]interface{}) {
|
||||
logger.Printf("%s", testMsg)
|
||||
}
|
||||
}
|
||||
|
||||
config := defaultConnTestRunner.CreateConfig(context.Background(), t)
|
||||
config.Tracer = &tracelog.TraceLog{
|
||||
Logger: createAdapterFn(logger),
|
||||
LogLevel: tracelog.LogLevelTrace,
|
||||
}
|
||||
|
||||
conn, err := pgx.ConnectConfig(context.Background(), config)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(context.Background())
|
||||
|
||||
buf.Reset() // Clear logs written when establishing connection
|
||||
|
||||
if _, err := conn.Exec(context.TODO(), ";"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(buf.String()) != testMsg {
|
||||
t.Errorf("Expected logger function to return '%s', but it was '%s'", testMsg, buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := &testLogger{}
|
||||
tracer := &tracelog.TraceLog{
|
||||
Logger: logger,
|
||||
LogLevel: tracelog.LogLevelTrace,
|
||||
}
|
||||
|
||||
ctr := defaultConnTestRunner
|
||||
ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
|
||||
config := defaultConnTestRunner.CreateConfig(ctx, t)
|
||||
config.Tracer = tracer
|
||||
return config
|
||||
}
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
logger.logs = logger.logs[0:0] // Clear any logs written when establishing connection
|
||||
|
||||
_, err := conn.Exec(ctx, `select $1::text`, "testing")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, logger.logs, 1)
|
||||
require.Equal(t, "Query", logger.logs[0].msg)
|
||||
require.Equal(t, tracelog.LogLevelInfo, logger.logs[0].lvl)
|
||||
|
||||
_, err = conn.Exec(ctx, `foo`, "testing")
|
||||
require.Error(t, err)
|
||||
require.Len(t, logger.logs, 2)
|
||||
require.Equal(t, "Query", logger.logs[1].msg)
|
||||
require.Equal(t, tracelog.LogLevelError, logger.logs[1].lvl)
|
||||
require.Equal(t, err, logger.logs[1].data["err"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestLogCopyFrom(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := &testLogger{}
|
||||
tracer := &tracelog.TraceLog{
|
||||
Logger: logger,
|
||||
LogLevel: tracelog.LogLevelTrace,
|
||||
}
|
||||
|
||||
ctr := defaultConnTestRunner
|
||||
ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
|
||||
config := defaultConnTestRunner.CreateConfig(ctx, t)
|
||||
config.Tracer = tracer
|
||||
return config
|
||||
}
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
_, err := conn.Exec(context.Background(), `create temporary table foo(a int4)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
logger.logs = logger.logs[0:0]
|
||||
|
||||
inputRows := [][]any{
|
||||
{int32(1)},
|
||||
{nil},
|
||||
}
|
||||
|
||||
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows))
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, len(inputRows), copyCount)
|
||||
require.Len(t, logger.logs, 1)
|
||||
require.Equal(t, "CopyFrom", logger.logs[0].msg)
|
||||
require.Equal(t, tracelog.LogLevelInfo, logger.logs[0].lvl)
|
||||
|
||||
logger.logs = logger.logs[0:0]
|
||||
|
||||
inputRows = [][]any{
|
||||
{"not an integer"},
|
||||
{nil},
|
||||
}
|
||||
|
||||
copyCount, err = conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows))
|
||||
require.Error(t, err)
|
||||
require.EqualValues(t, 0, copyCount)
|
||||
require.Len(t, logger.logs, 1)
|
||||
require.Equal(t, "CopyFrom", logger.logs[0].msg)
|
||||
require.Equal(t, tracelog.LogLevelError, logger.logs[0].lvl)
|
||||
})
|
||||
}
|
||||
|
||||
func TestLogConnect(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := &testLogger{}
|
||||
tracer := &tracelog.TraceLog{
|
||||
Logger: logger,
|
||||
LogLevel: tracelog.LogLevelTrace,
|
||||
}
|
||||
|
||||
config := defaultConnTestRunner.CreateConfig(context.Background(), t)
|
||||
config.Tracer = tracer
|
||||
|
||||
conn1, err := pgx.ConnectConfig(context.Background(), config)
|
||||
require.NoError(t, err)
|
||||
defer conn1.Close(context.Background())
|
||||
require.Len(t, logger.logs, 1)
|
||||
require.Equal(t, "Connect", logger.logs[0].msg)
|
||||
require.Equal(t, tracelog.LogLevelInfo, logger.logs[0].lvl)
|
||||
|
||||
logger.logs = logger.logs[0:0]
|
||||
|
||||
config, err = pgx.ParseConfig("host=/invalid")
|
||||
require.NoError(t, err)
|
||||
config.Tracer = tracer
|
||||
|
||||
conn2, err := pgx.ConnectConfig(context.Background(), config)
|
||||
require.Nil(t, conn2)
|
||||
require.Error(t, err)
|
||||
require.Len(t, logger.logs, 1)
|
||||
require.Equal(t, "Connect", logger.logs[0].msg)
|
||||
require.Equal(t, tracelog.LogLevelError, logger.logs[0].lvl)
|
||||
}
|
||||
|
||||
func TestLogBatchStatementsOnExec(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := &testLogger{}
|
||||
tracer := &tracelog.TraceLog{
|
||||
Logger: logger,
|
||||
LogLevel: tracelog.LogLevelTrace,
|
||||
}
|
||||
|
||||
ctr := defaultConnTestRunner
|
||||
ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
|
||||
config := defaultConnTestRunner.CreateConfig(ctx, t)
|
||||
config.Tracer = tracer
|
||||
return config
|
||||
}
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
logger.logs = logger.logs[0:0] // Clear any logs written when establishing connection
|
||||
|
||||
batch := &pgx.Batch{}
|
||||
batch.Queue("create table foo (id bigint)")
|
||||
batch.Queue("drop table foo")
|
||||
|
||||
br := conn.SendBatch(context.Background(), batch)
|
||||
|
||||
_, err := br.Exec()
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = br.Exec()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = br.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, logger.logs, 3)
|
||||
assert.Equal(t, "BatchQuery", logger.logs[0].msg)
|
||||
assert.Equal(t, "create table foo (id bigint)", logger.logs[0].data["sql"])
|
||||
assert.Equal(t, "BatchQuery", logger.logs[1].msg)
|
||||
assert.Equal(t, "drop table foo", logger.logs[1].data["sql"])
|
||||
assert.Equal(t, "BatchClose", logger.logs[2].msg)
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
func TestLogBatchStatementsOnBatchResultClose(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := &testLogger{}
|
||||
tracer := &tracelog.TraceLog{
|
||||
Logger: logger,
|
||||
LogLevel: tracelog.LogLevelTrace,
|
||||
}
|
||||
|
||||
ctr := defaultConnTestRunner
|
||||
ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
|
||||
config := defaultConnTestRunner.CreateConfig(ctx, t)
|
||||
config.Tracer = tracer
|
||||
return config
|
||||
}
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
logger.logs = logger.logs[0:0] // Clear any logs written when establishing connection
|
||||
|
||||
batch := &pgx.Batch{}
|
||||
batch.Queue("select generate_series(1,$1)", 100)
|
||||
batch.Queue("select 1 = 1;")
|
||||
|
||||
br := conn.SendBatch(context.Background(), batch)
|
||||
err := br.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, logger.logs, 3)
|
||||
assert.Equal(t, "BatchQuery", logger.logs[0].msg)
|
||||
assert.Equal(t, "select generate_series(1,$1)", logger.logs[0].data["sql"])
|
||||
assert.Equal(t, "BatchQuery", logger.logs[1].msg)
|
||||
assert.Equal(t, "select 1 = 1;", logger.logs[1].data["sql"])
|
||||
assert.Equal(t, "BatchClose", logger.logs[2].msg)
|
||||
})
|
||||
}
|
|
@ -0,0 +1,107 @@
|
|||
package pgx
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
)
|
||||
|
||||
// QueryTracer traces Query, QueryRow, and Exec.
|
||||
type QueryTracer interface {
|
||||
// TraceQueryStart is called at the beginning of Query, QueryRow, and Exec calls. The returned context is used for the
|
||||
// rest of the call and will be passed to TraceQueryEnd.
|
||||
TraceQueryStart(ctx context.Context, conn *Conn, data TraceQueryStartData) context.Context
|
||||
|
||||
TraceQueryEnd(ctx context.Context, conn *Conn, data TraceQueryEndData)
|
||||
}
|
||||
|
||||
type TraceQueryStartData struct {
|
||||
SQL string
|
||||
Args []any
|
||||
}
|
||||
|
||||
type TraceQueryEndData struct {
|
||||
CommandTag pgconn.CommandTag
|
||||
Err error
|
||||
}
|
||||
|
||||
// BatchTracer traces SendBatch.
|
||||
type BatchTracer interface {
|
||||
// TraceBatchStart is called at the beginning of SendBatch calls. The returned context is used for the
|
||||
// rest of the call and will be passed to TraceBatchQuery and TraceBatchEnd.
|
||||
TraceBatchStart(ctx context.Context, conn *Conn, data TraceBatchStartData) context.Context
|
||||
|
||||
TraceBatchQuery(ctx context.Context, conn *Conn, data TraceBatchQueryData)
|
||||
TraceBatchEnd(ctx context.Context, conn *Conn, data TraceBatchEndData)
|
||||
}
|
||||
|
||||
type TraceBatchStartData struct {
|
||||
Batch *Batch
|
||||
}
|
||||
|
||||
type TraceBatchQueryData struct {
|
||||
SQL string
|
||||
Args []any
|
||||
CommandTag pgconn.CommandTag
|
||||
Err error
|
||||
}
|
||||
|
||||
type TraceBatchEndData struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
// CopyFromTracer traces CopyFrom.
|
||||
type CopyFromTracer interface {
|
||||
// TraceCopyFromStart is called at the beginning of CopyFrom calls. The returned context is used for the
|
||||
// rest of the call and will be passed to TraceCopyFromEnd.
|
||||
TraceCopyFromStart(ctx context.Context, conn *Conn, data TraceCopyFromStartData) context.Context
|
||||
|
||||
TraceCopyFromEnd(ctx context.Context, conn *Conn, data TraceCopyFromEndData)
|
||||
}
|
||||
|
||||
type TraceCopyFromStartData struct {
|
||||
TableName Identifier
|
||||
ColumnNames []string
|
||||
}
|
||||
|
||||
type TraceCopyFromEndData struct {
|
||||
CommandTag pgconn.CommandTag
|
||||
Err error
|
||||
}
|
||||
|
||||
// PrepareTracer traces Prepare.
|
||||
type PrepareTracer interface {
|
||||
// TracePrepareStart is called at the beginning of Prepare calls. The returned context is used for the
|
||||
// rest of the call and will be passed to TracePrepareEnd.
|
||||
TracePrepareStart(ctx context.Context, conn *Conn, data TracePrepareStartData) context.Context
|
||||
|
||||
TracePrepareEnd(ctx context.Context, conn *Conn, data TracePrepareEndData)
|
||||
}
|
||||
|
||||
type TracePrepareStartData struct {
|
||||
Name string
|
||||
SQL string
|
||||
}
|
||||
|
||||
type TracePrepareEndData struct {
|
||||
AlreadyPrepared bool
|
||||
Err error
|
||||
}
|
||||
|
||||
// ConnectTracer traces Connect and ConnectConfig.
|
||||
type ConnectTracer interface {
|
||||
// TraceConnectStart is called at the beginning of Connect and ConnectConfig calls. The returned context is used for
|
||||
// the rest of the call and will be passed to TraceConnectEnd.
|
||||
TraceConnectStart(ctx context.Context, data TraceConnectStartData) context.Context
|
||||
|
||||
TraceConnectEnd(ctx context.Context, data TraceConnectEndData)
|
||||
}
|
||||
|
||||
type TraceConnectStartData struct {
|
||||
ConnConfig *ConnConfig
|
||||
}
|
||||
|
||||
type TraceConnectEndData struct {
|
||||
Conn *Conn
|
||||
Err error
|
||||
}
|
|
@ -0,0 +1,538 @@
|
|||
package pgx_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxtest"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type testTracer struct {
|
||||
traceQueryStart func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context
|
||||
traceQueryEnd func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData)
|
||||
traceBatchStart func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context
|
||||
traceBatchQuery func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData)
|
||||
traceBatchEnd func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData)
|
||||
traceCopyFromStart func(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context
|
||||
traceCopyFromEnd func(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData)
|
||||
tracePrepareStart func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context
|
||||
tracePrepareEnd func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData)
|
||||
traceConnectStart func(ctx context.Context, data pgx.TraceConnectStartData) context.Context
|
||||
traceConnectEnd func(ctx context.Context, data pgx.TraceConnectEndData)
|
||||
}
|
||||
|
||||
func (tt *testTracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
|
||||
if tt.traceQueryStart != nil {
|
||||
return tt.traceQueryStart(ctx, conn, data)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (tt *testTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
|
||||
if tt.traceQueryEnd != nil {
|
||||
tt.traceQueryEnd(ctx, conn, data)
|
||||
}
|
||||
}
|
||||
|
||||
func (tt *testTracer) TraceBatchStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
|
||||
if tt.traceBatchStart != nil {
|
||||
return tt.traceBatchStart(ctx, conn, data)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (tt *testTracer) TraceBatchQuery(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
|
||||
if tt.traceBatchQuery != nil {
|
||||
tt.traceBatchQuery(ctx, conn, data)
|
||||
}
|
||||
}
|
||||
|
||||
func (tt *testTracer) TraceBatchEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
|
||||
if tt.traceBatchEnd != nil {
|
||||
tt.traceBatchEnd(ctx, conn, data)
|
||||
}
|
||||
}
|
||||
|
||||
func (tt *testTracer) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context {
|
||||
if tt.traceCopyFromStart != nil {
|
||||
return tt.traceCopyFromStart(ctx, conn, data)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (tt *testTracer) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) {
|
||||
if tt.traceCopyFromEnd != nil {
|
||||
tt.traceCopyFromEnd(ctx, conn, data)
|
||||
}
|
||||
}
|
||||
|
||||
func (tt *testTracer) TracePrepareStart(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context {
|
||||
if tt.tracePrepareStart != nil {
|
||||
return tt.tracePrepareStart(ctx, conn, data)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (tt *testTracer) TracePrepareEnd(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) {
|
||||
if tt.tracePrepareEnd != nil {
|
||||
tt.tracePrepareEnd(ctx, conn, data)
|
||||
}
|
||||
}
|
||||
|
||||
func (tt *testTracer) TraceConnectStart(ctx context.Context, data pgx.TraceConnectStartData) context.Context {
|
||||
if tt.traceConnectStart != nil {
|
||||
return tt.traceConnectStart(ctx, data)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (tt *testTracer) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEndData) {
|
||||
if tt.traceConnectEnd != nil {
|
||||
tt.traceConnectEnd(ctx, data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTraceExec(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracer := &testTracer{}
|
||||
|
||||
ctr := defaultConnTestRunner
|
||||
ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
|
||||
config := defaultConnTestRunner.CreateConfig(ctx, t)
|
||||
config.Tracer = tracer
|
||||
return config
|
||||
}
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
traceQueryStartCalled := false
|
||||
tracer.traceQueryStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
|
||||
traceQueryStartCalled = true
|
||||
require.Equal(t, `select $1::text`, data.SQL)
|
||||
require.Len(t, data.Args, 1)
|
||||
require.Equal(t, `testing`, data.Args[0])
|
||||
return context.WithValue(ctx, "fromTraceQueryStart", "foo")
|
||||
}
|
||||
|
||||
traceQueryEndCalled := false
|
||||
tracer.traceQueryEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
|
||||
traceQueryEndCalled = true
|
||||
require.Equal(t, "foo", ctx.Value("fromTraceQueryStart"))
|
||||
require.Equal(t, `SELECT 1`, data.CommandTag.String())
|
||||
require.NoError(t, data.Err)
|
||||
}
|
||||
|
||||
_, err := conn.Exec(ctx, `select $1::text`, "testing")
|
||||
require.NoError(t, err)
|
||||
require.True(t, traceQueryStartCalled)
|
||||
require.True(t, traceQueryEndCalled)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTraceQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracer := &testTracer{}
|
||||
|
||||
ctr := defaultConnTestRunner
|
||||
ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
|
||||
config := defaultConnTestRunner.CreateConfig(ctx, t)
|
||||
config.Tracer = tracer
|
||||
return config
|
||||
}
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
traceQueryStartCalled := false
|
||||
tracer.traceQueryStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
|
||||
traceQueryStartCalled = true
|
||||
require.Equal(t, `select $1::text`, data.SQL)
|
||||
require.Len(t, data.Args, 1)
|
||||
require.Equal(t, `testing`, data.Args[0])
|
||||
return context.WithValue(ctx, "fromTraceQueryStart", "foo")
|
||||
}
|
||||
|
||||
traceQueryEndCalled := false
|
||||
tracer.traceQueryEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
|
||||
traceQueryEndCalled = true
|
||||
require.Equal(t, "foo", ctx.Value("fromTraceQueryStart"))
|
||||
require.Equal(t, `SELECT 1`, data.CommandTag.String())
|
||||
require.NoError(t, data.Err)
|
||||
}
|
||||
|
||||
var s string
|
||||
err := conn.QueryRow(ctx, `select $1::text`, "testing").Scan(&s)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "testing", s)
|
||||
require.True(t, traceQueryStartCalled)
|
||||
require.True(t, traceQueryEndCalled)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTraceBatchNormal(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracer := &testTracer{}
|
||||
|
||||
ctr := defaultConnTestRunner
|
||||
ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
|
||||
config := defaultConnTestRunner.CreateConfig(ctx, t)
|
||||
config.Tracer = tracer
|
||||
return config
|
||||
}
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
traceBatchStartCalled := false
|
||||
tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
|
||||
traceBatchStartCalled = true
|
||||
require.NotNil(t, data.Batch)
|
||||
require.Equal(t, 2, data.Batch.Len())
|
||||
return context.WithValue(ctx, "fromTraceBatchStart", "foo")
|
||||
}
|
||||
|
||||
traceBatchQueryCalledCount := 0
|
||||
tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
|
||||
traceBatchQueryCalledCount++
|
||||
require.Equal(t, "foo", ctx.Value("fromTraceBatchStart"))
|
||||
require.NoError(t, data.Err)
|
||||
}
|
||||
|
||||
traceBatchEndCalled := false
|
||||
tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
|
||||
traceBatchEndCalled = true
|
||||
require.Equal(t, "foo", ctx.Value("fromTraceBatchStart"))
|
||||
require.NoError(t, data.Err)
|
||||
}
|
||||
|
||||
batch := &pgx.Batch{}
|
||||
batch.Queue(`select 1`)
|
||||
batch.Queue(`select 2`)
|
||||
|
||||
br := conn.SendBatch(context.Background(), batch)
|
||||
require.True(t, traceBatchStartCalled)
|
||||
|
||||
var n int32
|
||||
err := br.QueryRow().Scan(&n)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, n)
|
||||
require.EqualValues(t, 1, traceBatchQueryCalledCount)
|
||||
|
||||
err = br.QueryRow().Scan(&n)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 2, n)
|
||||
require.EqualValues(t, 2, traceBatchQueryCalledCount)
|
||||
|
||||
err = br.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
require.True(t, traceBatchEndCalled)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTraceBatchClose(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracer := &testTracer{}
|
||||
|
||||
ctr := defaultConnTestRunner
|
||||
ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
|
||||
config := defaultConnTestRunner.CreateConfig(ctx, t)
|
||||
config.Tracer = tracer
|
||||
return config
|
||||
}
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
traceBatchStartCalled := false
|
||||
tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
|
||||
traceBatchStartCalled = true
|
||||
require.NotNil(t, data.Batch)
|
||||
require.Equal(t, 2, data.Batch.Len())
|
||||
return context.WithValue(ctx, "fromTraceBatchStart", "foo")
|
||||
}
|
||||
|
||||
traceBatchQueryCalledCount := 0
|
||||
tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
|
||||
traceBatchQueryCalledCount++
|
||||
require.Equal(t, "foo", ctx.Value("fromTraceBatchStart"))
|
||||
require.NoError(t, data.Err)
|
||||
}
|
||||
|
||||
traceBatchEndCalled := false
|
||||
tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
|
||||
traceBatchEndCalled = true
|
||||
require.Equal(t, "foo", ctx.Value("fromTraceBatchStart"))
|
||||
require.NoError(t, data.Err)
|
||||
}
|
||||
|
||||
batch := &pgx.Batch{}
|
||||
batch.Queue(`select 1`)
|
||||
batch.Queue(`select 2`)
|
||||
|
||||
br := conn.SendBatch(context.Background(), batch)
|
||||
require.True(t, traceBatchStartCalled)
|
||||
err := br.Close()
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 2, traceBatchQueryCalledCount)
|
||||
require.True(t, traceBatchEndCalled)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTraceBatchErrorWhileReadingResults(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracer := &testTracer{}
|
||||
|
||||
ctr := defaultConnTestRunner
|
||||
ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
|
||||
config := defaultConnTestRunner.CreateConfig(ctx, t)
|
||||
config.Tracer = tracer
|
||||
return config
|
||||
}
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, []pgx.QueryExecMode{pgx.QueryExecModeSimpleProtocol}, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
traceBatchStartCalled := false
|
||||
tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
|
||||
traceBatchStartCalled = true
|
||||
require.NotNil(t, data.Batch)
|
||||
require.Equal(t, 3, data.Batch.Len())
|
||||
return context.WithValue(ctx, "fromTraceBatchStart", "foo")
|
||||
}
|
||||
|
||||
traceBatchQueryCalledCount := 0
|
||||
tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
|
||||
traceBatchQueryCalledCount++
|
||||
require.Equal(t, "foo", ctx.Value("fromTraceBatchStart"))
|
||||
if traceBatchQueryCalledCount == 2 {
|
||||
require.Error(t, data.Err)
|
||||
} else {
|
||||
require.NoError(t, data.Err)
|
||||
}
|
||||
}
|
||||
|
||||
traceBatchEndCalled := false
|
||||
tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
|
||||
traceBatchEndCalled = true
|
||||
require.Equal(t, "foo", ctx.Value("fromTraceBatchStart"))
|
||||
require.Error(t, data.Err)
|
||||
}
|
||||
|
||||
batch := &pgx.Batch{}
|
||||
batch.Queue(`select 1`)
|
||||
batch.Queue(`select 2/n-2 from generate_series(0,10) n`)
|
||||
batch.Queue(`select 3`)
|
||||
|
||||
br := conn.SendBatch(context.Background(), batch)
|
||||
require.True(t, traceBatchStartCalled)
|
||||
|
||||
commandTag, err := br.Exec()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "SELECT 1", commandTag.String())
|
||||
|
||||
commandTag, err = br.Exec()
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "", commandTag.String())
|
||||
|
||||
commandTag, err = br.Exec()
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "", commandTag.String())
|
||||
|
||||
err = br.Close()
|
||||
require.Error(t, err)
|
||||
require.EqualValues(t, 2, traceBatchQueryCalledCount)
|
||||
require.True(t, traceBatchEndCalled)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTraceBatchErrorWhileReadingResultsWhileClosing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracer := &testTracer{}
|
||||
|
||||
ctr := defaultConnTestRunner
|
||||
ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
|
||||
config := defaultConnTestRunner.CreateConfig(ctx, t)
|
||||
config.Tracer = tracer
|
||||
return config
|
||||
}
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, []pgx.QueryExecMode{pgx.QueryExecModeSimpleProtocol}, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
traceBatchStartCalled := false
|
||||
tracer.traceBatchStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
|
||||
traceBatchStartCalled = true
|
||||
require.NotNil(t, data.Batch)
|
||||
require.Equal(t, 3, data.Batch.Len())
|
||||
return context.WithValue(ctx, "fromTraceBatchStart", "foo")
|
||||
}
|
||||
|
||||
traceBatchQueryCalledCount := 0
|
||||
tracer.traceBatchQuery = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
|
||||
traceBatchQueryCalledCount++
|
||||
require.Equal(t, "foo", ctx.Value("fromTraceBatchStart"))
|
||||
if traceBatchQueryCalledCount == 2 {
|
||||
require.Error(t, data.Err)
|
||||
} else {
|
||||
require.NoError(t, data.Err)
|
||||
}
|
||||
}
|
||||
|
||||
traceBatchEndCalled := false
|
||||
tracer.traceBatchEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
|
||||
traceBatchEndCalled = true
|
||||
require.Equal(t, "foo", ctx.Value("fromTraceBatchStart"))
|
||||
require.Error(t, data.Err)
|
||||
}
|
||||
|
||||
batch := &pgx.Batch{}
|
||||
batch.Queue(`select 1`)
|
||||
batch.Queue(`select 2/n-2 from generate_series(0,10) n`)
|
||||
batch.Queue(`select 3`)
|
||||
|
||||
br := conn.SendBatch(context.Background(), batch)
|
||||
require.True(t, traceBatchStartCalled)
|
||||
err := br.Close()
|
||||
require.Error(t, err)
|
||||
require.EqualValues(t, 2, traceBatchQueryCalledCount)
|
||||
require.True(t, traceBatchEndCalled)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTraceCopyFrom(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracer := &testTracer{}
|
||||
|
||||
ctr := defaultConnTestRunner
|
||||
ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
|
||||
config := defaultConnTestRunner.CreateConfig(ctx, t)
|
||||
config.Tracer = tracer
|
||||
return config
|
||||
}
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
traceCopyFromStartCalled := false
|
||||
tracer.traceCopyFromStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context {
|
||||
traceCopyFromStartCalled = true
|
||||
require.Equal(t, pgx.Identifier{"foo"}, data.TableName)
|
||||
require.Equal(t, []string{"a"}, data.ColumnNames)
|
||||
return context.WithValue(ctx, "fromTraceCopyFromStart", "foo")
|
||||
}
|
||||
|
||||
traceCopyFromEndCalled := false
|
||||
tracer.traceCopyFromEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) {
|
||||
traceCopyFromEndCalled = true
|
||||
require.Equal(t, "foo", ctx.Value("fromTraceCopyFromStart"))
|
||||
require.Equal(t, `COPY 2`, data.CommandTag.String())
|
||||
require.NoError(t, data.Err)
|
||||
}
|
||||
|
||||
_, err := conn.Exec(context.Background(), `create temporary table foo(a int4)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
inputRows := [][]any{
|
||||
{int32(1)},
|
||||
{nil},
|
||||
}
|
||||
|
||||
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows))
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, len(inputRows), copyCount)
|
||||
require.True(t, traceCopyFromStartCalled)
|
||||
require.True(t, traceCopyFromEndCalled)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTracePrepare(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracer := &testTracer{}
|
||||
|
||||
ctr := defaultConnTestRunner
|
||||
ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
|
||||
config := defaultConnTestRunner.CreateConfig(ctx, t)
|
||||
config.Tracer = tracer
|
||||
return config
|
||||
}
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
tracePrepareStartCalled := false
|
||||
tracer.tracePrepareStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context {
|
||||
tracePrepareStartCalled = true
|
||||
require.Equal(t, `ps`, data.Name)
|
||||
require.Equal(t, `select $1::text`, data.SQL)
|
||||
return context.WithValue(ctx, "fromTracePrepareStart", "foo")
|
||||
}
|
||||
|
||||
tracePrepareEndCalled := false
|
||||
tracer.tracePrepareEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) {
|
||||
tracePrepareEndCalled = true
|
||||
require.False(t, data.AlreadyPrepared)
|
||||
require.NoError(t, data.Err)
|
||||
}
|
||||
|
||||
_, err := conn.Prepare(ctx, "ps", `select $1::text`)
|
||||
require.NoError(t, err)
|
||||
require.True(t, tracePrepareStartCalled)
|
||||
require.True(t, tracePrepareEndCalled)
|
||||
|
||||
tracePrepareStartCalled = false
|
||||
tracePrepareEndCalled = false
|
||||
tracer.tracePrepareEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) {
|
||||
tracePrepareEndCalled = true
|
||||
require.True(t, data.AlreadyPrepared)
|
||||
require.NoError(t, data.Err)
|
||||
}
|
||||
|
||||
_, err = conn.Prepare(ctx, "ps", `select $1::text`)
|
||||
require.NoError(t, err)
|
||||
require.True(t, tracePrepareStartCalled)
|
||||
require.True(t, tracePrepareEndCalled)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTraceConnect(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tracer := &testTracer{}
|
||||
|
||||
config := defaultConnTestRunner.CreateConfig(context.Background(), t)
|
||||
config.Tracer = tracer
|
||||
|
||||
traceConnectStartCalled := false
|
||||
tracer.traceConnectStart = func(ctx context.Context, data pgx.TraceConnectStartData) context.Context {
|
||||
traceConnectStartCalled = true
|
||||
require.NotNil(t, data.ConnConfig)
|
||||
return context.WithValue(ctx, "fromTraceConnectStart", "foo")
|
||||
}
|
||||
|
||||
traceConnectEndCalled := false
|
||||
tracer.traceConnectEnd = func(ctx context.Context, data pgx.TraceConnectEndData) {
|
||||
traceConnectEndCalled = true
|
||||
require.NotNil(t, data.Conn)
|
||||
require.NoError(t, data.Err)
|
||||
}
|
||||
|
||||
conn1, err := pgx.ConnectConfig(context.Background(), config)
|
||||
require.NoError(t, err)
|
||||
defer conn1.Close(context.Background())
|
||||
require.True(t, traceConnectStartCalled)
|
||||
require.True(t, traceConnectEndCalled)
|
||||
|
||||
config, err = pgx.ParseConfig("host=/invalid")
|
||||
require.NoError(t, err)
|
||||
config.Tracer = tracer
|
||||
|
||||
traceConnectStartCalled = false
|
||||
traceConnectEndCalled = false
|
||||
tracer.traceConnectEnd = func(ctx context.Context, data pgx.TraceConnectEndData) {
|
||||
traceConnectEndCalled = true
|
||||
require.Nil(t, data.Conn)
|
||||
require.Error(t, data.Err)
|
||||
}
|
||||
|
||||
conn2, err := pgx.ConnectConfig(context.Background(), config)
|
||||
require.Nil(t, conn2)
|
||||
require.Error(t, err)
|
||||
require.True(t, traceConnectStartCalled)
|
||||
require.True(t, traceConnectEndCalled)
|
||||
}
|
Loading…
Reference in New Issue