mirror of https://github.com/jackc/pgx.git
QueryRewriter.RewriteQuery now returns an error
https://github.com/jackc/pgx/issues/1186#issuecomment-1288207250pull/1364/head
parent
6515e183ff
commit
7d3b9c1e44
24
conn.go
24
conn.go
|
@ -407,7 +407,10 @@ optionLoop:
|
||||||
}
|
}
|
||||||
|
|
||||||
if queryRewriter != nil {
|
if queryRewriter != nil {
|
||||||
sql, arguments = queryRewriter.RewriteQuery(ctx, c, sql, arguments)
|
sql, arguments, err = queryRewriter.RewriteQuery(ctx, c, sql, arguments)
|
||||||
|
if err != nil {
|
||||||
|
return pgconn.CommandTag{}, fmt.Errorf("rewrite query failed: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Always use simple protocol when there are no arguments.
|
// Always use simple protocol when there are no arguments.
|
||||||
|
@ -600,7 +603,7 @@ type QueryResultFormatsByOID map[uint32]int16
|
||||||
|
|
||||||
// QueryRewriter rewrites a query when used as the first arguments to a query method.
|
// QueryRewriter rewrites a query when used as the first arguments to a query method.
|
||||||
type QueryRewriter interface {
|
type QueryRewriter interface {
|
||||||
RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any)
|
RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query sends a query to the server and returns a Rows to read the results. Only errors encountered sending the query
|
// Query sends a query to the server and returns a Rows to read the results. Only errors encountered sending the query
|
||||||
|
@ -659,7 +662,16 @@ optionLoop:
|
||||||
}
|
}
|
||||||
|
|
||||||
if queryRewriter != nil {
|
if queryRewriter != nil {
|
||||||
sql, args = queryRewriter.RewriteQuery(ctx, c, sql, args)
|
var err error
|
||||||
|
originalSQL := sql
|
||||||
|
originalArgs := args
|
||||||
|
sql, args, err = queryRewriter.RewriteQuery(ctx, c, sql, args)
|
||||||
|
if err != nil {
|
||||||
|
rows := c.getRows(ctx, originalSQL, originalArgs)
|
||||||
|
err = fmt.Errorf("rewrite query failed: %v", err)
|
||||||
|
rows.fatal(err)
|
||||||
|
return rows, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Bypass any statement caching.
|
// Bypass any statement caching.
|
||||||
|
@ -826,7 +838,11 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if queryRewriter != nil {
|
if queryRewriter != nil {
|
||||||
sql, arguments = queryRewriter.RewriteQuery(ctx, c, sql, arguments)
|
var err error
|
||||||
|
sql, arguments, err = queryRewriter.RewriteQuery(ctx, c, sql, arguments)
|
||||||
|
if err != nil {
|
||||||
|
return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("rewrite query failed: %v", err)}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bi.query = sql
|
bi.query = sql
|
||||||
|
|
|
@ -236,8 +236,8 @@ type testQueryRewriter struct {
|
||||||
args []any
|
args []any
|
||||||
}
|
}
|
||||||
|
|
||||||
func (qr *testQueryRewriter) RewriteQuery(ctx context.Context, conn *pgx.Conn, sql string, args []any) (newSQL string, newArgs []any) {
|
func (qr *testQueryRewriter) RewriteQuery(ctx context.Context, conn *pgx.Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
|
||||||
return qr.sql, qr.args
|
return qr.sql, qr.args, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExecWithQueryRewriter(t *testing.T) {
|
func TestExecWithQueryRewriter(t *testing.T) {
|
||||||
|
|
|
@ -12,12 +12,12 @@ import (
|
||||||
//
|
//
|
||||||
// For example, the following two queries are equivalent:
|
// For example, the following two queries are equivalent:
|
||||||
//
|
//
|
||||||
// conn.Query(ctx, "select * from widgets where foo = @foo and bar = @bar", pgx.NamedArgs{"foo": 1, "bar": 2})
|
// conn.Query(ctx, "select * from widgets where foo = @foo and bar = @bar", pgx.NamedArgs{"foo": 1, "bar": 2})
|
||||||
// conn.Query(ctx, "select * from widgets where foo = $1 and bar = $2", 1, 2)
|
// conn.Query(ctx, "select * from widgets where foo = $1 and bar = $2", 1, 2)
|
||||||
type NamedArgs map[string]any
|
type NamedArgs map[string]any
|
||||||
|
|
||||||
// RewriteQuery implements the QueryRewriter interface.
|
// RewriteQuery implements the QueryRewriter interface.
|
||||||
func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any) {
|
func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
|
||||||
l := &sqlLexer{
|
l := &sqlLexer{
|
||||||
src: sql,
|
src: sql,
|
||||||
stateFn: rawState,
|
stateFn: rawState,
|
||||||
|
@ -44,7 +44,7 @@ func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, ar
|
||||||
newArgs[ordinal-1] = na[string(name)]
|
newArgs[ordinal-1] = na[string(name)]
|
||||||
}
|
}
|
||||||
|
|
||||||
return sb.String(), newArgs
|
return sb.String(), newArgs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type namedArg string
|
type namedArg string
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
|
|
||||||
"github.com/jackc/pgx/v5"
|
"github.com/jackc/pgx/v5"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNamedArgsRewriteQuery(t *testing.T) {
|
func TestNamedArgsRewriteQuery(t *testing.T) {
|
||||||
|
@ -95,7 +96,8 @@ func TestNamedArgsRewriteQuery(t *testing.T) {
|
||||||
|
|
||||||
// test comments and quotes
|
// test comments and quotes
|
||||||
} {
|
} {
|
||||||
sql, args := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, tt.args)
|
sql, args, err := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, tt.args)
|
||||||
|
require.NoError(t, err)
|
||||||
assert.Equalf(t, tt.expectedSQL, sql, "%d", i)
|
assert.Equalf(t, tt.expectedSQL, sql, "%d", i)
|
||||||
assert.Equalf(t, tt.expectedArgs, args, "%d", i)
|
assert.Equalf(t, tt.expectedArgs, args, "%d", i)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue