From 1df45d758d4f57b45d56fbce1e5bb8cf9aabaaf0 Mon Sep 17 00:00:00 2001 From: Ethan Pailes Date: Mon, 9 Nov 2020 08:48:57 -0500 Subject: [PATCH] fix stmtcache invalidation This patch fixes jackc/pgx#841. The meat of the fix lives in [a PR to the pgconn repo][1]. This change just checks for errors after executing a prepared statement and informs the underlying stmtcache about them so that it can properly clean up. We don't try to get fancy with retries or anything like that, just return the error and allow the application to handle it. I had to make [some][1] [changes][2] to to the jackc/pgconn package as well as this package. Fixes #841 [1]: https://github.com/jackc/pgconn/pull/56 [2]: https://github.com/jackc/pgconn/pull/55 --- conn.go | 1 + conn_test.go | 136 +++++++++++++++++++++++++++++++++++++++++++++++++++ go.mod | 2 +- go.sum | 4 +- rows.go | 10 +++- 5 files changed, 148 insertions(+), 5 deletions(-) diff --git a/conn.go b/conn.go index 28b0f87c..6c6d545f 100644 --- a/conn.go +++ b/conn.go @@ -529,6 +529,7 @@ func (c *Conn) getRows(ctx context.Context, sql string, args []interface{}) *con r.startTime = time.Now() r.sql = sql r.args = args + r.conn = c return r } diff --git a/conn_test.go b/conn_test.go index ba4eda95..592ab8c8 100644 --- a/conn_test.go +++ b/conn_test.go @@ -879,3 +879,139 @@ func TestDomainType(t *testing.T) { } }) } + +func TestStmtCacheInvalidationConn(t *testing.T) { + ctx := context.Background() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + // create a table and fill it with some data + _, err := conn.Exec(ctx, ` + DROP TABLE IF EXISTS drop_cols; + CREATE TABLE drop_cols ( + id SERIAL PRIMARY KEY NOT NULL, + f1 int NOT NULL, + f2 int NOT NULL + ); + `) + require.NoError(t, err) + _, err = conn.Exec(ctx, "INSERT INTO drop_cols (f1, f2) VALUES (1, 2)") + require.NoError(t, err) + + getSQL := "SELECT * FROM drop_cols WHERE id = $1" + + // This query will populate the statement cache. We don't care about the result. + rows, err := conn.Query(ctx, getSQL, 1) + require.NoError(t, err) + rows.Close() + + // Now, change the schema of the table out from under the statement, making it invalid. + _, err = conn.Exec(ctx, "ALTER TABLE drop_cols DROP COLUMN f1") + require.NoError(t, err) + + // We must get an error the first time we try to re-execute a bad statement. + // It is up to the application to determine if it wants to try again. We punt to + // the application because there is no clear recovery path in the case of failed transactions + // or batch operations and because automatic retry is tricky and we don't want to get + // it wrong at such an importaint layer of the stack. + rows, err = conn.Query(ctx, getSQL, 1) + require.NoError(t, err) + rows.Next() + nextErr := rows.Err() + rows.Close() + for _, err := range []error{nextErr, rows.Err()} { + if err == nil { + t.Fatal("expected InvalidCachedStatementPlanError: no error") + } + if !strings.Contains(err.Error(), "cached plan must not change result type") { + t.Fatalf("expected InvalidCachedStatementPlanError, got: %s", err.Error()) + } + } + + // On retry, the statement should have been flushed from the cache. + rows, err = conn.Query(ctx, getSQL, 1) + require.NoError(t, err) + rows.Next() + err = rows.Err() + require.NoError(t, err) + rows.Close() + require.NoError(t, rows.Err()) + + ensureConnValid(t, conn) +} + +func TestStmtCacheInvalidationTx(t *testing.T) { + ctx := context.Background() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + // create a table and fill it with some data + _, err := conn.Exec(ctx, ` + DROP TABLE IF EXISTS drop_cols; + CREATE TABLE drop_cols ( + id SERIAL PRIMARY KEY NOT NULL, + f1 int NOT NULL, + f2 int NOT NULL + ); + `) + require.NoError(t, err) + _, err = conn.Exec(ctx, "INSERT INTO drop_cols (f1, f2) VALUES (1, 2)") + require.NoError(t, err) + + tx, err := conn.Begin(ctx) + require.NoError(t, err) + + getSQL := "SELECT * FROM drop_cols WHERE id = $1" + + // This query will populate the statement cache. We don't care about the result. + rows, err := tx.Query(ctx, getSQL, 1) + require.NoError(t, err) + rows.Close() + + // Now, change the schema of the table out from under the statement, making it invalid. + _, err = tx.Exec(ctx, "ALTER TABLE drop_cols DROP COLUMN f1") + require.NoError(t, err) + + // We must get an error the first time we try to re-execute a bad statement. + // It is up to the application to determine if it wants to try again. We punt to + // the application because there is no clear recovery path in the case of failed transactions + // or batch operations and because automatic retry is tricky and we don't want to get + // it wrong at such an importaint layer of the stack. + rows, err = tx.Query(ctx, getSQL, 1) + require.NoError(t, err) + rows.Next() + nextErr := rows.Err() + rows.Close() + for _, err := range []error{nextErr, rows.Err()} { + if err == nil { + t.Fatal("expected InvalidCachedStatementPlanError: no error") + } + if !strings.Contains(err.Error(), "cached plan must not change result type") { + t.Fatalf("expected InvalidCachedStatementPlanError, got: %s", err.Error()) + } + } + + rows, err = tx.Query(ctx, getSQL, 1) + require.NoError(t, err) // error does not pop up immediately + rows.Next() + err = rows.Err() + // Retries within the same transaction are errors (really anything except a rollbakc + // will be an error in this transaction). + require.Error(t, err) + rows.Close() + + err = tx.Rollback(ctx) + require.NoError(t, err) + + // once we've rolled back, retries will work + rows, err = conn.Query(ctx, getSQL, 1) + require.NoError(t, err) + rows.Next() + err = rows.Err() + require.NoError(t, err) + rows.Close() + + ensureConnValid(t, conn) +} diff --git a/go.mod b/go.mod index cfe30794..ad34a683 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.12 require ( github.com/cockroachdb/apd v1.1.0 github.com/gofrs/uuid v3.2.0+incompatible - github.com/jackc/pgconn v1.7.2 + github.com/jackc/pgconn v1.7.3-0.20201111215259-cba610c24526 github.com/jackc/pgio v1.0.0 github.com/jackc/pgproto3/v2 v2.0.6 github.com/jackc/pgtype v1.6.1 diff --git a/go.sum b/go.sum index f8f6d30b..c8dfa0d2 100644 --- a/go.sum +++ b/go.sum @@ -50,8 +50,8 @@ github.com/jackc/pgconn v1.7.0 h1:pwjzcYyfmz/HQOQlENvG1OcDqauTGaqlVahq934F0/U= github.com/jackc/pgconn v1.7.0/go.mod h1:sF/lPpNEMEOp+IYhyQGdAvrG20gWf6A1tKlr0v7JMeA= github.com/jackc/pgconn v1.7.1 h1:Ii3hORkg9yTX+8etl2LtfFnL+YzmnR6VSLeTflQBkaQ= github.com/jackc/pgconn v1.7.1/go.mod h1:sF/lPpNEMEOp+IYhyQGdAvrG20gWf6A1tKlr0v7JMeA= -github.com/jackc/pgconn v1.7.2 h1:195tt17jkjy+FrFlY0pgyrul5kRLb7BGXY3JTrNxeXU= -github.com/jackc/pgconn v1.7.2/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= +github.com/jackc/pgconn v1.7.3-0.20201111215259-cba610c24526 h1:5u4fYBcaCLuQFvquOCBaT2a7KLbUGgNowbOLgVz6DWI= +github.com/jackc/pgconn v1.7.3-0.20201111215259-cba610c24526/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2 h1:JVX6jT/XfzNqIjye4717ITLaNwV9mWbJx0dLCpcRzdA= diff --git a/rows.go b/rows.go index 957192f6..88949b45 100644 --- a/rows.go +++ b/rows.go @@ -106,6 +106,7 @@ type connRows struct { sql string args []interface{} closed bool + conn *Conn resultReader *pgconn.ResultReader multiResultReader *pgconn.MultiResultReader @@ -145,8 +146,13 @@ func (rows *connRows) Close() { endTime := time.Now() rows.logger.log(rows.ctx, LogLevelInfo, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount}) } - } else if rows.logger.shouldLog(LogLevelError) { - rows.logger.log(rows.ctx, LogLevelError, "Query", map[string]interface{}{"err": rows.err, "sql": rows.sql, "args": logQueryArgs(rows.args)}) + } else { + if rows.logger.shouldLog(LogLevelError) { + rows.logger.log(rows.ctx, LogLevelError, "Query", map[string]interface{}{"err": rows.err, "sql": rows.sql, "args": logQueryArgs(rows.args)}) + } + if rows.err != nil { + rows.conn.stmtcache.StatementErrored(rows.sql, rows.err) + } } } }