diff --git a/stmtcache/lru_test.go b/stmtcache/lru_test.go index a4108155..f594ceac 100644 --- a/stmtcache/lru_test.go +++ b/stmtcache/lru_test.go @@ -235,6 +235,40 @@ func TestLRUModeDescribe(t *testing.T) { require.Empty(t, fetchServerStatements(t, ctx, conn)) } +func TestLRUContext(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) + require.NoError(t, err) + defer conn.Close(ctx) + + cache := stmtcache.NewLRU(conn, stmtcache.ModeDescribe, 2) + + // test 1 : getting a value for the first time with a cancelled context returns an error + ctx1, cancel1 := context.WithCancel(ctx) + cancel1() + + desc, err := cache.Get(ctx1, "SELECT 1") + require.Error(t, err) + require.Nil(t, desc) + + // test 2 : when querying for the 2nd time a cached value, if the context is canceled return an error + ctx2, cancel2 := context.WithCancel(ctx) + + desc, err = cache.Get(ctx2, "SELECT 2") + require.NoError(t, err) + require.NotNil(t, desc) + + cancel2() + + desc, err = cache.Get(ctx2, "SELECT 2") + require.Error(t, err) + require.Nil(t, desc) +} + func fetchServerStatements(t testing.TB, ctx context.Context, conn *pgconn.PgConn) []string { result := conn.ExecParams(ctx, `select statement from pg_prepared_statements`, nil, nil, nil, nil).Read() require.NoError(t, result.Err)