add pool to trace acquire

pull/2017/head
ngavinsir 2024-05-11 10:32:18 +07:00 committed by Jack Christensen
parent a39632db43
commit 19fcb54564
3 changed files with 27 additions and 19 deletions

View File

@ -506,9 +506,13 @@ func (p *Pool) createIdleResources(parentCtx context.Context, targetResources in
// Acquire returns a connection (*Conn) from the Pool
func (p *Pool) Acquire(ctx context.Context) (c *Conn, err error) {
if p.acquireTracer != nil {
ctx = p.acquireTracer.TraceAcquireStart(ctx, TraceAcquireStartData{ConnConfig: p.config.ConnConfig})
ctx = p.acquireTracer.TraceAcquireStart(ctx, p, TraceAcquireStartData{})
defer func() {
p.acquireTracer.TraceAcquireEnd(ctx, TraceAcquireEndData{Err: err})
var conn *pgx.Conn
if c != nil {
conn = c.Conn()
}
p.acquireTracer.TraceAcquireEnd(ctx, p, TraceAcquireEndData{Conn: conn, Err: err})
}()
}

View File

@ -10,14 +10,14 @@ import (
type AcquireTracer interface {
// TraceAcquireStart is called at the beginning of Acquire.
// The returned context is used for the rest of the call and will be passed to the TraceAcquireEnd.
TraceAcquireStart(ctx context.Context, data TraceAcquireStartData) context.Context
TraceAcquireEnd(ctx context.Context, data TraceAcquireEndData)
TraceAcquireStart(ctx context.Context, pool *Pool, data TraceAcquireStartData) context.Context
// TraceAcquireEnd is called when a connection has been acquired
TraceAcquireEnd(ctx context.Context, pool *Pool, data TraceAcquireEndData)
}
type TraceAcquireStartData struct {
ConnConfig *pgx.ConnConfig
}
type TraceAcquireStartData struct{}
type TraceAcquireEndData struct {
Err error
Conn *pgx.Conn
Err error
}

View File

@ -12,22 +12,22 @@ import (
)
type testTracer struct {
traceAcquireStart func(ctx context.Context, data pgxpool.TraceAcquireStartData) context.Context
traceAcquireEnd func(ctx context.Context, data pgxpool.TraceAcquireEndData)
traceAcquireStart func(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireStartData) context.Context
traceAcquireEnd func(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireEndData)
}
type ctxKey string
func (tt *testTracer) TraceAcquireStart(ctx context.Context, data pgxpool.TraceAcquireStartData) context.Context {
func (tt *testTracer) TraceAcquireStart(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireStartData) context.Context {
if tt.traceAcquireStart != nil {
return tt.traceAcquireStart(ctx, data)
return tt.traceAcquireStart(ctx, pool, data)
}
return ctx
}
func (tt *testTracer) TraceAcquireEnd(ctx context.Context, data pgxpool.TraceAcquireEndData) {
func (tt *testTracer) TraceAcquireEnd(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireEndData) {
if tt.traceAcquireEnd != nil {
tt.traceAcquireEnd(ctx, data)
tt.traceAcquireEnd(ctx, pool, data)
}
}
@ -55,16 +55,18 @@ func TestTraceAcquire(t *testing.T) {
defer pool.Close()
traceAcquireStartCalled := false
tracer.traceAcquireStart = func(ctx context.Context, data pgxpool.TraceAcquireStartData) context.Context {
tracer.traceAcquireStart = func(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireStartData) context.Context {
traceAcquireStartCalled = true
require.NotNil(t, data.ConnConfig)
require.NotNil(t, pool)
return context.WithValue(ctx, ctxKey("fromTraceAcquireStart"), "foo")
}
traceAcquireEndCalled := false
tracer.traceAcquireEnd = func(ctx context.Context, data pgxpool.TraceAcquireEndData) {
tracer.traceAcquireEnd = func(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireEndData) {
traceAcquireEndCalled = true
require.Equal(t, "foo", ctx.Value(ctxKey(ctxKey("fromTraceAcquireStart"))))
require.NotNil(t, pool)
require.NotNil(t, data.Conn)
require.NoError(t, data.Err)
}
@ -76,14 +78,16 @@ func TestTraceAcquire(t *testing.T) {
traceAcquireStartCalled = false
traceAcquireEndCalled = false
tracer.traceAcquireEnd = func(ctx context.Context, data pgxpool.TraceAcquireEndData) {
tracer.traceAcquireEnd = func(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireEndData) {
traceAcquireEndCalled = true
require.NotNil(t, pool)
require.Nil(t, data.Conn)
require.Error(t, data.Err)
}
ctx, cancel = context.WithCancel(ctx)
cancel()
c, err = pool.Acquire(ctx)
_, err = pool.Acquire(ctx)
require.ErrorIs(t, err, context.Canceled)
require.True(t, traceAcquireStartCalled)
require.True(t, traceAcquireEndCalled)