package stmtcache_test import ( "context" "fmt" "math/rand" "os" "regexp" "testing" "time" "github.com/jackc/pgx/v5/internal/stmtcache" "github.com/jackc/pgx/v5/pgconn" "github.com/stretchr/testify/require" ) func TestLRUModePrepare(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer conn.Close(ctx) cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2) require.EqualValues(t, 0, cache.Len()) require.EqualValues(t, 2, cache.Cap()) require.EqualValues(t, stmtcache.ModePrepare, cache.Mode()) psd, err := cache.Get(ctx, "select 1") require.NoError(t, err) require.NotNil(t, psd) require.EqualValues(t, 1, cache.Len()) require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn)) psd, err = cache.Get(ctx, "select 1") require.NoError(t, err) require.NotNil(t, psd) require.EqualValues(t, 1, cache.Len()) require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn)) psd, err = cache.Get(ctx, "select 2") require.NoError(t, err) require.NotNil(t, psd) require.EqualValues(t, 2, cache.Len()) require.ElementsMatch(t, []string{"select 1", "select 2"}, fetchServerStatements(t, ctx, conn)) psd, err = cache.Get(ctx, "select 3") require.NoError(t, err) require.NotNil(t, psd) require.EqualValues(t, 2, cache.Len()) require.ElementsMatch(t, []string{"select 2", "select 3"}, fetchServerStatements(t, ctx, conn)) err = cache.Clear(ctx) require.NoError(t, err) require.EqualValues(t, 0, cache.Len()) require.Empty(t, fetchServerStatements(t, ctx, conn)) } func TestLRUStmtInvalidation(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer conn.Close(ctx) // we construct a fake error because its not super straightforward to actually call // a prepared statement from the LRU cache without the helper routines which live // in pgx proper. fakeInvalidCachePlanError := &pgconn.PgError{ Severity: "ERROR", Code: "0A000", Message: "cached plan must not change result type", } cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2) // // outside of a transaction, we eagerly flush the statement // _, err = cache.Get(ctx, "select 1") require.NoError(t, err) require.EqualValues(t, 1, cache.Len()) require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn)) cache.StatementErrored("select 1", fakeInvalidCachePlanError) _, err = cache.Get(ctx, "select 2") require.NoError(t, err) require.EqualValues(t, 1, cache.Len()) require.ElementsMatch(t, []string{"select 2"}, fetchServerStatements(t, ctx, conn)) err = cache.Clear(ctx) require.NoError(t, err) // // within an errored transaction, we defer the flush to after the first get // that happens after the transaction is rolled back // _, err = cache.Get(ctx, "select 1") require.NoError(t, err) require.EqualValues(t, 1, cache.Len()) require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn)) res := conn.Exec(ctx, "begin") require.NoError(t, res.Close()) require.Equal(t, byte('T'), conn.TxStatus()) res = conn.Exec(ctx, "selec") require.Error(t, res.Close()) require.Equal(t, byte('E'), conn.TxStatus()) cache.StatementErrored("select 1", fakeInvalidCachePlanError) require.EqualValues(t, 1, cache.Len()) res = conn.Exec(ctx, "rollback") require.NoError(t, res.Close()) _, err = cache.Get(ctx, "select 2") require.EqualValues(t, 1, cache.Len()) require.ElementsMatch(t, []string{"select 2"}, fetchServerStatements(t, ctx, conn)) } func TestLRUStmtInvalidationIntegration(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer conn.Close(ctx) cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2) result := conn.ExecParams(ctx, "create temporary table stmtcache_table (a text)", nil, nil, nil, nil).Read() require.NoError(t, result.Err) sql := "select * from stmtcache_table" sd1, err := cache.Get(ctx, sql) require.NoError(t, err) result = conn.ExecPrepared(ctx, sd1.Name, nil, nil, nil).Read() require.NoError(t, result.Err) result = conn.ExecParams(ctx, "alter table stmtcache_table add column b text", nil, nil, nil, nil).Read() require.NoError(t, result.Err) result = conn.ExecPrepared(ctx, sd1.Name, nil, nil, nil).Read() require.EqualError(t, result.Err, "ERROR: cached plan must not change result type (SQLSTATE 0A000)") cache.StatementErrored(sql, result.Err) sd2, err := cache.Get(ctx, sql) require.NoError(t, err) require.NotEqual(t, sd1.Name, sd2.Name) result = conn.ExecPrepared(ctx, sd2.Name, nil, nil, nil).Read() require.NoError(t, result.Err) } func TestLRUModePrepareStress(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer conn.Close(ctx) cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 8) require.EqualValues(t, 0, cache.Len()) require.EqualValues(t, 8, cache.Cap()) require.EqualValues(t, stmtcache.ModePrepare, cache.Mode()) for i := 0; i < 1000; i++ { psd, err := cache.Get(ctx, fmt.Sprintf("select %d", rand.Intn(50))) require.NoError(t, err) require.NotNil(t, psd) result := conn.ExecPrepared(ctx, psd.Name, nil, nil, nil).Read() require.NoError(t, result.Err) } } func TestLRUModeDescribe(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer conn.Close(ctx) cache := stmtcache.NewLRU(conn, stmtcache.ModeDescribe, 2) require.EqualValues(t, 0, cache.Len()) require.EqualValues(t, 2, cache.Cap()) require.EqualValues(t, stmtcache.ModeDescribe, cache.Mode()) psd, err := cache.Get(ctx, "select 1") require.NoError(t, err) require.NotNil(t, psd) require.EqualValues(t, 1, cache.Len()) require.Empty(t, fetchServerStatements(t, ctx, conn)) psd, err = cache.Get(ctx, "select 1") require.NoError(t, err) require.NotNil(t, psd) require.EqualValues(t, 1, cache.Len()) require.Empty(t, fetchServerStatements(t, ctx, conn)) psd, err = cache.Get(ctx, "select 2") require.NoError(t, err) require.NotNil(t, psd) require.EqualValues(t, 2, cache.Len()) require.Empty(t, fetchServerStatements(t, ctx, conn)) psd, err = cache.Get(ctx, "select 3") require.NoError(t, err) require.NotNil(t, psd) require.EqualValues(t, 2, cache.Len()) require.Empty(t, fetchServerStatements(t, ctx, conn)) err = cache.Clear(ctx) require.NoError(t, err) require.EqualValues(t, 0, cache.Len()) require.Empty(t, fetchServerStatements(t, ctx, conn)) } func TestLRUContext(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer conn.Close(ctx) cache := stmtcache.NewLRU(conn, stmtcache.ModeDescribe, 2) // test 1 : getting a value for the first time with a cancelled context returns an error ctx1, cancel1 := context.WithCancel(ctx) cancel1() desc, err := cache.Get(ctx1, "SELECT 1") require.Error(t, err) require.Nil(t, desc) // test 2 : when querying for the 2nd time a cached value, if the context is canceled return an error ctx2, cancel2 := context.WithCancel(ctx) desc, err = cache.Get(ctx2, "SELECT 2") require.NoError(t, err) require.NotNil(t, desc) cancel2() desc, err = cache.Get(ctx2, "SELECT 2") require.Error(t, err) require.Nil(t, desc) } func fetchServerStatements(t testing.TB, ctx context.Context, conn *pgconn.PgConn) []string { result := conn.ExecParams(ctx, `select statement from pg_prepared_statements`, nil, nil, nil, nil).Read() require.NoError(t, result.Err) var statements []string for _, r := range result.Rows { statement := string(r[0]) if conn.ParameterStatus("crdb_version") != "" { if statement == "PREPARE AS select statement from pg_prepared_statements" { // CockroachDB includes the currently running unnamed prepared statement while PostgreSQL does not. Ignore it. continue } // CockroachDB includes the "PREPARE ... AS" text in the statement even if it was prepared through the extended // protocol will PostgreSQL does not. Normalize the statement. re := regexp.MustCompile(`^PREPARE lrupsc[0-9_]+ AS `) statement = re.ReplaceAllString(statement, "") } statements = append(statements, statement) } return statements }