From b72b0daa5a8888f69562bab6674bd7f8ac4c99c4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 23 Apr 2022 17:26:42 -0500 Subject: [PATCH] Add QueryRewriter interface --- batch.go | 2 +- batch_test.go | 30 ++++++++++++++++++++++++++++++ conn.go | 45 +++++++++++++++++++++++++++++++++++++++++++++ conn_test.go | 19 +++++++++++++++++++ query_test.go | 20 ++++++++++++++++++++ 5 files changed, 115 insertions(+), 1 deletion(-) diff --git a/batch.go b/batch.go index 103d9aed..98f216dd 100644 --- a/batch.go +++ b/batch.go @@ -14,7 +14,7 @@ type batchItem struct { } // Batch queries are a way of bundling multiple queries together to avoid -// unnecessary network round trips. +// unnecessary network round trips. A Batch must only be sent once. type Batch struct { items []*batchItem } diff --git a/batch_test.go b/batch_test.go index 5558b823..96cf61c2 100644 --- a/batch_test.go +++ b/batch_test.go @@ -239,6 +239,36 @@ func TestConnSendBatchWithPreparedStatement(t *testing.T) { }) } +func TestConnSendBatchWithQueryRewriter(t *testing.T) { + t.Parallel() + + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + batch := &pgx.Batch{} + batch.Queue("something to be replaced", &testQueryRewriter{sql: "select $1::int", args: []any{1}}) + batch.Queue("something else to be replaced", &testQueryRewriter{sql: "select $1::text", args: []any{"hello"}}) + batch.Queue("more to be replaced", &testQueryRewriter{sql: "select $1::int", args: []any{3}}) + + br := conn.SendBatch(context.Background(), batch) + + var n int32 + err := br.QueryRow().Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 1, n) + + var s string + err = br.QueryRow().Scan(&s) + require.NoError(t, err) + require.Equal(t, "hello", s) + + err = br.QueryRow().Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 3, n) + + err = br.Close() + require.NoError(t, err) + }) +} + // https://github.com/jackc/pgx/issues/856 func TestConnSendBatchWithPreparedStatementAndStatementCacheDisabled(t *testing.T) { t.Parallel() diff --git a/conn.go b/conn.go index d0c9fe33..dd4a7301 100644 --- a/conn.go +++ b/conn.go @@ -404,6 +404,7 @@ func (c *Conn) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.C func (c *Conn) exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error) { mode := c.config.DefaultQueryExecMode + var queryRewriter QueryRewriter optionLoop: for len(arguments) > 0 { @@ -411,11 +412,18 @@ optionLoop: case QueryExecMode: mode = arg arguments = arguments[1:] + case QueryRewriter: + queryRewriter = arg + arguments = arguments[1:] default: break optionLoop } } + if queryRewriter != nil { + sql, arguments = queryRewriter.RewriteQuery(ctx, c, sql, arguments) + } + // Always use simple protocol when there are no arguments. if len(arguments) == 0 { mode = QueryExecModeSimpleProtocol @@ -682,6 +690,11 @@ type QueryResultFormats []int16 // QueryResultFormatsByOID controls the result format (text=0, binary=1) of a query by the result column OID. type QueryResultFormatsByOID map[uint32]int16 +// QueryRewriter rewrites a query when used as the first arguments to a query method. +type QueryRewriter interface { + RewriteQuery(ctx context.Context, conn *Conn, sql string, args ...any) (newSQL string, newArgs []any) +} + // Query executes sql with args. It is safe to attempt to read from the returned Rows even if an error is returned. The // error will be the available in rows.Err() after rows are closed. So it is allowed to ignore the error returned from // Query and handle it in Rows. @@ -696,6 +709,7 @@ func (c *Conn) Query(ctx context.Context, sql string, args ...any) (Rows, error) var resultFormats QueryResultFormats var resultFormatsByOID QueryResultFormatsByOID mode := c.config.DefaultQueryExecMode + var queryRewriter QueryRewriter optionLoop: for len(args) > 0 { @@ -709,11 +723,18 @@ optionLoop: case QueryExecMode: mode = arg args = args[1:] + case QueryRewriter: + queryRewriter = arg + args = args[1:] default: break optionLoop } } + if queryRewriter != nil { + sql, args = queryRewriter.RewriteQuery(ctx, c, sql, args) + } + c.eqb.Reset() anynil.NormalizeSlice(args) rows := c.getRows(ctx, sql, args) @@ -883,6 +904,30 @@ func (c *Conn) QueryFunc(ctx context.Context, sql string, args []any, scans []an func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { mode := c.config.DefaultQueryExecMode + for _, bi := range b.items { + var queryRewriter QueryRewriter + sql := bi.query + arguments := bi.arguments + + optionLoop: + for len(arguments) > 0 { + switch arg := arguments[0].(type) { + case QueryRewriter: + queryRewriter = arg + arguments = arguments[1:] + default: + break optionLoop + } + } + + if queryRewriter != nil { + sql, arguments = queryRewriter.RewriteQuery(ctx, c, sql, arguments) + } + + bi.query = sql + bi.arguments = arguments + } + if mode == QueryExecModeSimpleProtocol { var sb strings.Builder for i, bi := range b.items { diff --git a/conn_test.go b/conn_test.go index 61f6c951..392ea623 100644 --- a/conn_test.go +++ b/conn_test.go @@ -230,6 +230,25 @@ func TestExec(t *testing.T) { }) } +type testQueryRewriter struct { + sql string + args []any +} + +func (qr *testQueryRewriter) RewriteQuery(ctx context.Context, conn *pgx.Conn, sql string, args ...any) (newSQL string, newArgs []any) { + return qr.sql, qr.args +} + +func TestExecWithQueryRewriter(t *testing.T) { + t.Parallel() + + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + qr := testQueryRewriter{sql: "select $1::int", args: []any{42}} + _, err := conn.Exec(ctx, "should be replaced", &qr) + require.NoError(t, err) + }) +} + func TestExecFailure(t *testing.T) { t.Parallel() diff --git a/query_test.go b/query_test.go index 0e310eef..78cacb6c 100644 --- a/query_test.go +++ b/query_test.go @@ -1864,6 +1864,26 @@ func TestQueryErrorWithDisabledStatementCache(t *testing.T) { ensureConnValid(t, conn) } +func TestQueryWithQueryRewriter(t *testing.T) { + t.Parallel() + + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + qr := testQueryRewriter{sql: "select $1::int", args: []any{42}} + rows, err := conn.Query(ctx, "should be replaced", &qr) + require.NoError(t, err) + + var n int32 + var rowCount int + for rows.Next() { + rowCount++ + err = rows.Scan(&n) + require.NoError(t, err) + } + + require.NoError(t, rows.Err()) + }) +} + func TestConnQueryFunc(t *testing.T) { t.Parallel()