QueryRewriter.RewriteQuery now returns an error

https://github.com/jackc/pgx/issues/1186#issuecomment-1288207250
pull/1364/head
Jack Christensen 2022-10-29 09:32:35 -05:00
parent 6515e183ff
commit 7d3b9c1e44
4 changed files with 29 additions and 11 deletions

24
conn.go
View File

@ -407,7 +407,10 @@ optionLoop:
} }
if queryRewriter != nil { if queryRewriter != nil {
sql, arguments = queryRewriter.RewriteQuery(ctx, c, sql, arguments) sql, arguments, err = queryRewriter.RewriteQuery(ctx, c, sql, arguments)
if err != nil {
return pgconn.CommandTag{}, fmt.Errorf("rewrite query failed: %v", err)
}
} }
// Always use simple protocol when there are no arguments. // Always use simple protocol when there are no arguments.
@ -600,7 +603,7 @@ type QueryResultFormatsByOID map[uint32]int16
// QueryRewriter rewrites a query when used as the first arguments to a query method. // QueryRewriter rewrites a query when used as the first arguments to a query method.
type QueryRewriter interface { type QueryRewriter interface {
RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error)
} }
// Query sends a query to the server and returns a Rows to read the results. Only errors encountered sending the query // Query sends a query to the server and returns a Rows to read the results. Only errors encountered sending the query
@ -659,7 +662,16 @@ optionLoop:
} }
if queryRewriter != nil { if queryRewriter != nil {
sql, args = queryRewriter.RewriteQuery(ctx, c, sql, args) var err error
originalSQL := sql
originalArgs := args
sql, args, err = queryRewriter.RewriteQuery(ctx, c, sql, args)
if err != nil {
rows := c.getRows(ctx, originalSQL, originalArgs)
err = fmt.Errorf("rewrite query failed: %v", err)
rows.fatal(err)
return rows, err
}
} }
// Bypass any statement caching. // Bypass any statement caching.
@ -826,7 +838,11 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
} }
if queryRewriter != nil { if queryRewriter != nil {
sql, arguments = queryRewriter.RewriteQuery(ctx, c, sql, arguments) var err error
sql, arguments, err = queryRewriter.RewriteQuery(ctx, c, sql, arguments)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("rewrite query failed: %v", err)}
}
} }
bi.query = sql bi.query = sql

View File

@ -236,8 +236,8 @@ type testQueryRewriter struct {
args []any args []any
} }
func (qr *testQueryRewriter) RewriteQuery(ctx context.Context, conn *pgx.Conn, sql string, args []any) (newSQL string, newArgs []any) { func (qr *testQueryRewriter) RewriteQuery(ctx context.Context, conn *pgx.Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
return qr.sql, qr.args return qr.sql, qr.args, nil
} }
func TestExecWithQueryRewriter(t *testing.T) { func TestExecWithQueryRewriter(t *testing.T) {

View File

@ -12,12 +12,12 @@ import (
// //
// For example, the following two queries are equivalent: // For example, the following two queries are equivalent:
// //
// conn.Query(ctx, "select * from widgets where foo = @foo and bar = @bar", pgx.NamedArgs{"foo": 1, "bar": 2}) // conn.Query(ctx, "select * from widgets where foo = @foo and bar = @bar", pgx.NamedArgs{"foo": 1, "bar": 2})
// conn.Query(ctx, "select * from widgets where foo = $1 and bar = $2", 1, 2) // conn.Query(ctx, "select * from widgets where foo = $1 and bar = $2", 1, 2)
type NamedArgs map[string]any type NamedArgs map[string]any
// RewriteQuery implements the QueryRewriter interface. // RewriteQuery implements the QueryRewriter interface.
func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any) { func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
l := &sqlLexer{ l := &sqlLexer{
src: sql, src: sql,
stateFn: rawState, stateFn: rawState,
@ -44,7 +44,7 @@ func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, ar
newArgs[ordinal-1] = na[string(name)] newArgs[ordinal-1] = na[string(name)]
} }
return sb.String(), newArgs return sb.String(), newArgs, nil
} }
type namedArg string type namedArg string

View File

@ -6,6 +6,7 @@ import (
"github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestNamedArgsRewriteQuery(t *testing.T) { func TestNamedArgsRewriteQuery(t *testing.T) {
@ -95,7 +96,8 @@ func TestNamedArgsRewriteQuery(t *testing.T) {
// test comments and quotes // test comments and quotes
} { } {
sql, args := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, tt.args) sql, args, err := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, tt.args)
require.NoError(t, err)
assert.Equalf(t, tt.expectedSQL, sql, "%d", i) assert.Equalf(t, tt.expectedSQL, sql, "%d", i)
assert.Equalf(t, tt.expectedArgs, args, "%d", i) assert.Equalf(t, tt.expectedArgs, args, "%d", i)
} }