mirror of https://github.com/jackc/pgx.git
Add QueryRewriter interface
parent
f9857b73d9
commit
b72b0daa5a
2
batch.go
2
batch.go
|
@ -14,7 +14,7 @@ type batchItem struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Batch queries are a way of bundling multiple queries together to avoid
|
// Batch queries are a way of bundling multiple queries together to avoid
|
||||||
// unnecessary network round trips.
|
// unnecessary network round trips. A Batch must only be sent once.
|
||||||
type Batch struct {
|
type Batch struct {
|
||||||
items []*batchItem
|
items []*batchItem
|
||||||
}
|
}
|
||||||
|
|
|
@ -239,6 +239,36 @@ func TestConnSendBatchWithPreparedStatement(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnSendBatchWithQueryRewriter(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
batch := &pgx.Batch{}
|
||||||
|
batch.Queue("something to be replaced", &testQueryRewriter{sql: "select $1::int", args: []any{1}})
|
||||||
|
batch.Queue("something else to be replaced", &testQueryRewriter{sql: "select $1::text", args: []any{"hello"}})
|
||||||
|
batch.Queue("more to be replaced", &testQueryRewriter{sql: "select $1::int", args: []any{3}})
|
||||||
|
|
||||||
|
br := conn.SendBatch(context.Background(), batch)
|
||||||
|
|
||||||
|
var n int32
|
||||||
|
err := br.QueryRow().Scan(&n)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, 1, n)
|
||||||
|
|
||||||
|
var s string
|
||||||
|
err = br.QueryRow().Scan(&s)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "hello", s)
|
||||||
|
|
||||||
|
err = br.QueryRow().Scan(&n)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, 3, n)
|
||||||
|
|
||||||
|
err = br.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// https://github.com/jackc/pgx/issues/856
|
// https://github.com/jackc/pgx/issues/856
|
||||||
func TestConnSendBatchWithPreparedStatementAndStatementCacheDisabled(t *testing.T) {
|
func TestConnSendBatchWithPreparedStatementAndStatementCacheDisabled(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
45
conn.go
45
conn.go
|
@ -404,6 +404,7 @@ func (c *Conn) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.C
|
||||||
|
|
||||||
func (c *Conn) exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error) {
|
func (c *Conn) exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error) {
|
||||||
mode := c.config.DefaultQueryExecMode
|
mode := c.config.DefaultQueryExecMode
|
||||||
|
var queryRewriter QueryRewriter
|
||||||
|
|
||||||
optionLoop:
|
optionLoop:
|
||||||
for len(arguments) > 0 {
|
for len(arguments) > 0 {
|
||||||
|
@ -411,11 +412,18 @@ optionLoop:
|
||||||
case QueryExecMode:
|
case QueryExecMode:
|
||||||
mode = arg
|
mode = arg
|
||||||
arguments = arguments[1:]
|
arguments = arguments[1:]
|
||||||
|
case QueryRewriter:
|
||||||
|
queryRewriter = arg
|
||||||
|
arguments = arguments[1:]
|
||||||
default:
|
default:
|
||||||
break optionLoop
|
break optionLoop
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if queryRewriter != nil {
|
||||||
|
sql, arguments = queryRewriter.RewriteQuery(ctx, c, sql, arguments)
|
||||||
|
}
|
||||||
|
|
||||||
// Always use simple protocol when there are no arguments.
|
// Always use simple protocol when there are no arguments.
|
||||||
if len(arguments) == 0 {
|
if len(arguments) == 0 {
|
||||||
mode = QueryExecModeSimpleProtocol
|
mode = QueryExecModeSimpleProtocol
|
||||||
|
@ -682,6 +690,11 @@ type QueryResultFormats []int16
|
||||||
// QueryResultFormatsByOID controls the result format (text=0, binary=1) of a query by the result column OID.
|
// QueryResultFormatsByOID controls the result format (text=0, binary=1) of a query by the result column OID.
|
||||||
type QueryResultFormatsByOID map[uint32]int16
|
type QueryResultFormatsByOID map[uint32]int16
|
||||||
|
|
||||||
|
// QueryRewriter rewrites a query when used as the first arguments to a query method.
|
||||||
|
type QueryRewriter interface {
|
||||||
|
RewriteQuery(ctx context.Context, conn *Conn, sql string, args ...any) (newSQL string, newArgs []any)
|
||||||
|
}
|
||||||
|
|
||||||
// Query executes sql with args. It is safe to attempt to read from the returned Rows even if an error is returned. The
|
// Query executes sql with args. It is safe to attempt to read from the returned Rows even if an error is returned. The
|
||||||
// error will be the available in rows.Err() after rows are closed. So it is allowed to ignore the error returned from
|
// error will be the available in rows.Err() after rows are closed. So it is allowed to ignore the error returned from
|
||||||
// Query and handle it in Rows.
|
// Query and handle it in Rows.
|
||||||
|
@ -696,6 +709,7 @@ func (c *Conn) Query(ctx context.Context, sql string, args ...any) (Rows, error)
|
||||||
var resultFormats QueryResultFormats
|
var resultFormats QueryResultFormats
|
||||||
var resultFormatsByOID QueryResultFormatsByOID
|
var resultFormatsByOID QueryResultFormatsByOID
|
||||||
mode := c.config.DefaultQueryExecMode
|
mode := c.config.DefaultQueryExecMode
|
||||||
|
var queryRewriter QueryRewriter
|
||||||
|
|
||||||
optionLoop:
|
optionLoop:
|
||||||
for len(args) > 0 {
|
for len(args) > 0 {
|
||||||
|
@ -709,11 +723,18 @@ optionLoop:
|
||||||
case QueryExecMode:
|
case QueryExecMode:
|
||||||
mode = arg
|
mode = arg
|
||||||
args = args[1:]
|
args = args[1:]
|
||||||
|
case QueryRewriter:
|
||||||
|
queryRewriter = arg
|
||||||
|
args = args[1:]
|
||||||
default:
|
default:
|
||||||
break optionLoop
|
break optionLoop
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if queryRewriter != nil {
|
||||||
|
sql, args = queryRewriter.RewriteQuery(ctx, c, sql, args)
|
||||||
|
}
|
||||||
|
|
||||||
c.eqb.Reset()
|
c.eqb.Reset()
|
||||||
anynil.NormalizeSlice(args)
|
anynil.NormalizeSlice(args)
|
||||||
rows := c.getRows(ctx, sql, args)
|
rows := c.getRows(ctx, sql, args)
|
||||||
|
@ -883,6 +904,30 @@ func (c *Conn) QueryFunc(ctx context.Context, sql string, args []any, scans []an
|
||||||
func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
|
func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
|
||||||
mode := c.config.DefaultQueryExecMode
|
mode := c.config.DefaultQueryExecMode
|
||||||
|
|
||||||
|
for _, bi := range b.items {
|
||||||
|
var queryRewriter QueryRewriter
|
||||||
|
sql := bi.query
|
||||||
|
arguments := bi.arguments
|
||||||
|
|
||||||
|
optionLoop:
|
||||||
|
for len(arguments) > 0 {
|
||||||
|
switch arg := arguments[0].(type) {
|
||||||
|
case QueryRewriter:
|
||||||
|
queryRewriter = arg
|
||||||
|
arguments = arguments[1:]
|
||||||
|
default:
|
||||||
|
break optionLoop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if queryRewriter != nil {
|
||||||
|
sql, arguments = queryRewriter.RewriteQuery(ctx, c, sql, arguments)
|
||||||
|
}
|
||||||
|
|
||||||
|
bi.query = sql
|
||||||
|
bi.arguments = arguments
|
||||||
|
}
|
||||||
|
|
||||||
if mode == QueryExecModeSimpleProtocol {
|
if mode == QueryExecModeSimpleProtocol {
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
for i, bi := range b.items {
|
for i, bi := range b.items {
|
||||||
|
|
19
conn_test.go
19
conn_test.go
|
@ -230,6 +230,25 @@ func TestExec(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type testQueryRewriter struct {
|
||||||
|
sql string
|
||||||
|
args []any
|
||||||
|
}
|
||||||
|
|
||||||
|
func (qr *testQueryRewriter) RewriteQuery(ctx context.Context, conn *pgx.Conn, sql string, args ...any) (newSQL string, newArgs []any) {
|
||||||
|
return qr.sql, qr.args
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecWithQueryRewriter(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
qr := testQueryRewriter{sql: "select $1::int", args: []any{42}}
|
||||||
|
_, err := conn.Exec(ctx, "should be replaced", &qr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestExecFailure(t *testing.T) {
|
func TestExecFailure(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|
|
@ -1864,6 +1864,26 @@ func TestQueryErrorWithDisabledStatementCache(t *testing.T) {
|
||||||
ensureConnValid(t, conn)
|
ensureConnValid(t, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQueryWithQueryRewriter(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
qr := testQueryRewriter{sql: "select $1::int", args: []any{42}}
|
||||||
|
rows, err := conn.Query(ctx, "should be replaced", &qr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var n int32
|
||||||
|
var rowCount int
|
||||||
|
for rows.Next() {
|
||||||
|
rowCount++
|
||||||
|
err = rows.Scan(&n)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, rows.Err())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestConnQueryFunc(t *testing.T) {
|
func TestConnQueryFunc(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue