diff --git a/pgxpool/pool.go b/pgxpool/pool.go index fe0e914a..656dac25 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -506,9 +506,13 @@ func (p *Pool) createIdleResources(parentCtx context.Context, targetResources in // Acquire returns a connection (*Conn) from the Pool func (p *Pool) Acquire(ctx context.Context) (c *Conn, err error) { if p.acquireTracer != nil { - ctx = p.acquireTracer.TraceAcquireStart(ctx, TraceAcquireStartData{ConnConfig: p.config.ConnConfig}) + ctx = p.acquireTracer.TraceAcquireStart(ctx, p, TraceAcquireStartData{}) defer func() { - p.acquireTracer.TraceAcquireEnd(ctx, TraceAcquireEndData{Err: err}) + var conn *pgx.Conn + if c != nil { + conn = c.Conn() + } + p.acquireTracer.TraceAcquireEnd(ctx, p, TraceAcquireEndData{Conn: conn, Err: err}) }() } diff --git a/pgxpool/tracer.go b/pgxpool/tracer.go index 621091a8..ba740bf9 100644 --- a/pgxpool/tracer.go +++ b/pgxpool/tracer.go @@ -10,14 +10,14 @@ import ( type AcquireTracer interface { // TraceAcquireStart is called at the beginning of Acquire. // The returned context is used for the rest of the call and will be passed to the TraceAcquireEnd. - TraceAcquireStart(ctx context.Context, data TraceAcquireStartData) context.Context - TraceAcquireEnd(ctx context.Context, data TraceAcquireEndData) + TraceAcquireStart(ctx context.Context, pool *Pool, data TraceAcquireStartData) context.Context + // TraceAcquireEnd is called when a connection has been acquired + TraceAcquireEnd(ctx context.Context, pool *Pool, data TraceAcquireEndData) } -type TraceAcquireStartData struct { - ConnConfig *pgx.ConnConfig -} +type TraceAcquireStartData struct{} type TraceAcquireEndData struct { - Err error + Conn *pgx.Conn + Err error } diff --git a/pgxpool/tracer_test.go b/pgxpool/tracer_test.go index 8e3a17d1..c71f3eff 100644 --- a/pgxpool/tracer_test.go +++ b/pgxpool/tracer_test.go @@ -12,22 +12,22 @@ import ( ) type testTracer struct { - traceAcquireStart func(ctx context.Context, data pgxpool.TraceAcquireStartData) context.Context - traceAcquireEnd func(ctx context.Context, data pgxpool.TraceAcquireEndData) + traceAcquireStart func(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireStartData) context.Context + traceAcquireEnd func(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireEndData) } type ctxKey string -func (tt *testTracer) TraceAcquireStart(ctx context.Context, data pgxpool.TraceAcquireStartData) context.Context { +func (tt *testTracer) TraceAcquireStart(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireStartData) context.Context { if tt.traceAcquireStart != nil { - return tt.traceAcquireStart(ctx, data) + return tt.traceAcquireStart(ctx, pool, data) } return ctx } -func (tt *testTracer) TraceAcquireEnd(ctx context.Context, data pgxpool.TraceAcquireEndData) { +func (tt *testTracer) TraceAcquireEnd(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireEndData) { if tt.traceAcquireEnd != nil { - tt.traceAcquireEnd(ctx, data) + tt.traceAcquireEnd(ctx, pool, data) } } @@ -55,16 +55,18 @@ func TestTraceAcquire(t *testing.T) { defer pool.Close() traceAcquireStartCalled := false - tracer.traceAcquireStart = func(ctx context.Context, data pgxpool.TraceAcquireStartData) context.Context { + tracer.traceAcquireStart = func(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireStartData) context.Context { traceAcquireStartCalled = true - require.NotNil(t, data.ConnConfig) + require.NotNil(t, pool) return context.WithValue(ctx, ctxKey("fromTraceAcquireStart"), "foo") } traceAcquireEndCalled := false - tracer.traceAcquireEnd = func(ctx context.Context, data pgxpool.TraceAcquireEndData) { + tracer.traceAcquireEnd = func(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireEndData) { traceAcquireEndCalled = true require.Equal(t, "foo", ctx.Value(ctxKey(ctxKey("fromTraceAcquireStart")))) + require.NotNil(t, pool) + require.NotNil(t, data.Conn) require.NoError(t, data.Err) } @@ -76,14 +78,16 @@ func TestTraceAcquire(t *testing.T) { traceAcquireStartCalled = false traceAcquireEndCalled = false - tracer.traceAcquireEnd = func(ctx context.Context, data pgxpool.TraceAcquireEndData) { + tracer.traceAcquireEnd = func(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireEndData) { traceAcquireEndCalled = true + require.NotNil(t, pool) + require.Nil(t, data.Conn) require.Error(t, data.Err) } ctx, cancel = context.WithCancel(ctx) cancel() - c, err = pool.Acquire(ctx) + _, err = pool.Acquire(ctx) require.ErrorIs(t, err, context.Canceled) require.True(t, traceAcquireStartCalled) require.True(t, traceAcquireEndCalled)