Jack Christensen 2022-04-23 18:43:04 -05:00
parent b72b0daa5a
commit 107196ab0c
4 changed files with 364 additions and 2 deletions

View File

@ -692,7 +692,7 @@ type QueryResultFormatsByOID map[uint32]int16
// QueryRewriter rewrites a query when used as the first arguments to a query method.
type QueryRewriter interface {
RewriteQuery(ctx context.Context, conn *Conn, sql string, args ...any) (newSQL string, newArgs []any)
RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any)
}
// Query executes sql with args. It is safe to attempt to read from the returned Rows even if an error is returned. The

View File

@ -235,7 +235,7 @@ type testQueryRewriter struct {
args []any
}
func (qr *testQueryRewriter) RewriteQuery(ctx context.Context, conn *pgx.Conn, sql string, args ...any) (newSQL string, newArgs []any) {
func (qr *testQueryRewriter) RewriteQuery(ctx context.Context, conn *pgx.Conn, sql string, args []any) (newSQL string, newArgs []any) {
return qr.sql, qr.args
}

266
named_args.go Normal file
View File

@ -0,0 +1,266 @@
package pgx
import (
"context"
"strconv"
"strings"
"unicode/utf8"
)
// NamedArgs can be used as the first argument to a query method. It will replace every '@' named placeholder with a '$'
// ordinal placeholder and construct the appropriate arguments.
//
// For example, the following two queries are equivalent:
//
// conn.Query(ctx, "select * from widgets where foo = @foo and bar = @bar", pgx.NamedArgs{"foo": 1, "bar": 2}))
// conn.Query(ctx, "select * from widgets where foo = $1 and bar = $2", 1, 2}))
type NamedArgs map[string]any
// RewriteQuery implements the QueryRewriter interface.
func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any) {
l := &sqlLexer{
src: sql,
stateFn: rawState,
nameToOrdinal: make(map[namedArg]int, len(na)),
}
for l.stateFn != nil {
l.stateFn = l.stateFn(l)
}
sb := strings.Builder{}
for _, p := range l.parts {
switch p := p.(type) {
case string:
sb.WriteString(p)
case namedArg:
sb.WriteRune('$')
sb.WriteString(strconv.Itoa(l.nameToOrdinal[p]))
}
}
newArgs = make([]any, len(l.nameToOrdinal))
for name, ordinal := range l.nameToOrdinal {
newArgs[ordinal-1] = na[string(name)]
}
return sb.String(), newArgs
}
type namedArg string
type sqlLexer struct {
src string
start int
pos int
nested int // multiline comment nesting level.
stateFn stateFn
parts []any
nameToOrdinal map[namedArg]int
}
type stateFn func(*sqlLexer) stateFn
func rawState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case 'e', 'E':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune == '\'' {
l.pos += width
return escapeStringState
}
case '\'':
return singleQuoteState
case '"':
return doubleQuoteState
case '@':
nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
if isLetter(nextRune) {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos-width])
}
l.start = l.pos
return namedArgState
}
case '-':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune == '-' {
l.pos += width
return oneLineCommentState
}
case '/':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune == '*' {
l.pos += width
return multilineCommentState
}
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
func isLetter(r rune) bool {
return (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z')
}
func namedArgState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
if r == utf8.RuneError {
if l.pos-l.start > 0 {
na := namedArg(l.src[l.start:l.pos])
if _, found := l.nameToOrdinal[na]; !found {
l.nameToOrdinal[na] = len(l.nameToOrdinal) + 1
}
l.parts = append(l.parts, na)
l.start = l.pos
}
return nil
} else if !(isLetter(r) || (r >= '0' && r <= '9')) {
l.pos -= width
na := namedArg(l.src[l.start:l.pos])
if _, found := l.nameToOrdinal[na]; !found {
l.nameToOrdinal[na] = len(l.nameToOrdinal) + 1
}
l.parts = append(l.parts, namedArg(na))
l.start = l.pos
return rawState
}
}
}
func singleQuoteState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '\'':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '\'' {
return rawState
}
l.pos += width
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
func doubleQuoteState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '"':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '"' {
return rawState
}
l.pos += width
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
func escapeStringState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '\\':
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
case '\'':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '\'' {
return rawState
}
l.pos += width
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
func oneLineCommentState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '\\':
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
case '\n', '\r':
return rawState
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
func multilineCommentState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '/':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune == '*' {
l.pos += width
l.nested++
}
case '*':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '/' {
continue
}
l.pos += width
if l.nested == 0 {
return rawState
}
l.nested--
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}

96
named_args_test.go Normal file
View File

@ -0,0 +1,96 @@
package pgx_test
import (
"context"
"testing"
"github.com/jackc/pgx/v5"
"github.com/stretchr/testify/assert"
)
func TestNamedArgsRewriteQuery(t *testing.T) {
t.Parallel()
for i, tt := range []struct {
sql string
args []any
namedArgs pgx.NamedArgs
expectedSQL string
expectedArgs []any
}{
{
sql: "select * from users where id = @id",
namedArgs: pgx.NamedArgs{"id": int32(42)},
expectedSQL: "select * from users where id = $1",
expectedArgs: []any{int32(42)},
},
{
sql: "select * from t where foo < @abc and baz = @def and bar < @abc",
namedArgs: pgx.NamedArgs{"abc": int32(42), "def": int32(1)},
expectedSQL: "select * from t where foo < $1 and baz = $2 and bar < $1",
expectedArgs: []any{int32(42), int32(1)},
},
{
sql: "select @a::int, @b::text",
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
expectedSQL: "select $1::int, $2::text",
expectedArgs: []any{int32(42), "foo"},
},
{
sql: "at end @",
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
expectedSQL: "at end @",
expectedArgs: []any{},
},
{
sql: "ignores without letter after @ foo bar",
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
expectedSQL: "ignores without letter after @ foo bar",
expectedArgs: []any{},
},
{
sql: "name must start with letter @1 foo bar",
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
expectedSQL: "name must start with letter @1 foo bar",
expectedArgs: []any{},
},
{
sql: `select *, '@foo' as "@bar" from users where id = @id`,
namedArgs: pgx.NamedArgs{"id": int32(42)},
expectedSQL: `select *, '@foo' as "@bar" from users where id = $1`,
expectedArgs: []any{int32(42)},
},
{
sql: `select * -- @foo
from users -- @single line comments
where id = @id;`,
namedArgs: pgx.NamedArgs{"id": int32(42)},
expectedSQL: `select * -- @foo
from users -- @single line comments
where id = $1;`,
expectedArgs: []any{int32(42)},
},
{
sql: `select * /* @multi line
@comment
*/
/* /* with @nesting */ */
from users
where id = @id;`,
namedArgs: pgx.NamedArgs{"id": int32(42)},
expectedSQL: `select * /* @multi line
@comment
*/
/* /* with @nesting */ */
from users
where id = $1;`,
expectedArgs: []any{int32(42)},
},
// test comments and quotes
} {
sql, args := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, tt.args)
assert.Equalf(t, tt.expectedSQL, sql, "%d", i)
assert.Equalf(t, tt.expectedArgs, args, "%d", i)
}
}