mirror of https://github.com/jackc/pgx.git
Add QueryRewriter interface
parent
f9857b73d9
commit
b72b0daa5a
2
batch.go
2
batch.go
|
@ -14,7 +14,7 @@ type batchItem struct {
|
|||
}
|
||||
|
||||
// Batch queries are a way of bundling multiple queries together to avoid
|
||||
// unnecessary network round trips.
|
||||
// unnecessary network round trips. A Batch must only be sent once.
|
||||
type Batch struct {
|
||||
items []*batchItem
|
||||
}
|
||||
|
|
|
@ -239,6 +239,36 @@ func TestConnSendBatchWithPreparedStatement(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestConnSendBatchWithQueryRewriter(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
batch := &pgx.Batch{}
|
||||
batch.Queue("something to be replaced", &testQueryRewriter{sql: "select $1::int", args: []any{1}})
|
||||
batch.Queue("something else to be replaced", &testQueryRewriter{sql: "select $1::text", args: []any{"hello"}})
|
||||
batch.Queue("more to be replaced", &testQueryRewriter{sql: "select $1::int", args: []any{3}})
|
||||
|
||||
br := conn.SendBatch(context.Background(), batch)
|
||||
|
||||
var n int32
|
||||
err := br.QueryRow().Scan(&n)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, n)
|
||||
|
||||
var s string
|
||||
err = br.QueryRow().Scan(&s)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "hello", s)
|
||||
|
||||
err = br.QueryRow().Scan(&n)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 3, n)
|
||||
|
||||
err = br.Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// https://github.com/jackc/pgx/issues/856
|
||||
func TestConnSendBatchWithPreparedStatementAndStatementCacheDisabled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
|
45
conn.go
45
conn.go
|
@ -404,6 +404,7 @@ func (c *Conn) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.C
|
|||
|
||||
func (c *Conn) exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error) {
|
||||
mode := c.config.DefaultQueryExecMode
|
||||
var queryRewriter QueryRewriter
|
||||
|
||||
optionLoop:
|
||||
for len(arguments) > 0 {
|
||||
|
@ -411,11 +412,18 @@ optionLoop:
|
|||
case QueryExecMode:
|
||||
mode = arg
|
||||
arguments = arguments[1:]
|
||||
case QueryRewriter:
|
||||
queryRewriter = arg
|
||||
arguments = arguments[1:]
|
||||
default:
|
||||
break optionLoop
|
||||
}
|
||||
}
|
||||
|
||||
if queryRewriter != nil {
|
||||
sql, arguments = queryRewriter.RewriteQuery(ctx, c, sql, arguments)
|
||||
}
|
||||
|
||||
// Always use simple protocol when there are no arguments.
|
||||
if len(arguments) == 0 {
|
||||
mode = QueryExecModeSimpleProtocol
|
||||
|
@ -682,6 +690,11 @@ type QueryResultFormats []int16
|
|||
// QueryResultFormatsByOID controls the result format (text=0, binary=1) of a query by the result column OID.
|
||||
type QueryResultFormatsByOID map[uint32]int16
|
||||
|
||||
// QueryRewriter rewrites a query when used as the first arguments to a query method.
|
||||
type QueryRewriter interface {
|
||||
RewriteQuery(ctx context.Context, conn *Conn, sql string, args ...any) (newSQL string, newArgs []any)
|
||||
}
|
||||
|
||||
// Query executes sql with args. It is safe to attempt to read from the returned Rows even if an error is returned. The
|
||||
// error will be the available in rows.Err() after rows are closed. So it is allowed to ignore the error returned from
|
||||
// Query and handle it in Rows.
|
||||
|
@ -696,6 +709,7 @@ func (c *Conn) Query(ctx context.Context, sql string, args ...any) (Rows, error)
|
|||
var resultFormats QueryResultFormats
|
||||
var resultFormatsByOID QueryResultFormatsByOID
|
||||
mode := c.config.DefaultQueryExecMode
|
||||
var queryRewriter QueryRewriter
|
||||
|
||||
optionLoop:
|
||||
for len(args) > 0 {
|
||||
|
@ -709,11 +723,18 @@ optionLoop:
|
|||
case QueryExecMode:
|
||||
mode = arg
|
||||
args = args[1:]
|
||||
case QueryRewriter:
|
||||
queryRewriter = arg
|
||||
args = args[1:]
|
||||
default:
|
||||
break optionLoop
|
||||
}
|
||||
}
|
||||
|
||||
if queryRewriter != nil {
|
||||
sql, args = queryRewriter.RewriteQuery(ctx, c, sql, args)
|
||||
}
|
||||
|
||||
c.eqb.Reset()
|
||||
anynil.NormalizeSlice(args)
|
||||
rows := c.getRows(ctx, sql, args)
|
||||
|
@ -883,6 +904,30 @@ func (c *Conn) QueryFunc(ctx context.Context, sql string, args []any, scans []an
|
|||
func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
|
||||
mode := c.config.DefaultQueryExecMode
|
||||
|
||||
for _, bi := range b.items {
|
||||
var queryRewriter QueryRewriter
|
||||
sql := bi.query
|
||||
arguments := bi.arguments
|
||||
|
||||
optionLoop:
|
||||
for len(arguments) > 0 {
|
||||
switch arg := arguments[0].(type) {
|
||||
case QueryRewriter:
|
||||
queryRewriter = arg
|
||||
arguments = arguments[1:]
|
||||
default:
|
||||
break optionLoop
|
||||
}
|
||||
}
|
||||
|
||||
if queryRewriter != nil {
|
||||
sql, arguments = queryRewriter.RewriteQuery(ctx, c, sql, arguments)
|
||||
}
|
||||
|
||||
bi.query = sql
|
||||
bi.arguments = arguments
|
||||
}
|
||||
|
||||
if mode == QueryExecModeSimpleProtocol {
|
||||
var sb strings.Builder
|
||||
for i, bi := range b.items {
|
||||
|
|
19
conn_test.go
19
conn_test.go
|
@ -230,6 +230,25 @@ func TestExec(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
type testQueryRewriter struct {
|
||||
sql string
|
||||
args []any
|
||||
}
|
||||
|
||||
func (qr *testQueryRewriter) RewriteQuery(ctx context.Context, conn *pgx.Conn, sql string, args ...any) (newSQL string, newArgs []any) {
|
||||
return qr.sql, qr.args
|
||||
}
|
||||
|
||||
func TestExecWithQueryRewriter(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
qr := testQueryRewriter{sql: "select $1::int", args: []any{42}}
|
||||
_, err := conn.Exec(ctx, "should be replaced", &qr)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestExecFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
|
@ -1864,6 +1864,26 @@ func TestQueryErrorWithDisabledStatementCache(t *testing.T) {
|
|||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestQueryWithQueryRewriter(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
qr := testQueryRewriter{sql: "select $1::int", args: []any{42}}
|
||||
rows, err := conn.Query(ctx, "should be replaced", &qr)
|
||||
require.NoError(t, err)
|
||||
|
||||
var n int32
|
||||
var rowCount int
|
||||
for rows.Next() {
|
||||
rowCount++
|
||||
err = rows.Scan(&n)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
require.NoError(t, rows.Err())
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnQueryFunc(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
Loading…
Reference in New Issue