mirror of https://github.com/jackc/pgx.git
Add NamedArgs
https://github.com/jackc/pgx/issues/1186 https://github.com/jackc/pgx/issues/387non-blocking v5.0.0-alpha.3
parent
b72b0daa5a
commit
107196ab0c
2
conn.go
2
conn.go
|
@ -692,7 +692,7 @@ type QueryResultFormatsByOID map[uint32]int16
|
|||
|
||||
// QueryRewriter rewrites a query when used as the first arguments to a query method.
|
||||
type QueryRewriter interface {
|
||||
RewriteQuery(ctx context.Context, conn *Conn, sql string, args ...any) (newSQL string, newArgs []any)
|
||||
RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any)
|
||||
}
|
||||
|
||||
// Query executes sql with args. It is safe to attempt to read from the returned Rows even if an error is returned. The
|
||||
|
|
|
@ -235,7 +235,7 @@ type testQueryRewriter struct {
|
|||
args []any
|
||||
}
|
||||
|
||||
func (qr *testQueryRewriter) RewriteQuery(ctx context.Context, conn *pgx.Conn, sql string, args ...any) (newSQL string, newArgs []any) {
|
||||
func (qr *testQueryRewriter) RewriteQuery(ctx context.Context, conn *pgx.Conn, sql string, args []any) (newSQL string, newArgs []any) {
|
||||
return qr.sql, qr.args
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,266 @@
|
|||
package pgx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// NamedArgs can be used as the first argument to a query method. It will replace every '@' named placeholder with a '$'
|
||||
// ordinal placeholder and construct the appropriate arguments.
|
||||
//
|
||||
// For example, the following two queries are equivalent:
|
||||
//
|
||||
// conn.Query(ctx, "select * from widgets where foo = @foo and bar = @bar", pgx.NamedArgs{"foo": 1, "bar": 2}))
|
||||
// conn.Query(ctx, "select * from widgets where foo = $1 and bar = $2", 1, 2}))
|
||||
type NamedArgs map[string]any
|
||||
|
||||
// RewriteQuery implements the QueryRewriter interface.
|
||||
func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any) {
|
||||
l := &sqlLexer{
|
||||
src: sql,
|
||||
stateFn: rawState,
|
||||
nameToOrdinal: make(map[namedArg]int, len(na)),
|
||||
}
|
||||
|
||||
for l.stateFn != nil {
|
||||
l.stateFn = l.stateFn(l)
|
||||
}
|
||||
|
||||
sb := strings.Builder{}
|
||||
for _, p := range l.parts {
|
||||
switch p := p.(type) {
|
||||
case string:
|
||||
sb.WriteString(p)
|
||||
case namedArg:
|
||||
sb.WriteRune('$')
|
||||
sb.WriteString(strconv.Itoa(l.nameToOrdinal[p]))
|
||||
}
|
||||
}
|
||||
|
||||
newArgs = make([]any, len(l.nameToOrdinal))
|
||||
for name, ordinal := range l.nameToOrdinal {
|
||||
newArgs[ordinal-1] = na[string(name)]
|
||||
}
|
||||
|
||||
return sb.String(), newArgs
|
||||
}
|
||||
|
||||
type namedArg string
|
||||
|
||||
type sqlLexer struct {
|
||||
src string
|
||||
start int
|
||||
pos int
|
||||
nested int // multiline comment nesting level.
|
||||
stateFn stateFn
|
||||
parts []any
|
||||
|
||||
nameToOrdinal map[namedArg]int
|
||||
}
|
||||
|
||||
type stateFn func(*sqlLexer) stateFn
|
||||
|
||||
func rawState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case 'e', 'E':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune == '\'' {
|
||||
l.pos += width
|
||||
return escapeStringState
|
||||
}
|
||||
case '\'':
|
||||
return singleQuoteState
|
||||
case '"':
|
||||
return doubleQuoteState
|
||||
case '@':
|
||||
nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if isLetter(nextRune) {
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos-width])
|
||||
}
|
||||
l.start = l.pos
|
||||
return namedArgState
|
||||
}
|
||||
case '-':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune == '-' {
|
||||
l.pos += width
|
||||
return oneLineCommentState
|
||||
}
|
||||
case '/':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune == '*' {
|
||||
l.pos += width
|
||||
return multilineCommentState
|
||||
}
|
||||
case utf8.RuneError:
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isLetter(r rune) bool {
|
||||
return (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z')
|
||||
}
|
||||
|
||||
func namedArgState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
if r == utf8.RuneError {
|
||||
if l.pos-l.start > 0 {
|
||||
na := namedArg(l.src[l.start:l.pos])
|
||||
if _, found := l.nameToOrdinal[na]; !found {
|
||||
l.nameToOrdinal[na] = len(l.nameToOrdinal) + 1
|
||||
}
|
||||
l.parts = append(l.parts, na)
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
} else if !(isLetter(r) || (r >= '0' && r <= '9')) {
|
||||
l.pos -= width
|
||||
na := namedArg(l.src[l.start:l.pos])
|
||||
if _, found := l.nameToOrdinal[na]; !found {
|
||||
l.nameToOrdinal[na] = len(l.nameToOrdinal) + 1
|
||||
}
|
||||
l.parts = append(l.parts, namedArg(na))
|
||||
l.start = l.pos
|
||||
return rawState
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func singleQuoteState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case '\'':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune != '\'' {
|
||||
return rawState
|
||||
}
|
||||
l.pos += width
|
||||
case utf8.RuneError:
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func doubleQuoteState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case '"':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune != '"' {
|
||||
return rawState
|
||||
}
|
||||
l.pos += width
|
||||
case utf8.RuneError:
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func escapeStringState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case '\\':
|
||||
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
case '\'':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune != '\'' {
|
||||
return rawState
|
||||
}
|
||||
l.pos += width
|
||||
case utf8.RuneError:
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func oneLineCommentState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case '\\':
|
||||
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
case '\n', '\r':
|
||||
return rawState
|
||||
case utf8.RuneError:
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func multilineCommentState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case '/':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune == '*' {
|
||||
l.pos += width
|
||||
l.nested++
|
||||
}
|
||||
case '*':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune != '/' {
|
||||
continue
|
||||
}
|
||||
|
||||
l.pos += width
|
||||
if l.nested == 0 {
|
||||
return rawState
|
||||
}
|
||||
l.nested--
|
||||
|
||||
case utf8.RuneError:
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,96 @@
|
|||
package pgx_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNamedArgsRewriteQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for i, tt := range []struct {
|
||||
sql string
|
||||
args []any
|
||||
namedArgs pgx.NamedArgs
|
||||
expectedSQL string
|
||||
expectedArgs []any
|
||||
}{
|
||||
{
|
||||
sql: "select * from users where id = @id",
|
||||
namedArgs: pgx.NamedArgs{"id": int32(42)},
|
||||
expectedSQL: "select * from users where id = $1",
|
||||
expectedArgs: []any{int32(42)},
|
||||
},
|
||||
{
|
||||
sql: "select * from t where foo < @abc and baz = @def and bar < @abc",
|
||||
namedArgs: pgx.NamedArgs{"abc": int32(42), "def": int32(1)},
|
||||
expectedSQL: "select * from t where foo < $1 and baz = $2 and bar < $1",
|
||||
expectedArgs: []any{int32(42), int32(1)},
|
||||
},
|
||||
{
|
||||
sql: "select @a::int, @b::text",
|
||||
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
|
||||
expectedSQL: "select $1::int, $2::text",
|
||||
expectedArgs: []any{int32(42), "foo"},
|
||||
},
|
||||
{
|
||||
sql: "at end @",
|
||||
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
|
||||
expectedSQL: "at end @",
|
||||
expectedArgs: []any{},
|
||||
},
|
||||
{
|
||||
sql: "ignores without letter after @ foo bar",
|
||||
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
|
||||
expectedSQL: "ignores without letter after @ foo bar",
|
||||
expectedArgs: []any{},
|
||||
},
|
||||
{
|
||||
sql: "name must start with letter @1 foo bar",
|
||||
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
|
||||
expectedSQL: "name must start with letter @1 foo bar",
|
||||
expectedArgs: []any{},
|
||||
},
|
||||
{
|
||||
sql: `select *, '@foo' as "@bar" from users where id = @id`,
|
||||
namedArgs: pgx.NamedArgs{"id": int32(42)},
|
||||
expectedSQL: `select *, '@foo' as "@bar" from users where id = $1`,
|
||||
expectedArgs: []any{int32(42)},
|
||||
},
|
||||
{
|
||||
sql: `select * -- @foo
|
||||
from users -- @single line comments
|
||||
where id = @id;`,
|
||||
namedArgs: pgx.NamedArgs{"id": int32(42)},
|
||||
expectedSQL: `select * -- @foo
|
||||
from users -- @single line comments
|
||||
where id = $1;`,
|
||||
expectedArgs: []any{int32(42)},
|
||||
},
|
||||
{
|
||||
sql: `select * /* @multi line
|
||||
@comment
|
||||
*/
|
||||
/* /* with @nesting */ */
|
||||
from users
|
||||
where id = @id;`,
|
||||
namedArgs: pgx.NamedArgs{"id": int32(42)},
|
||||
expectedSQL: `select * /* @multi line
|
||||
@comment
|
||||
*/
|
||||
/* /* with @nesting */ */
|
||||
from users
|
||||
where id = $1;`,
|
||||
expectedArgs: []any{int32(42)},
|
||||
},
|
||||
|
||||
// test comments and quotes
|
||||
} {
|
||||
sql, args := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, tt.args)
|
||||
assert.Equalf(t, tt.expectedSQL, sql, "%d", i)
|
||||
assert.Equalf(t, tt.expectedArgs, args, "%d", i)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue