From 7d3b9c1e44a6bfbc6a432b1fd4ddc7f6bf0073b9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 29 Oct 2022 09:32:35 -0500 Subject: [PATCH] QueryRewriter.RewriteQuery now returns an error https://github.com/jackc/pgx/issues/1186#issuecomment-1288207250 --- conn.go | 24 ++++++++++++++++++++---- conn_test.go | 4 ++-- named_args.go | 8 ++++---- named_args_test.go | 4 +++- 4 files changed, 29 insertions(+), 11 deletions(-) diff --git a/conn.go b/conn.go index a95cff21..4a68893a 100644 --- a/conn.go +++ b/conn.go @@ -407,7 +407,10 @@ optionLoop: } if queryRewriter != nil { - sql, arguments = queryRewriter.RewriteQuery(ctx, c, sql, arguments) + sql, arguments, err = queryRewriter.RewriteQuery(ctx, c, sql, arguments) + if err != nil { + return pgconn.CommandTag{}, fmt.Errorf("rewrite query failed: %v", err) + } } // Always use simple protocol when there are no arguments. @@ -600,7 +603,7 @@ type QueryResultFormatsByOID map[uint32]int16 // QueryRewriter rewrites a query when used as the first arguments to a query method. type QueryRewriter interface { - RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any) + RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) } // Query sends a query to the server and returns a Rows to read the results. Only errors encountered sending the query @@ -659,7 +662,16 @@ optionLoop: } if queryRewriter != nil { - sql, args = queryRewriter.RewriteQuery(ctx, c, sql, args) + var err error + originalSQL := sql + originalArgs := args + sql, args, err = queryRewriter.RewriteQuery(ctx, c, sql, args) + if err != nil { + rows := c.getRows(ctx, originalSQL, originalArgs) + err = fmt.Errorf("rewrite query failed: %v", err) + rows.fatal(err) + return rows, err + } } // Bypass any statement caching. @@ -826,7 +838,11 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) { } if queryRewriter != nil { - sql, arguments = queryRewriter.RewriteQuery(ctx, c, sql, arguments) + var err error + sql, arguments, err = queryRewriter.RewriteQuery(ctx, c, sql, arguments) + if err != nil { + return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("rewrite query failed: %v", err)} + } } bi.query = sql diff --git a/conn_test.go b/conn_test.go index 90382519..204ff615 100644 --- a/conn_test.go +++ b/conn_test.go @@ -236,8 +236,8 @@ type testQueryRewriter struct { args []any } -func (qr *testQueryRewriter) RewriteQuery(ctx context.Context, conn *pgx.Conn, sql string, args []any) (newSQL string, newArgs []any) { - return qr.sql, qr.args +func (qr *testQueryRewriter) RewriteQuery(ctx context.Context, conn *pgx.Conn, sql string, args []any) (newSQL string, newArgs []any, err error) { + return qr.sql, qr.args, nil } func TestExecWithQueryRewriter(t *testing.T) { diff --git a/named_args.go b/named_args.go index 391fb8cc..1bc32337 100644 --- a/named_args.go +++ b/named_args.go @@ -12,12 +12,12 @@ import ( // // For example, the following two queries are equivalent: // -// conn.Query(ctx, "select * from widgets where foo = @foo and bar = @bar", pgx.NamedArgs{"foo": 1, "bar": 2}) -// conn.Query(ctx, "select * from widgets where foo = $1 and bar = $2", 1, 2) +// conn.Query(ctx, "select * from widgets where foo = @foo and bar = @bar", pgx.NamedArgs{"foo": 1, "bar": 2}) +// conn.Query(ctx, "select * from widgets where foo = $1 and bar = $2", 1, 2) type NamedArgs map[string]any // RewriteQuery implements the QueryRewriter interface. -func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any) { +func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) { l := &sqlLexer{ src: sql, stateFn: rawState, @@ -44,7 +44,7 @@ func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, ar newArgs[ordinal-1] = na[string(name)] } - return sb.String(), newArgs + return sb.String(), newArgs, nil } type namedArg string diff --git a/named_args_test.go b/named_args_test.go index 116e03dc..bd54faa1 100644 --- a/named_args_test.go +++ b/named_args_test.go @@ -6,6 +6,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNamedArgsRewriteQuery(t *testing.T) { @@ -95,7 +96,8 @@ func TestNamedArgsRewriteQuery(t *testing.T) { // test comments and quotes } { - sql, args := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, tt.args) + sql, args, err := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, tt.args) + require.NoError(t, err) assert.Equalf(t, tt.expectedSQL, sql, "%d", i) assert.Equalf(t, tt.expectedArgs, args, "%d", i) }