Add QueryRewriter interface

non-blocking
Jack Christensen 2022-04-23 17:26:42 -05:00
parent f9857b73d9
commit b72b0daa5a
5 changed files with 115 additions and 1 deletions

View File

@ -14,7 +14,7 @@ type batchItem struct {
}
// Batch queries are a way of bundling multiple queries together to avoid
// unnecessary network round trips.
// unnecessary network round trips. A Batch must only be sent once.
type Batch struct {
items []*batchItem
}

View File

@ -239,6 +239,36 @@ func TestConnSendBatchWithPreparedStatement(t *testing.T) {
})
}
func TestConnSendBatchWithQueryRewriter(t *testing.T) {
t.Parallel()
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
batch := &pgx.Batch{}
batch.Queue("something to be replaced", &testQueryRewriter{sql: "select $1::int", args: []any{1}})
batch.Queue("something else to be replaced", &testQueryRewriter{sql: "select $1::text", args: []any{"hello"}})
batch.Queue("more to be replaced", &testQueryRewriter{sql: "select $1::int", args: []any{3}})
br := conn.SendBatch(context.Background(), batch)
var n int32
err := br.QueryRow().Scan(&n)
require.NoError(t, err)
require.EqualValues(t, 1, n)
var s string
err = br.QueryRow().Scan(&s)
require.NoError(t, err)
require.Equal(t, "hello", s)
err = br.QueryRow().Scan(&n)
require.NoError(t, err)
require.EqualValues(t, 3, n)
err = br.Close()
require.NoError(t, err)
})
}
// https://github.com/jackc/pgx/issues/856
func TestConnSendBatchWithPreparedStatementAndStatementCacheDisabled(t *testing.T) {
t.Parallel()

45
conn.go
View File

@ -404,6 +404,7 @@ func (c *Conn) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.C
func (c *Conn) exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error) {
mode := c.config.DefaultQueryExecMode
var queryRewriter QueryRewriter
optionLoop:
for len(arguments) > 0 {
@ -411,11 +412,18 @@ optionLoop:
case QueryExecMode:
mode = arg
arguments = arguments[1:]
case QueryRewriter:
queryRewriter = arg
arguments = arguments[1:]
default:
break optionLoop
}
}
if queryRewriter != nil {
sql, arguments = queryRewriter.RewriteQuery(ctx, c, sql, arguments)
}
// Always use simple protocol when there are no arguments.
if len(arguments) == 0 {
mode = QueryExecModeSimpleProtocol
@ -682,6 +690,11 @@ type QueryResultFormats []int16
// QueryResultFormatsByOID controls the result format (text=0, binary=1) of a query by the result column OID.
type QueryResultFormatsByOID map[uint32]int16
// QueryRewriter rewrites a query when used as the first arguments to a query method.
type QueryRewriter interface {
RewriteQuery(ctx context.Context, conn *Conn, sql string, args ...any) (newSQL string, newArgs []any)
}
// Query executes sql with args. It is safe to attempt to read from the returned Rows even if an error is returned. The
// error will be the available in rows.Err() after rows are closed. So it is allowed to ignore the error returned from
// Query and handle it in Rows.
@ -696,6 +709,7 @@ func (c *Conn) Query(ctx context.Context, sql string, args ...any) (Rows, error)
var resultFormats QueryResultFormats
var resultFormatsByOID QueryResultFormatsByOID
mode := c.config.DefaultQueryExecMode
var queryRewriter QueryRewriter
optionLoop:
for len(args) > 0 {
@ -709,11 +723,18 @@ optionLoop:
case QueryExecMode:
mode = arg
args = args[1:]
case QueryRewriter:
queryRewriter = arg
args = args[1:]
default:
break optionLoop
}
}
if queryRewriter != nil {
sql, args = queryRewriter.RewriteQuery(ctx, c, sql, args)
}
c.eqb.Reset()
anynil.NormalizeSlice(args)
rows := c.getRows(ctx, sql, args)
@ -883,6 +904,30 @@ func (c *Conn) QueryFunc(ctx context.Context, sql string, args []any, scans []an
func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
mode := c.config.DefaultQueryExecMode
for _, bi := range b.items {
var queryRewriter QueryRewriter
sql := bi.query
arguments := bi.arguments
optionLoop:
for len(arguments) > 0 {
switch arg := arguments[0].(type) {
case QueryRewriter:
queryRewriter = arg
arguments = arguments[1:]
default:
break optionLoop
}
}
if queryRewriter != nil {
sql, arguments = queryRewriter.RewriteQuery(ctx, c, sql, arguments)
}
bi.query = sql
bi.arguments = arguments
}
if mode == QueryExecModeSimpleProtocol {
var sb strings.Builder
for i, bi := range b.items {

View File

@ -230,6 +230,25 @@ func TestExec(t *testing.T) {
})
}
type testQueryRewriter struct {
sql string
args []any
}
func (qr *testQueryRewriter) RewriteQuery(ctx context.Context, conn *pgx.Conn, sql string, args ...any) (newSQL string, newArgs []any) {
return qr.sql, qr.args
}
func TestExecWithQueryRewriter(t *testing.T) {
t.Parallel()
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
qr := testQueryRewriter{sql: "select $1::int", args: []any{42}}
_, err := conn.Exec(ctx, "should be replaced", &qr)
require.NoError(t, err)
})
}
func TestExecFailure(t *testing.T) {
t.Parallel()

View File

@ -1864,6 +1864,26 @@ func TestQueryErrorWithDisabledStatementCache(t *testing.T) {
ensureConnValid(t, conn)
}
func TestQueryWithQueryRewriter(t *testing.T) {
t.Parallel()
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
qr := testQueryRewriter{sql: "select $1::int", args: []any{42}}
rows, err := conn.Query(ctx, "should be replaced", &qr)
require.NoError(t, err)
var n int32
var rowCount int
for rows.Next() {
rowCount++
err = rows.Scan(&n)
require.NoError(t, err)
}
require.NoError(t, rows.Err())
})
}
func TestConnQueryFunc(t *testing.T) {
t.Parallel()