diff --git a/pgxpool/pool.go b/pgxpool/pool.go index 6998e7e8..fe0e914a 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -95,6 +95,8 @@ type Pool struct { healthCheckChan chan struct{} + acquireTracer AcquireTracer + closeOnce sync.Once closeChan chan struct{} } @@ -195,6 +197,10 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { closeChan: make(chan struct{}), } + if t, ok := config.ConnConfig.Tracer.(AcquireTracer); ok { + p.acquireTracer = t + } + var err error p.p, err = puddle.NewPool( &puddle.Config[*connResource]{ @@ -498,7 +504,14 @@ func (p *Pool) createIdleResources(parentCtx context.Context, targetResources in } // Acquire returns a connection (*Conn) from the Pool -func (p *Pool) Acquire(ctx context.Context) (*Conn, error) { +func (p *Pool) Acquire(ctx context.Context) (c *Conn, err error) { + if p.acquireTracer != nil { + ctx = p.acquireTracer.TraceAcquireStart(ctx, TraceAcquireStartData{ConnConfig: p.config.ConnConfig}) + defer func() { + p.acquireTracer.TraceAcquireEnd(ctx, TraceAcquireEndData{Err: err}) + }() + } + for { res, err := p.p.Acquire(ctx) if err != nil { diff --git a/pgxpool/tracer.go b/pgxpool/tracer.go new file mode 100644 index 00000000..621091a8 --- /dev/null +++ b/pgxpool/tracer.go @@ -0,0 +1,23 @@ +package pgxpool + +import ( + "context" + + "github.com/jackc/pgx/v5" +) + +// AcquireTracer traces Acquire. +type AcquireTracer interface { + // TraceAcquireStart is called at the beginning of Acquire. + // The returned context is used for the rest of the call and will be passed to the TraceAcquireEnd. + TraceAcquireStart(ctx context.Context, data TraceAcquireStartData) context.Context + TraceAcquireEnd(ctx context.Context, data TraceAcquireEndData) +} + +type TraceAcquireStartData struct { + ConnConfig *pgx.ConnConfig +} + +type TraceAcquireEndData struct { + Err error +} diff --git a/pgxpool/tracer_test.go b/pgxpool/tracer_test.go new file mode 100644 index 00000000..8e3a17d1 --- /dev/null +++ b/pgxpool/tracer_test.go @@ -0,0 +1,90 @@ +package pgxpool_test + +import ( + "context" + "os" + "testing" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/stretchr/testify/require" +) + +type testTracer struct { + traceAcquireStart func(ctx context.Context, data pgxpool.TraceAcquireStartData) context.Context + traceAcquireEnd func(ctx context.Context, data pgxpool.TraceAcquireEndData) +} + +type ctxKey string + +func (tt *testTracer) TraceAcquireStart(ctx context.Context, data pgxpool.TraceAcquireStartData) context.Context { + if tt.traceAcquireStart != nil { + return tt.traceAcquireStart(ctx, data) + } + return ctx +} + +func (tt *testTracer) TraceAcquireEnd(ctx context.Context, data pgxpool.TraceAcquireEndData) { + if tt.traceAcquireEnd != nil { + tt.traceAcquireEnd(ctx, data) + } +} + +func (tt *testTracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context { + return ctx +} + +func (tt *testTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) { +} + +func TestTraceAcquire(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + config.ConnConfig.Tracer = tracer + + pool, err := pgxpool.NewWithConfig(ctx, config) + require.NoError(t, err) + defer pool.Close() + + traceAcquireStartCalled := false + tracer.traceAcquireStart = func(ctx context.Context, data pgxpool.TraceAcquireStartData) context.Context { + traceAcquireStartCalled = true + require.NotNil(t, data.ConnConfig) + return context.WithValue(ctx, ctxKey("fromTraceAcquireStart"), "foo") + } + + traceAcquireEndCalled := false + tracer.traceAcquireEnd = func(ctx context.Context, data pgxpool.TraceAcquireEndData) { + traceAcquireEndCalled = true + require.Equal(t, "foo", ctx.Value(ctxKey(ctxKey("fromTraceAcquireStart")))) + require.NoError(t, data.Err) + } + + c, err := pool.Acquire(ctx) + require.NoError(t, err) + defer c.Release() + require.True(t, traceAcquireStartCalled) + require.True(t, traceAcquireEndCalled) + + traceAcquireStartCalled = false + traceAcquireEndCalled = false + tracer.traceAcquireEnd = func(ctx context.Context, data pgxpool.TraceAcquireEndData) { + traceAcquireEndCalled = true + require.Error(t, data.Err) + } + + ctx, cancel = context.WithCancel(ctx) + cancel() + c, err = pool.Acquire(ctx) + require.ErrorIs(t, err, context.Canceled) + require.True(t, traceAcquireStartCalled) + require.True(t, traceAcquireEndCalled) +}